diff --git a/arch/x86/kvm/mmu/mmu.c b/arch/x86/kvm/mmu/mmu.c
index 0ca2f9bd82846a1aa327af68693868eab45af6e3..9c3bfc5cb52721f48b57e94d1ae2a9789a0618dd 100644
--- a/arch/x86/kvm/mmu/mmu.c
+++ b/arch/x86/kvm/mmu/mmu.c
@@ -3948,8 +3948,7 @@ int kvm_tdp_page_fault(struct kvm_vcpu *vcpu, gpa_t gpa, u32 error_code,
 				 max_level, true);
 }
 
-static void nonpaging_init_context(struct kvm_vcpu *vcpu,
-				   struct kvm_mmu *context)
+static void nonpaging_init_context(struct kvm_mmu *context)
 {
 	context->page_fault = nonpaging_page_fault;
 	context->gva_to_gpa = nonpaging_gva_to_gpa;
@@ -4513,14 +4512,13 @@ static void update_last_nonleaf_level(struct kvm_mmu *mmu)
 		mmu->last_nonleaf_level++;
 }
 
-static void paging64_init_context_common(struct kvm_vcpu *vcpu,
-					 struct kvm_mmu *context,
+static void paging64_init_context_common(struct kvm_mmu *context,
 					 int root_level)
 {
-	context->nx = is_nx(vcpu);
+	context->nx = is_efer_nx(context);
 	context->root_level = root_level;
 
-	MMU_WARN_ON(!is_pae(vcpu));
+	WARN_ON_ONCE(!is_cr4_pae(context));
 	context->page_fault = paging64_page_fault;
 	context->gva_to_gpa = paging64_gva_to_gpa;
 	context->sync_page = paging64_sync_page;
@@ -4528,17 +4526,16 @@ static void paging64_init_context_common(struct kvm_vcpu *vcpu,
 	context->direct_map = false;
 }
 
-static void paging64_init_context(struct kvm_vcpu *vcpu,
-				  struct kvm_mmu *context)
+static void paging64_init_context(struct kvm_mmu *context,
+				  struct kvm_mmu_role_regs *regs)
 {
-	int root_level = is_la57_mode(vcpu) ?
-			 PT64_ROOT_5LEVEL : PT64_ROOT_4LEVEL;
+	int root_level = ____is_cr4_la57(regs) ? PT64_ROOT_5LEVEL :
+						 PT64_ROOT_4LEVEL;
 
-	paging64_init_context_common(vcpu, context, root_level);
+	paging64_init_context_common(context, root_level);
 }
 
-static void paging32_init_context(struct kvm_vcpu *vcpu,
-				  struct kvm_mmu *context)
+static void paging32_init_context(struct kvm_mmu *context)
 {
 	context->nx = false;
 	context->root_level = PT32_ROOT_LEVEL;
@@ -4549,10 +4546,9 @@ static void paging32_init_context(struct kvm_vcpu *vcpu,
 	context->direct_map = false;
 }
 
-static void paging32E_init_context(struct kvm_vcpu *vcpu,
-				   struct kvm_mmu *context)
+static void paging32E_init_context(struct kvm_mmu *context)
 {
-	paging64_init_context_common(vcpu, context, PT32E_ROOT_LEVEL);
+	paging64_init_context_common(context, PT32E_ROOT_LEVEL);
 }
 
 static union kvm_mmu_extended_role kvm_calc_mmu_role_ext(struct kvm_vcpu *vcpu,
@@ -4712,13 +4708,13 @@ static void shadow_mmu_init_context(struct kvm_vcpu *vcpu, struct kvm_mmu *conte
 	context->mmu_role.as_u64 = new_role.as_u64;
 
 	if (!____is_cr0_pg(regs))
-		nonpaging_init_context(vcpu, context);
+		nonpaging_init_context(context);
 	else if (____is_efer_lma(regs))
-		paging64_init_context(vcpu, context);
+		paging64_init_context(context, regs);
 	else if (____is_cr4_pae(regs))
-		paging32E_init_context(vcpu, context);
+		paging32E_init_context(context);
 	else
-		paging32_init_context(vcpu, context);
+		paging32_init_context(context);
 
 	if (____is_cr0_pg(regs)) {
 		reset_rsvds_bits_mask(vcpu, context);