mmu: get virtual alignment from physical address

On ARM64 platforms, when mapping multiple memory zones with size
not multiple of a L2 block size (2MiB), all the following mappings
will probably use L3 tables.

And a huge mapping will consume all possible L3 tables.

In order to reduce usage of L3 tables, this introduces a new
arch_virt_region_align() optional architecture specific
call to eventually return a more optimal virtual address
alignment than the default MMU_PAGE_SIZE.

This alignment is used in virt_region_alloc() by:
- requesting more pages in virt_region_bitmap to make sure we request
  up to the possible aligned virtual address
- freeing the supplementary pages used for alignment

Suggested-by: Nicolas Pitre <npitre@baylibre.com>
Signed-off-by: Neil Armstrong <narmstrong@baylibre.com>
diff --git a/include/sys/arch_interface.h b/include/sys/arch_interface.h
index 3a49deb..6b280d8 100644
--- a/include/sys/arch_interface.h
+++ b/include/sys/arch_interface.h
@@ -650,6 +650,21 @@
 int arch_buffer_validate(void *addr, size_t size, int write);
 
 /**
+ * Get the optimal virtual region alignment to optimize the MMU table layout
+ *
+ * Some MMU HW requires some region to be aligned to some of the intermediate
+ * block alignment in order to reduce table usage.
+ * This call returns the optimal virtual address alignment in order to permit
+ * such optimization in the following MMU mapping call.
+ *
+ * @param[in] phys Physical address of region to be mapped, aligned to MMU_PAGE_SIZE
+ * @param[in] size Size of region to be mapped, aligned to MMU_PAGE_SIZE
+ *
+ * @retval alignment to apply on the virtual address of this region
+ */
+size_t arch_virt_region_align(uintptr_t phys, size_t size);
+
+/**
  * Perform a one-way transition from supervisor to kernel mode.
  *
  * Implementations of this function must do the following:
diff --git a/kernel/mmu.c b/kernel/mmu.c
index 72f8bf9..5e5f78b 100644
--- a/kernel/mmu.c
+++ b/kernel/mmu.c
@@ -226,40 +226,6 @@
 	virt_region_inited = true;
 }
 
-static void *virt_region_alloc(size_t size)
-{
-	uintptr_t dest_addr;
-	size_t offset;
-	size_t num_bits;
-	int ret;
-
-	if (unlikely(!virt_region_inited)) {
-		virt_region_init();
-	}
-
-	num_bits = size / CONFIG_MMU_PAGE_SIZE;
-	ret = sys_bitarray_alloc(&virt_region_bitmap, num_bits, &offset);
-	if (ret != 0) {
-		LOG_ERR("insufficient virtual address space (requested %zu)",
-			size);
-		return NULL;
-	}
-
-	/* Remember that bit #0 in bitmap corresponds to the highest
-	 * virtual address. So here we need to go downwards (backwards?)
-	 * to get the starting address of the allocated region.
-	 */
-	dest_addr = virt_from_bitmap_offset(offset, size);
-
-	/* Need to make sure this does not step into kernel memory */
-	if (dest_addr < POINTER_TO_UINT(Z_VIRT_REGION_START_ADDR)) {
-		(void)sys_bitarray_free(&virt_region_bitmap, size, offset);
-		return NULL;
-	}
-
-	return UINT_TO_POINTER(dest_addr);
-}
-
 static void virt_region_free(void *vaddr, size_t size)
 {
 	size_t offset, num_bits;
@@ -282,6 +248,86 @@
 	(void)sys_bitarray_free(&virt_region_bitmap, num_bits, offset);
 }
 
