diff --git a/block/kyber-iosched.c b/block/kyber-iosched.c
index b4df317c291692f01138b91608dc6c80f71bb9aa..f95c60774ce8ca613417d3ccf54bee52010752ee 100644
--- a/block/kyber-iosched.c
+++ b/block/kyber-iosched.c
@@ -100,9 +100,13 @@ struct kyber_hctx_data {
 	unsigned int cur_domain;
 	unsigned int batching;
 	wait_queue_entry_t domain_wait[KYBER_NUM_DOMAINS];
+	struct sbq_wait_state *domain_ws[KYBER_NUM_DOMAINS];
 	atomic_t wait_index[KYBER_NUM_DOMAINS];
 };
 
+static int kyber_domain_wake(wait_queue_entry_t *wait, unsigned mode, int flags,
+			     void *key);
+
 static int rq_sched_domain(const struct request *rq)
 {
 	unsigned int op = rq->cmd_flags;
@@ -385,6 +389,9 @@ static int kyber_init_hctx(struct blk_mq_hw_ctx *hctx, unsigned int hctx_idx)
 
 	for (i = 0; i < KYBER_NUM_DOMAINS; i++) {
 		INIT_LIST_HEAD(&khd->rqs[i]);
+		init_waitqueue_func_entry(&khd->domain_wait[i],
+					  kyber_domain_wake);
+		khd->domain_wait[i].private = hctx;
 		INIT_LIST_HEAD(&khd->domain_wait[i].entry);
 		atomic_set(&khd->wait_index[i], 0);
 	}
@@ -524,35 +531,39 @@ static int kyber_get_domain_token(struct kyber_queue_data *kqd,
 	int nr;
 
 	nr = __sbitmap_queue_get(domain_tokens);
-	if (nr >= 0)
-		return nr;
 
 	/*
 	 * If we failed to get a domain token, make sure the hardware queue is
 	 * run when one becomes available. Note that this is serialized on
 	 * khd->lock, but we still need to be careful about the waker.
 	 */
-	if (list_empty_careful(&wait->entry)) {
-		init_waitqueue_func_entry(wait, kyber_domain_wake);
-		wait->private = hctx;
+	if (nr < 0 && list_empty_careful(&wait->entry)) {
 		ws = sbq_wait_ptr(domain_tokens,
 				  &khd->wait_index[sched_domain]);
+		khd->domain_ws[sched_domain] = ws;
 		add_wait_queue(&ws->wait, wait);
 
 		/*
 		 * Try again in case a token was freed before we got on the wait
-		 * queue. The waker may have already removed the entry from the
-		 * wait queue, but list_del_init() is okay with that.
+		 * queue.
 		 */
 		nr = __sbitmap_queue_get(domain_tokens);
-		if (nr >= 0) {
-			unsigned long flags;
+	}
 
-			spin_lock_irqsave(&ws->wait.lock, flags);
-			list_del_init(&wait->entry);
-			spin_unlock_irqrestore(&ws->wait.lock, flags);
-		}
+	/*
+	 * If we got a token while we were on the wait queue, remove ourselves
+	 * from the wait queue to ensure that all wake ups make forward
+	 * progress. It's possible that the waker already deleted the entry
+	 * between the !list_empty_careful() check and us grabbing the lock, but
+	 * list_del_init() is okay with that.
+	 */
+	if (nr >= 0 && !list_empty_careful(&wait->entry)) {
+		ws = khd->domain_ws[sched_domain];
+		spin_lock_irq(&ws->wait.lock);
+		list_del_init(&wait->entry);
+		spin_unlock_irq(&ws->wait.lock);
 	}
+
 	return nr;
 }