diff --git a/fs/dax.c b/fs/dax.c
index 5b84a46201c2a6d6df977a8d5142f89b9ae5a253..d5f6aca5a4d79850b8f605bec697fa36e8acda4d 100644
--- a/fs/dax.c
+++ b/fs/dax.c
@@ -24,6 +24,7 @@
 #include <linux/memcontrol.h>
 #include <linux/mm.h>
 #include <linux/mutex.h>
+#include <linux/pagevec.h>
 #include <linux/pmem.h>
 #include <linux/sched.h>
 #include <linux/uio.h>
@@ -324,6 +325,199 @@ static int copy_user_bh(struct page *to, struct inode *inode,
 	return 0;
 }
 
+#define NO_SECTOR -1
+#define DAX_PMD_INDEX(page_index) (page_index & (PMD_MASK >> PAGE_CACHE_SHIFT))
+
+static int dax_radix_entry(struct address_space *mapping, pgoff_t index,
+		sector_t sector, bool pmd_entry, bool dirty)
+{
+	struct radix_tree_root *page_tree = &mapping->page_tree;
+	pgoff_t pmd_index = DAX_PMD_INDEX(index);
+	int type, error = 0;
+	void *entry;
+
+	WARN_ON_ONCE(pmd_entry && !dirty);
+	__mark_inode_dirty(mapping->host, I_DIRTY_PAGES);
+
+	spin_lock_irq(&mapping->tree_lock);
+
+	entry = radix_tree_lookup(page_tree, pmd_index);
+	if (entry && RADIX_DAX_TYPE(entry) == RADIX_DAX_PMD) {
+		index = pmd_index;
+		goto dirty;
+	}
+
+	entry = radix_tree_lookup(page_tree, index);
+	if (entry) {
+		type = RADIX_DAX_TYPE(entry);
+		if (WARN_ON_ONCE(type != RADIX_DAX_PTE &&
+					type != RADIX_DAX_PMD)) {
+			error = -EIO;
+			goto unlock;
+		}
+
+		if (!pmd_entry || type == RADIX_DAX_PMD)
+			goto dirty;
+
+		/*
+		 * We only insert dirty PMD entries into the radix tree.  This
+		 * means we don't need to worry about removing a dirty PTE
+		 * entry and inserting a clean PMD entry, thus reducing the
+		 * range we would flush with a follow-up fsync/msync call.
+		 */
+		radix_tree_delete(&mapping->page_tree, index);
+		mapping->nrexceptional--;
+	}
+
+	if (sector == NO_SECTOR) {
+		/*
+		 * This can happen during correct operation if our pfn_mkwrite
+		 * fault raced against a hole punch operation.  If this
+		 * happens the pte that was hole punched will have been
+		 * unmapped and the radix tree entry will have been removed by
+		 * the time we are called, but the call will still happen.  We
+		 * will return all the way up to wp_pfn_shared(), where the
+		 * pte_same() check will fail, eventually causing page fault
+		 * to be retried by the CPU.
+		 */
+		goto unlock;
+	}
+
+	error = radix_tree_insert(page_tree, index,
+			RADIX_DAX_ENTRY(sector, pmd_entry));
+	if (error)
+		goto unlock;
+
+	mapping->nrexceptional++;
+ dirty:
+	if (dirty)
+		radix_tree_tag_set(page_tree, index, PAGECACHE_TAG_DIRTY);
+ unlock:
+	spin_unlock_irq(&mapping->tree_lock);
+	return error;
+}
+
+static int dax_writeback_one(struct block_device *bdev,
+		struct address_space *mapping, pgoff_t index, void *entry)
+{
+	struct radix_tree_root *page_tree = &mapping->page_tree;
+	int type = RADIX_DAX_TYPE(entry);
+	struct radix_tree_node *node;
+	struct blk_dax_ctl dax;
+	void **slot;
+	int ret = 0;
+
+	spin_lock_irq(&mapping->tree_lock);
+	/*
+	 * Regular page slots are stabilized by the page lock even
+	 * without the tree itself locked.  These unlocked entries
+	 * need verification under the tree lock.
+	 */
+	if (!__radix_tree_lookup(page_tree, index, &node, &slot))
+		goto unlock;
+	if (*slot != entry)
+		goto unlock;
+
+	/* another fsync thread may have already written back this entry */
+	if (!radix_tree_tag_get(page_tree, index, PAGECACHE_TAG_TOWRITE))
+		goto unlock;
+
+	if (WARN_ON_ONCE(type != RADIX_DAX_PTE && type != RADIX_DAX_PMD)) {
+		ret = -EIO;
+		goto unlock;
+	}
+
+	dax.sector = RADIX_DAX_SECTOR(entry);
+	dax.size = (type == RADIX_DAX_PMD ? PMD_SIZE : PAGE_SIZE);
+	spin_unlock_irq(&mapping->tree_lock);
+
+	/*
+	 * We cannot hold tree_lock while calling dax_map_atomic() because it
+	 * eventually calls cond_resched().
+	 */
+	ret = dax_map_atomic(bdev, &dax);
+	if (ret < 0)
+		return ret;
+
+	if (WARN_ON_ONCE(ret < dax.size)) {
+		ret = -EIO;
+		goto unmap;
+	}
+
+	wb_cache_pmem(dax.addr, dax.size);
+
+	spin_lock_irq(&mapping->tree_lock);
+	radix_tree_tag_clear(page_tree, index, PAGECACHE_TAG_TOWRITE);
+	spin_unlock_irq(&mapping->tree_lock);
+ unmap:
+	dax_unmap_atomic(bdev, &dax);
+	return ret;
+
+ unlock:
+	spin_unlock_irq(&mapping->tree_lock);
+	return ret;
+}
+
+/*
+ * Flush the mapping to the persistent domain within the byte range of [start,
+ * end]. This is required by data integrity operations to ensure file data is
+ * on persistent storage prior to completion of the operation.
+ */
+int dax_writeback_mapping_range(struct address_space *mapping, loff_t start,
+		loff_t end)
+{
+	struct inode *inode = mapping->host;
+	struct block_device *bdev = inode->i_sb->s_bdev;
+	pgoff_t start_index, end_index, pmd_index;
+	pgoff_t indices[PAGEVEC_SIZE];
+	struct pagevec pvec;
+	bool done = false;
+	int i, ret = 0;
+	void *entry;
+
+	if (WARN_ON_ONCE(inode->i_blkbits != PAGE_SHIFT))
+		return -EIO;
+
+	start_index = start >> PAGE_CACHE_SHIFT;
+	end_index = end >> PAGE_CACHE_SHIFT;
+	pmd_index = DAX_PMD_INDEX(start_index);
+
+	rcu_read_lock();
+	entry = radix_tree_lookup(&mapping->page_tree, pmd_index);
+	rcu_read_unlock();
+
+	/* see if the start of our range is covered by a PMD entry */
+	if (entry && RADIX_DAX_TYPE(entry) == RADIX_DAX_PMD)
+		start_index = pmd_index;
+
+	tag_pages_for_writeback(mapping, start_index, end_index);
+
+	pagevec_init(&pvec, 0);
+	while (!done) {
+		pvec.nr = find_get_entries_tag(mapping, start_index,
+				PAGECACHE_TAG_TOWRITE, PAGEVEC_SIZE,
+				pvec.pages, indices);
+
+		if (pvec.nr == 0)
+			break;
+
+		for (i = 0; i < pvec.nr; i++) {
+			if (indices[i] > end_index) {
+				done = true;
+				break;
+			}
+
+			ret = dax_writeback_one(bdev, mapping, indices[i],
+					pvec.pages[i]);
+			if (ret < 0)
+				return ret;
+		}
+	}
+	wmb_pmem();
+	return 0;
+}
+EXPORT_SYMBOL_GPL(dax_writeback_mapping_range);
+
 static int dax_insert_mapping(struct inode *inode, struct buffer_head *bh,
 			struct vm_area_struct *vma, struct vm_fault *vmf)
 {
@@ -363,6 +557,11 @@ static int dax_insert_mapping(struct inode *inode, struct buffer_head *bh,
 	}
 	dax_unmap_atomic(bdev, &dax);
 
+	error = dax_radix_entry(mapping, vmf->pgoff, dax.sector, false,
+			vmf->flags & FAULT_FLAG_WRITE);
+	if (error)
+		goto out;
+
 	error = vm_insert_mixed(vma, vaddr, dax.pfn);
 
  out:
@@ -487,6 +686,7 @@ int __dax_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
 		delete_from_page_cache(page);
 		unlock_page(page);
 		page_cache_release(page);
+		page = NULL;
 	}
 
 	/*
@@ -589,9 +789,9 @@ int __dax_pmd_fault(struct vm_area_struct *vma, unsigned long address,
 	bool write = flags & FAULT_FLAG_WRITE;
 	struct block_device *bdev;
 	pgoff_t size, pgoff;
-	loff_t lstart, lend;
 	sector_t block;
-	int result = 0;
+	int error, result = 0;
+	bool alloc = false;
 
 	/* dax pmd mappings require pfn_t_devmap() */
 	if (!IS_ENABLED(CONFIG_FS_DAX_PMD))
