diff --git a/arch/x86/kvm/mmu/tdp_iter.c b/arch/x86/kvm/mmu/tdp_iter.c
index b3ed302c1a359fbe5b46b9aad8f6c0bcc5e7944d..caa96c270b9541bd5769a976bbe9f41227c37b0e 100644
--- a/arch/x86/kvm/mmu/tdp_iter.c
+++ b/arch/x86/kvm/mmu/tdp_iter.c
@@ -26,6 +26,7 @@ static gfn_t round_gfn_for_level(gfn_t gfn, int level)
  */
 void tdp_iter_restart(struct tdp_iter *iter)
 {
+	iter->yielded = false;
 	iter->yielded_gfn = iter->next_last_level_gfn;
 	iter->level = iter->root_level;
 
@@ -160,6 +161,11 @@ static bool try_step_up(struct tdp_iter *iter)
  */
 void tdp_iter_next(struct tdp_iter *iter)
 {
+	if (iter->yielded) {
+		tdp_iter_restart(iter);
+		return;
+	}
+
 	if (try_step_down(iter))
 		return;
 
diff --git a/arch/x86/kvm/mmu/tdp_iter.h b/arch/x86/kvm/mmu/tdp_iter.h
index b1748b988d3aef556087b26ffad29accbd2756ff..e19cabbcb65c846e6065406c336848fb38a959fc 100644
--- a/arch/x86/kvm/mmu/tdp_iter.h
+++ b/arch/x86/kvm/mmu/tdp_iter.h
@@ -45,6 +45,12 @@ struct tdp_iter {
 	 * iterator walks off the end of the paging structure.
 	 */
 	bool valid;
+	/*
+	 * True if KVM dropped mmu_lock and yielded in the middle of a walk, in
+	 * which case tdp_iter_next() needs to restart the walk at the root
+	 * level instead of advancing to the next entry.
+	 */
+	bool yielded;
 };
 
 /*
diff --git a/arch/x86/kvm/mmu/tdp_mmu.c b/arch/x86/kvm/mmu/tdp_mmu.c
index 1db8496259add5411626a99311d4ed6b372aed5d..1beb4ca9056092cca031b9ebd65c12112852bdd8 100644
--- a/arch/x86/kvm/mmu/tdp_mmu.c
+++ b/arch/x86/kvm/mmu/tdp_mmu.c
@@ -502,6 +502,8 @@ static inline bool tdp_mmu_set_spte_atomic(struct kvm *kvm,
 					   struct tdp_iter *iter,
 					   u64 new_spte)
 {
+	WARN_ON_ONCE(iter->yielded);
+
 	lockdep_assert_held_read(&kvm->mmu_lock);
 
 	/*
@@ -575,6 +577,8 @@ static inline void __tdp_mmu_set_spte(struct kvm *kvm, struct tdp_iter *iter,
 				      u64 new_spte, bool record_acc_track,
 				      bool record_dirty_log)
 {
+	WARN_ON_ONCE(iter->yielded);
+
 	lockdep_assert_held_write(&kvm->mmu_lock);
 
 	/*
@@ -640,18 +644,19 @@ static inline void tdp_mmu_set_spte_no_dirty_log(struct kvm *kvm,
  * If this function should yield and flush is set, it will perform a remote
  * TLB flush before yielding.
  *
- * If this function yields, it will also reset the tdp_iter's walk over the
- * paging structure and the calling function should skip to the next
- * iteration to allow the iterator to continue its traversal from the
- * paging structure root.
+ * If this function yields, iter->yielded is set and the caller must skip to
+ * the next iteration, where tdp_iter_next() will reset the tdp_iter's walk
+ * over the paging structures to allow the iterator to continue its traversal
+ * from the paging structure root.
  *
- * Return true if this function yielded and the iterator's traversal was reset.
- * Return false if a yield was not needed.
+ * Returns true if this function yielded.
  */
-static inline bool tdp_mmu_iter_cond_resched(struct kvm *kvm,
-					     struct tdp_iter *iter, bool flush,
-					     bool shared)
+static inline bool __must_check tdp_mmu_iter_cond_resched(struct kvm *kvm,
+							  struct tdp_iter *iter,
+							  bool flush, bool shared)
 {
+	WARN_ON(iter->yielded);
+
 	/* Ensure forward progress has been made before yielding. */
 	if (iter->next_last_level_gfn == iter->yielded_gfn)
 		return false;
@@ -671,12 +676,10 @@ static inline bool tdp_mmu_iter_cond_resched(struct kvm *kvm,
 
 		WARN_ON(iter->gfn > iter->next_last_level_gfn);
 
-		tdp_iter_restart(iter);
-
-		return true;
+		iter->yielded = true;
 	}
 
-	return false;
+	return iter->yielded;
 }
 
 /*