diff --git a/include/net/sock.h b/include/net/sock.h
index c005c3c750e89b6d4d36d36ccfad392a7e733603..dc3f8169312ef92f43ef54688029d1396f99347e 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -1623,7 +1623,36 @@ void release_sock(struct sock *sk);
 				SINGLE_DEPTH_NESTING)
 #define bh_unlock_sock(__sk)	spin_unlock(&((__sk)->sk_lock.slock))
 
-bool lock_sock_fast(struct sock *sk) __acquires(&sk->sk_lock.slock);
+bool __lock_sock_fast(struct sock *sk) __acquires(&sk->sk_lock.slock);
+
+/**
+ * lock_sock_fast - fast version of lock_sock
+ * @sk: socket
+ *
+ * This version should be used for very small section, where process wont block
+ * return false if fast path is taken:
+ *
+ *   sk_lock.slock locked, owned = 0, BH disabled
+ *
+ * return true if slow path is taken:
+ *
+ *   sk_lock.slock unlocked, owned = 1, BH enabled
+ */
+static inline bool lock_sock_fast(struct sock *sk)
+{
+	/* The sk_lock has mutex_lock() semantics here. */
+	mutex_acquire(&sk->sk_lock.dep_map, 0, 0, _RET_IP_);
+
+	return __lock_sock_fast(sk);
+}
+
+/* fast socket lock variant for caller already holding a [different] socket lock */
+static inline bool lock_sock_fast_nested(struct sock *sk)
+{
+	mutex_acquire(&sk->sk_lock.dep_map, SINGLE_DEPTH_NESTING, 0, _RET_IP_);
+
+	return __lock_sock_fast(sk);
+}
 
 /**
  * unlock_sock_fast - complement of lock_sock_fast
diff --git a/net/core/sock.c b/net/core/sock.c
index 512e629f97803f3216db8f054e97fd1b7a6e6c63..7060d183216e0133377e77fd9f551deeb3f55888 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -3210,24 +3210,8 @@ void release_sock(struct sock *sk)
 }
 EXPORT_SYMBOL(release_sock);
 
-/**
- * lock_sock_fast - fast version of lock_sock
- * @sk: socket
- *
- * This version should be used for very small section, where process wont block
- * return false if fast path is taken:
- *
- *   sk_lock.slock locked, owned = 0, BH disabled
- *
- * return true if slow path is taken:
- *
- *   sk_lock.slock unlocked, owned = 1, BH enabled
- */
-bool lock_sock_fast(struct sock *sk) __acquires(&sk->sk_lock.slock)
+bool __lock_sock_fast(struct sock *sk) __acquires(&sk->sk_lock.slock)
 {
-	/* The sk_lock has mutex_lock() semantics here. */
-	mutex_acquire(&sk->sk_lock.dep_map, 0, 0, _RET_IP_);
-
 	might_sleep();
 	spin_lock_bh(&sk->sk_lock.slock);
 
@@ -3256,7 +3240,7 @@ bool lock_sock_fast(struct sock *sk) __acquires(&sk->sk_lock.slock)
 	spin_unlock_bh(&sk->sk_lock.slock);
 	return true;
 }
-EXPORT_SYMBOL(lock_sock_fast);
+EXPORT_SYMBOL(__lock_sock_fast);
 
 int sock_gettstamp(struct socket *sock, void __user *userstamp,
 		   bool timeval, bool time32)
diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c
index dbcebf56798fa80a87eb8abd89b35fa13ed85d60..e5df0b5971c86adbeaae5fc8ad86c0e8ca375ea2 100644
--- a/net/mptcp/protocol.c
+++ b/net/mptcp/protocol.c
@@ -2735,7 +2735,7 @@ static void mptcp_close(struct sock *sk, long timeout)
 	inet_csk(sk)->icsk_mtup.probe_timestamp = tcp_jiffies32;
 	mptcp_for_each_subflow(mptcp_sk(sk), subflow) {
 		struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
-		bool slow = lock_sock_fast(ssk);
+		bool slow = lock_sock_fast_nested(ssk);
 
 		sock_orphan(ssk);
 		unlock_sock_fast(ssk, slow);