@@ -629,10 +829,17 @@ int __dax_pmd_fault(struct vm_area_struct *vma, unsigned long address,
 	block = (sector_t)pgoff << (PAGE_SHIFT - blkbits);
 
 	bh.b_size = PMD_SIZE;
-	if (get_block(inode, block, &bh, write) != 0)
+
+	if (get_block(inode, block, &bh, 0) != 0)
 		return VM_FAULT_SIGBUS;
+
+	if (!buffer_mapped(&bh) && write) {
+		if (get_block(inode, block, &bh, 1) != 0)
+			return VM_FAULT_SIGBUS;
+		alloc = true;
+	}
+
 	bdev = bh.b_bdev;
-	i_mmap_lock_read(mapping);
 
 	/*
 	 * If the filesystem isn't willing to tell us the length of a hole,
@@ -641,15 +848,20 @@ int __dax_pmd_fault(struct vm_area_struct *vma, unsigned long address,
 	 */
 	if (!buffer_size_valid(&bh) || bh.b_size < PMD_SIZE) {
 		dax_pmd_dbg(&bh, address, "allocated block too small");
-		goto fallback;
+		return VM_FAULT_FALLBACK;
+	}
+
+	/*
+	 * If we allocated new storage, make sure no process has any
+	 * zero pages covering this hole
+	 */
+	if (alloc) {
+		loff_t lstart = pgoff << PAGE_SHIFT;
+		loff_t lend = lstart + PMD_SIZE - 1; /* inclusive */
+
+		truncate_pagecache_range(inode, lstart, lend);
 	}
 
-	/* make sure no process has any zero pages covering this hole */
-	lstart = pgoff << PAGE_SHIFT;
-	lend = lstart + PMD_SIZE - 1; /* inclusive */
-	i_mmap_unlock_read(mapping);
-	unmap_mapping_range(mapping, lstart, PMD_SIZE, 0);
-	truncate_inode_pages_range(mapping, lstart, lend);
 	i_mmap_lock_read(mapping);
 
 	/*
@@ -733,6 +945,31 @@ int __dax_pmd_fault(struct vm_area_struct *vma, unsigned long address,
 		}
 		dax_unmap_atomic(bdev, &dax);
 
+		/*
+		 * For PTE faults we insert a radix tree entry for reads, and
+		 * leave it clean.  Then on the first write we dirty the radix
+		 * tree entry via the dax_pfn_mkwrite() path.  This sequence
+		 * allows the dax_pfn_mkwrite() call to be simpler and avoid a
+		 * call into get_block() to translate the pgoff to a sector in
+		 * order to be able to create a new radix tree entry.
+		 *
+		 * The PMD path doesn't have an equivalent to
+		 * dax_pfn_mkwrite(), though, so for a read followed by a
+		 * write we traverse all the way through __dax_pmd_fault()
+		 * twice.  This means we can just skip inserting a radix tree
+		 * entry completely on the initial read and just wait until
+		 * the write to insert a dirty entry.
+		 */
+		if (write) {
+			error = dax_radix_entry(mapping, pgoff, dax.sector,
+					true, true);
+			if (error) {
+				dax_pmd_dbg(&bh, address,
+						"PMD radix insertion failed");
+				goto fallback;
+			}
+		}
+
 		dev_dbg(part_to_dev(bdev->bd_part),
 				"%s: %s addr: %lx pfn: %lx sect: %llx\n",
 				__func__, current->comm, address,
@@ -791,15 +1028,20 @@ EXPORT_SYMBOL_GPL(dax_pmd_fault);
  * dax_pfn_mkwrite - handle first write to DAX page
  * @vma: The virtual memory area where the fault occurred
  * @vmf: The description of the fault
- *
  */
 int dax_pfn_mkwrite(struct vm_area_struct *vma, struct vm_fault *vmf)
 {
-	struct super_block *sb = file_inode(vma->vm_file)->i_sb;
+	struct file *file = vma->vm_file;
 
-	sb_start_pagefault(sb);
-	file_update_time(vma->vm_file);
-	sb_end_pagefault(sb);
+	/*
+	 * We pass NO_SECTOR to dax_radix_entry() because we expect that a
+	 * RADIX_DAX_PTE entry already exists in the radix tree from a
+	 * previous call to __dax_fault().  We just want to look up that PTE
+	 * entry using vmf->pgoff and make sure the dirty tag is set.  This
+	 * saves us from having to make a call to get_block() here to look
+	 * up the sector.
+	 */
+	dax_radix_entry(file->f_mapping, vmf->pgoff, NO_SECTOR, false, true);
 	return VM_FAULT_NOPAGE;
 }
 EXPORT_SYMBOL_GPL(dax_pfn_mkwrite);
diff --git a/include/linux/dax.h b/include/linux/dax.h
index e9d57f680f5034c5b1182ae8ca77cd82956c42a8..8204c3dc3800f37839513d7f2c0567d6a41f87a7 100644
--- a/include/linux/dax.h
+++ b/include/linux/dax.h
@@ -41,4 +41,6 @@ static inline bool dax_mapping(struct address_space *mapping)
 {
 	return mapping->host && IS_DAX(mapping->host);
 }
+int dax_writeback_mapping_range(struct address_space *mapping, loff_t start,
+		loff_t end);
 #endif
diff --git a/mm/filemap.c b/mm/filemap.c
index 1e215fc36c835eb6c8d79bc16b160aca0e44fd07..2e7c8d980d5e8bd96d6f68f5d1b1bdec936740b5 100644
--- a/mm/filemap.c
+++ b/mm/filemap.c
@@ -482,6 +482,12 @@ int filemap_write_and_wait_range(struct address_space *mapping,
 {
 	int err = 0;
 
+	if (dax_mapping(mapping) && mapping->nrexceptional) {
+		err = dax_writeback_mapping_range(mapping, lstart, lend);
+		if (err)
+			return err;
+	}
+
 	if (mapping->nrpages) {
 		err = __filemap_fdatawrite_range(mapping, lstart, lend,
 						 WB_SYNC_ALL);