+static void *virt_region_alloc(size_t size, size_t align)
+{
+	uintptr_t dest_addr;
+	size_t alloc_size;
+	size_t offset;
+	size_t num_bits;
+	int ret;
+
+	if (unlikely(!virt_region_inited)) {
+		virt_region_init();
+	}
+
+	/* Possibly request more pages to ensure we can get an aligned virtual address */
+	num_bits = (size + align - CONFIG_MMU_PAGE_SIZE) / CONFIG_MMU_PAGE_SIZE;
+	alloc_size = num_bits * CONFIG_MMU_PAGE_SIZE;
+	ret = sys_bitarray_alloc(&virt_region_bitmap, num_bits, &offset);
+	if (ret != 0) {
+		LOG_ERR("insufficient virtual address space (requested %zu)",
+			size);
+		return NULL;
+	}
+
+	/* Remember that bit #0 in bitmap corresponds to the highest
+	 * virtual address. So here we need to go downwards (backwards?)
+	 * to get the starting address of the allocated region.
+	 */
+	dest_addr = virt_from_bitmap_offset(offset, alloc_size);
+
+	if (alloc_size > size) {
+		uintptr_t aligned_dest_addr = ROUND_UP(dest_addr, align);
+
+		/* Here is the memory organization when trying to get an aligned
+		 * virtual address:
+		 *
+		 * +--------------+ <- Z_VIRT_RAM_START
+		 * | Undefined VM |
+		 * +--------------+ <- Z_KERNEL_VIRT_START (often == Z_VIRT_RAM_START)
+		 * | Mapping for  |
+		 * | main kernel  |
+		 * | image        |
+		 * |		  |
+		 * |		  |
+		 * +--------------+ <- Z_FREE_VM_START
+		 * | ...          |
+		 * +==============+ <- dest_addr
+		 * | Unused       |
+		 * |..............| <- aligned_dest_addr
+		 * |              |
+		 * | Aligned      |
+		 * | Mapping      |
+		 * |              |
+		 * |..............| <- aligned_dest_addr + size
+		 * | Unused       |
+		 * +==============+ <- offset from Z_VIRT_RAM_END == dest_addr + alloc_size
+		 * | ...          |
+		 * +--------------+
+		 * | Mapping      |
+		 * +--------------+
+		 * | Reserved     |
+		 * +--------------+ <- Z_VIRT_RAM_END
+		 */
+
+		/* Free the two unused regions */
+		virt_region_free(UINT_TO_POINTER(dest_addr),
+				 aligned_dest_addr - dest_addr);
+		virt_region_free(UINT_TO_POINTER(aligned_dest_addr + size),
+				 (dest_addr + alloc_size) - (aligned_dest_addr + size));
+
+		dest_addr = aligned_dest_addr;
+	}
+
+	/* Need to make sure this does not step into kernel memory */
+	if (dest_addr < POINTER_TO_UINT(Z_VIRT_REGION_START_ADDR)) {
+		(void)sys_bitarray_free(&virt_region_bitmap, size, offset);
+		return NULL;
+	}
+
+	return UINT_TO_POINTER(dest_addr);
+}
+
 /*
  * Free page frames management
  *
@@ -492,7 +538,7 @@
 	 */
 	total_size = size + CONFIG_MMU_PAGE_SIZE * 2;
 
-	dst = virt_region_alloc(total_size);
+	dst = virt_region_alloc(total_size, CONFIG_MMU_PAGE_SIZE);
 	if (dst == NULL) {
 		/* Address space has no free region */
 		goto out;
@@ -638,6 +684,23 @@
 	return ret * (size_t)CONFIG_MMU_PAGE_SIZE;
 }
 
+/* Get the default virtual region alignment, here the default MMU page size
+ *
+ * @param[in] phys Physical address of region to be mapped, aligned to MMU_PAGE_SIZE
+ * @param[in] size Size of region to be mapped, aligned to MMU_PAGE_SIZE
+ *
+ * @retval alignment to apply on the virtual address of this region
+ */
+static size_t virt_region_align(uintptr_t phys, size_t size)
+{
+	ARG_UNUSED(phys);
+	ARG_UNUSED(size);
+
+	return CONFIG_MMU_PAGE_SIZE;
+}
+
+__weak FUNC_ALIAS(virt_region_align, arch_virt_region_align, size_t);
+
 /* This may be called from arch early boot code before z_cstart() is invoked.
  * Data will be copied and BSS zeroed, but this must not rely on any
  * initialization functions being called prior to work correctly.
@@ -645,7 +708,7 @@
 void z_phys_map(uint8_t **virt_ptr, uintptr_t phys, size_t size, uint32_t flags)
 {
 	uintptr_t aligned_phys, addr_offset;
-	size_t aligned_size;
+	size_t aligned_size, align_boundary;
 	k_spinlock_key_t key;
 	uint8_t *dest_addr;
 
@@ -657,9 +720,11 @@
 		 "wraparound for physical address 0x%lx (size %zu)",
 		 aligned_phys, aligned_size);
 
+	align_boundary = arch_virt_region_align(aligned_phys, aligned_size);
+
 	key = k_spin_lock(&z_mm_lock);
 	/* Obtain an appropriately sized chunk of virtual memory */
-	dest_addr = virt_region_alloc(aligned_size);
+	dest_addr = virt_region_alloc(aligned_size, align_boundary);
 	if (!dest_addr) {
 		goto fail;
 	}