/*
 * Copyright (c) 2007, 2008 University of Tsukuba
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice,
 *    this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 *    this list of conditions and the following disclaimer in the documentation
 *    and/or other materials provided with the distribution.
 * 3. Neither the name of the University of Tsukuba nor the names of its
 *    contributors may be used to endorse or promote products derived from
 *    this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */
/*
 * Copyright (c) 2010-2016 Yuichi Watanabe
 */

#include <common.h>
#include <common/calc.h>
#include <core.h>
#include <core/mm.h>
#include <core/mmio.h>
#include <core/gmm.h>
#include <core/process.h>
#include <core/panic.h>
#include <core/printf.h>
#include <core/string.h>
#include <core/spinlock.h>
#include <core/vm.h>
#include <core/vmmerr.h>
#include <io.h>

#include "dmar.h"
#include "pci.h"
#include "pci_match.h"

/* #define IOMMU_DEBUG */
#ifdef IOMMU_DEBUG
#define IOMMU_DBG(...)						\
	do {							\
		printf("IOMMU: " __VA_ARGS__);		\
	} while (0)
#else
#define IOMMU_DBG(...)
#endif

/* #define DUMP_IOPT */

#define PAGE_SHIFT PAGESHIFT
#define PAGE_SIZE PAGESIZE
#define PAGE_MASK  (((u64)-1) << PAGE_SHIFT)

#define MAX_PAGE_LEVEL	6
#define MIN_PAGE_LEVEL	2

#define PERM_DMA_NO 0  // 00b
#define PERM_DMA_RO 1  // 01b
#define PERM_DMA_WO 2  // 10b
#define PERM_DMA_RW 3  // 11b

/*
 * Intel VT-d register offset
 */
#define CAP_REG    0x8     /* Capabilities, 64 bit */
#define ECAP_REG   0x10    /* Extended-capabilities, 64 bit */
#define GCMD_REG   0x18    /* Global command register, 32 bit */
#define GSTS_REG   0x1c    /* Global status register, 32 bit */
#define RTADDR_REG 0x20    /* Root-entry table address, 64 bit */
#define CCMD_REG   0x28    /* Context command register, 64 bit*/
#define FSTS_REG   0x34    /* Fault status register, 32 bit */
#define  FSTS_REG_FRI_SHIFT 8
#define  FSTS_REG_FRI_MASK  0xFF
#define FECTL_REG  0x38    /* Fault event control register, 32 bit */

/*
 * Decoding Extended Capability Register
 */
#define ecap_iro(e)   ((((e) >> 8) & 0x3ff) * 16)
#define ecap_c(e)     ((e >> 0) & 0x1)

/* IOTLB Invalidate Register Field Offset */
#define IOTLB_FLUSH_GLOBAL (((u64)1) << 60)
#define IOTLB_DRAIN_READ   (((u64)1) << 49)
#define IOTLB_DRAIN_WRITE  (((u64)1) << 48)
#define IOTLB_IVT          (((u64)1) << 63)

/*
 * Global Command Register Field Offset
 */
#define GCMD_TE     (((u64)1) << 31)
#define GCMD_SRTP   (((u64)1) << 30)
#define GCMD_WBF    (((u64)1) << 27)

/*
 * Global Status Register Field Offset
 */
#define GSTS_TES    (((u64)1) << 31)
#define GSTS_RTPS   (((u64)1) << 30)
#define GSTS_WBFS   (((u64)1) << 27)

/*
 * Context Command Register Field Offset
 */
#define CCMD_ICC   (((u64)1) << 63)
#define CCMD_GLOBAL_INVL (((u64)1) << 61)

/*
 * Decoding Fault Status Register
 */
#define FSTS_MASK   ((u64)0x7f)

/*
 * Fault Recoding Registers
 */
#define FR_F	0x8000000000000000LL

/* IO pagetable walk */
#define IOPT_LEVEL_STRIDE	(9)
#define IOPT_LEVEL_MASK		((1 << IOPT_LEVEL_STRIDE) - 1)
#define IOPT_L1PTN(gpfn)	(gpfn >> IOPT_LEVEL_STRIDE)

/* 8086:d155 specific definittion */
#define VTD_PCI_VTBAR		0x180
#define  VTD_PCI_VTBAR_ADDRESS_MASK 0xFFFFE000
#define  VTD_PCI_VTBAR_ENABLE	0x00000001
#define VTD_MMREG_BASE		0xFD000000
#define VTD_ISOCH_OFFSET	0x1000

#define SHARED_CTXTBL		1

struct iommu {
	LIST_DEFINE(iommu_list) ;
	phys_t reg_phys;
	void *reg;    /* register base address of the unit = drhd->address */
	u32 gcmd;	      /* maintaining TE field */
	u64 cap;	      /* capability register */
	u64 ecap;	      /* extended capability register */
	int page_level;
	spinlock_t unit_lock; /* iommu lock */
	spinlock_t reg_lock;  /* register operation lock */
	struct root_entry *root_entry; /* virtual address */
	phys_t root_entry_phys ;/* physical address */
};

// Translation Sturcture Format

/*
 * Root Entry Structure and Ops.
 */
struct root_entry {
	union {
		struct {
			unsigned p: 1 ;
			unsigned rsvd: 11 ;
			u64 ctp: 52 ;
		} v;
		u64 value;
	};
	u64 rsvd;
};

/*
 * Context Entry Structure and Ops.
 */
struct context_entry {
	union {
		struct {
			unsigned p:1 ;
			unsigned fpd:1 ;
			unsigned t:2 ;
			unsigned eh:1 ;
			unsigned alh:1 ;
			unsigned rsvd: 6 ;
			u64 asr: 52 ;
		} l;
		u64 low;
	};
	union {	
		struct {
			unsigned aw: 3 ;
			unsigned avail: 4 ;
			unsigned rsfc: 1 ;
			unsigned did:16 ;
			u64 rsvd: 40 ;
		} h;
		u64 high;
	};
};

/*
 * Page-Table Entry Structure and Ops.
 */
struct iopt_entry {
	union {
		struct {
			unsigned r: 1 ;
			unsigned w: 1 ;
			unsigned avail1: 5 ;
			unsigned sp: 1;
			unsigned avail2: 3;
			unsigned snp: 1;
			u64 addr: 40 ;
			unsigned avail3: 10 ;
			unsigned tm: 1;
			unsigned avail4: 1 ;
		};
		u64 value;
	};
};

struct io_domain
{
	unsigned short int  domain_id;
	phys_t pgd;   /* io page directory root */
	/* shared context table */
	phys_t shared_ctxtbl[MAX_PAGE_LEVEL - MIN_PAGE_LEVEL + 1];
	spinlock_t iopt_lock;  /* io page table lock */
	phys_t max_addr;
	int page_level;
	bool coherency;
};

static unsigned int clflush_size;
static LIST_DEFINE_HEAD(iommu_list);
static struct io_domain *io_domain0;

static inline void asm_clflush (void *a)
{
	__asm__ volatile ("clflush (%0)" : : "r"(a));
}

/*
 * Decoding Capability Register
 */

/* Fault-recording register offset */
#define cap_fro(c)    ((((c) >> 24) & 0x3ff) * 16)

/* Number of fault-recording register */
#define cap_nfr(c)    ((((c) >> 40) & 0xff) + 1)

/* Maximum guest address width */
#define cap_mgaw(c)   ((((c) >> 16) & 0x3f) + 1)

/* Supported adjusted guest address widths */
#define cap_sagaw(c)  (((c) >> 8) & 0x1f)

/* write-buffer flushing requirement */
#define cap_rwbf(c)   (((c) >> 4) & 1)

#define root_entry_present(r)	((r).v.p)
#define set_root_present(r)	do {(r).v.p = 1;} while(0)
#define root_entry_ctp(r)	((r).v.ctp << PAGE_SHIFT)
#define clear_root_entry(r)	do {(r).value = 0;} while(0)

#define set_ctp(r, value)		\
	do {(r).v.ctp = (value) >> PAGE_SHIFT;} while(0)

#define context_entry_present(c) \
	((c).l.p)

#define set_context_present(c) \
	do {(c).l.p = 1;} while(0)

#define set_context_domid(c, d)	\
	do {(c).h.did = (d) ;} while (0)

#define enable_fault_handling(c) \
	do {(c).l.fpd = 0;} while(0)

#define set_context_trans_type(c, val) \
	do {(c).l.t = val;} while(0)

#define set_asr(c, val)			\
	do {(c).l.asr = (val) >> PAGE_SHIFT ;} while(0)

#define set_agaw(c, val) \
	do {(c).h.aw = (val) & 7;} while(0)

#define clear_context_entry(c) \
	do {(c).low = 0; (c).high = 0;} while(0)

#define iopt_level_offset(addr, level) \
        ((addr >> (PAGE_SHIFT + ((level) - 1) * IOPT_LEVEL_STRIDE)) & IOPT_LEVEL_MASK)

#define io_page_size(level) \
	(1ULL << (PAGE_SHIFT + ((level) - 1) * IOPT_LEVEL_STRIDE))

#define io_max_addr(level)						\
	((1ULL << (PAGE_SHIFT + level * IOPT_LEVEL_STRIDE)) - 1)

#define io_page_mask(level) \
	((phys_t)(io_page_size(level) - 1))

#define set_pte_perm(p, prot) \
	do { (p).r = prot & 1; (p).w = (prot & 2) >> 1; } while (0)

#define get_pte_addr(p)	\
	((p).addr << PAGE_SHIFT)

#define set_pte_addr(p, address) \
	do {(p).addr = (address) >> PAGE_SHIFT ;} while(0)

#define get_pte_vailed(p) \
	((p).r || (p).w)

#define clear_pte(p) \
	do {(p).value = 0;} while(0)

static inline unsigned int
cpuid_ebx(unsigned int op)
{
	unsigned int eax, ebx, ecx, edx;

	__asm__ volatile ("cpuid"
		      : "=a" (eax), "=b" (ebx), "=c" (ecx), "=d" (edx)
		      : "a" (op));
	return ebx;
}

static void
flush_cacheline(struct iommu *iommu, void *addr, int size)
{
	int i=0 ;

	if (ecap_c(iommu->ecap)) return; // 'Coherency' field of ECAP register

	while (i < size) {
		asm_clflush(addr + i);
		i += clflush_size ;
	} ;
}

static void
flush_cacheline_dom(struct io_domain *dom, void *addr, int size)
{
	int i=0 ;

	if (dom->coherency)
		return;

	while (i < size) {
		asm_clflush(addr + i);
		i += clflush_size ;
	} ;
}

#define flush_cacheline_dw(iommu, addr) \
	flush_cacheline(iommu, addr, 8)
#define flush_cacheline_pg(iommu, addr) \
	flush_cacheline(iommu, addr, PAGE_SIZE)
#define flush_cacheline_dw_dom(dom, addr) \
	flush_cacheline_dom(dom, addr, 8)
#define flush_cacheline_pg_dom(dom, addr) \
	flush_cacheline_dom(dom, addr, PAGE_SIZE)

/*
 * Flush write buffer in root-complex.
 */
static void
flush_write_buffer(struct iommu *iommu)
{
	u32 val;

	if (!cap_rwbf(iommu->cap)) {
		// no need for write-buffer flushing to ensure changes to
		// memory-resident structures are visible to hardware

		return;
	}
	val = iommu->gcmd | GCMD_WBF;

	spinlock_lock(&iommu->reg_lock);
	write32(iommu->reg + GCMD_REG, val);

	// wait until completion
	for (;;) {
		val = read32(iommu->reg + GSTS_REG);
		if (!(val & GSTS_WBFS))
			break;
		cpu_relax();
	}
	spinlock_unlock(&iommu->reg_lock);
}

// Context-Cache global invalidation
static int
flush_context_cache(struct iommu *iommu)
{
	u64 val = CCMD_GLOBAL_INVL | CCMD_ICC;

	spinlock_lock(&iommu->reg_lock);
	write64(iommu->reg + CCMD_REG, val);

	// wait until complettion
	for (;;) {
		val = read64(iommu->reg + CCMD_REG);
		if (!(val & CCMD_ICC))
			break;
		cpu_relax();
	}
	spinlock_unlock(&iommu->reg_lock);

	return 0;
}

// IOTLB global invalidation
static int
flush_iotlb_global(struct iommu *iommu)
{
	int iotlb_reg_offset = ecap_iro(iommu->ecap);
	u64 val = 0 ;

	// IOTLB global invalidation
	// Also DMA draining will be applied, if supported
	val = IOTLB_FLUSH_GLOBAL|IOTLB_IVT|IOTLB_DRAIN_READ|IOTLB_DRAIN_WRITE;

	spinlock_lock(&iommu->reg_lock);
	write64(iommu->reg + iotlb_reg_offset + 8, val);

	// wait until completion
	for (;;) {
		val = read64(iommu->reg + iotlb_reg_offset + 8);
		if (!(val & IOTLB_IVT))
			break;
		cpu_relax();
	}
	spinlock_unlock(&iommu->reg_lock);

	return 0;
}

static void
flush_all(void)
{
	struct iommu *iommu;

	LIST_FOREACH(iommu_list, iommu) {
		flush_write_buffer(iommu);
		flush_context_cache(iommu);
		flush_iotlb_global(iommu);
	}
}

/* A caller should hold dom->iopt_lock */
static void *
iopt_walk(struct io_domain *dom, phys_t addr, int req_level, int *result_level)
{
	struct iopt_entry *tbl, *pte = NULL;
	int level;
	int offset;
	void *vaddr = NULL;
	phys_t phys;
	vmmerr_t ret;

	if (addr > dom->max_addr) {
		return NULL;
	}

	tbl = (void *)phys_to_virt(dom->pgd);
	for (level = dom->page_level; level > 1; level--) {
		offset = iopt_level_offset(addr, level);
		pte = &tbl[offset];

		if (get_pte_addr(*pte) == 0) {
			/* if iommu page table is NOT present ... */
			if (req_level < 0) {
				break;
			}
			if (level <= req_level) {
				break;
			}
			ret = alloc_page(&vaddr, &phys);
			if (ret!=0) {
				return NULL;
			}
			memset(vaddr, 0, PAGE_SIZE);
			flush_cacheline_pg_dom(dom, vaddr);

			set_pte_addr(*pte, (phys & PAGE_MASK));
			set_pte_perm(*pte, PERM_DMA_RW);
			flush_cacheline_dw_dom(dom, pte);
		} else {
			if (pte->sp) {
				break;
			}
			phys = get_pte_addr(*pte);
			vaddr = (void *)phys_to_virt(phys);
		}

		tbl = (struct iopt_entry *)vaddr;
	}

	*result_level = level;
	return tbl + iopt_level_offset(addr, level);
}

static phys_t
get_iopt_phys(struct io_domain *dom, int page_level)
{
	struct iopt_entry *tbl, *pte = NULL;
	int level;
	void *vaddr = NULL;
	phys_t phys;
	vmmerr_t ret;

	spinlock_lock(&dom->iopt_lock);
	if (page_level > dom->page_level) {
		panic("requested level %d is more than page_level %d",
		      page_level, dom->page_level);
	}

	phys = dom->pgd;
	tbl = (void *)phys_to_virt(dom->pgd);
	for (level = dom->page_level; level > page_level; level--) {
		pte = &tbl[0];

		if (get_pte_addr(*pte) == 0) {
			ret = alloc_page(&vaddr, &phys);
			if (ret != 0) {
				panic ("Failed to allocate an io page table");
			}
			memset(vaddr, 0, PAGE_SIZE);
			flush_cacheline_pg_dom(dom, vaddr);

			set_pte_addr(*pte, (phys & PAGE_MASK));
			set_pte_perm(*pte, PERM_DMA_RW);
			flush_cacheline_dw_dom(dom, pte);
		} else {
			if (pte->sp) {
				panic("super page found. page_lavel %d",
				      page_level);
			}
			phys = get_pte_addr(*pte);
			vaddr = (void *)phys_to_virt(phys);
		}

		tbl = (struct iopt_entry *)vaddr;
	}
	spinlock_unlock(&dom->iopt_lock);
	return phys;
}

// Set Root-entry table address
static vmmerr_t
set_root_entry_table(struct iommu *iommu)
{
	u32 cmd, stat;
	void *virt;
	vmmerr_t ret;

	if (iommu == NULL) {
		printf("set_root_entry_table: iommu == NULL\n");
		return VMMERR_INVAL;
	}

	if (!iommu->root_entry) {
		ret = alloc_page(&virt, &(iommu->root_entry_phys));
		if (ret != 0) {
			printf("Failed to allocate root entry table.");
			return VMMERR_NOMEM;
		}

		memset((u8*)virt, 0, PAGE_SIZE);
		flush_cacheline_pg(iommu, virt);

		iommu->root_entry = (struct root_entry *)virt;
	}

	spinlock_lock(&iommu->reg_lock);
	write64(iommu->reg + RTADDR_REG, iommu->root_entry_phys);

	cmd = iommu->gcmd | GCMD_SRTP;
	write32(iommu->reg + GCMD_REG, cmd);

	// wait until completion
	for (;;) {
		stat = read32(iommu->reg + GSTS_REG);
		if (stat & GSTS_RTPS)
			break;
		cpu_relax();
	}
	spinlock_unlock(&iommu->reg_lock);

	return 0;
}

// Enable DMA Remapping
static void
enable_dma_remapping(struct iommu *iommu)
{
	u32 stat;

	spinlock_lock(&iommu->reg_lock);
	// Enable translation
	iommu->gcmd |= GCMD_TE;
	write32(iommu->reg + GCMD_REG, iommu->gcmd);
	// Wait until completion
	for (;;) {
		stat = read32(iommu->reg + GSTS_REG);
		if (stat & GSTS_TES)
			break;
		cpu_relax();
	}

	spinlock_unlock(&iommu->reg_lock);
}

static void
map_guest_pages(struct io_domain *dom, phys_t gphys, phys_t len)
{
	phys_t start, end;
	phys_t hphys;
	struct iopt_entry *pte = NULL;
	int level;
	bool fakerom;

	start = ROUND_DOWN(gphys, PAGE_SIZE);
	end = ROUND_UP(gphys + len, PAGE_SIZE) - 1;

	IOMMU_DBG("map 0x%llx - 0x%llx\n", start, end);

	while (start <= end) {
		hphys = gmm_gp2hp(start, &fakerom);
		if (fakerom) {
			start += PAGE_SIZE;
			continue;
		}
		spinlock_lock(&dom->iopt_lock);
		pte = iopt_walk(dom, start, 1, &level);
		if (pte == NULL) {
			panic("Failed to get io-pte.");
		}
		if (level != 1) {
			panic("IOPT level is not 1.");
		}
		clear_pte(*pte);
		set_pte_addr(*pte, hphys);
		set_pte_perm(*pte, PERM_DMA_RW);
		flush_cacheline_dw_dom(dom, pte);
		spinlock_unlock(&dom->iopt_lock);
		start += io_page_size(level);
	}
	return;
}

static void
map_all_guest_pages(struct io_domain *dom)
{
	int index = 0;
	phys_t gphys;
	phys_t len;
	u32 type;
	bool restrict_access;

	/*
	 * 0MB - 1MB area need to be mapped because
	 * BIOS may use the area as a DMA target.
	 */
	map_guest_pages(dom, 0, 0x100000);

	while (gmm_get_mem_map(index++, &gphys, &len, &type,
			       &restrict_access)) {
		if (restrict_access) {
			continue;
		}
		map_guest_pages(dom, gphys, len);
	}
}

#ifdef DUMP_IOPT
static void
dump_iopt(struct io_domain *dom)
{
	phys_t top_of_gphys;
	phys_t gphys, hphys, next_gphys = -1, next_hphys = -1;
	struct iopt_entry pte;
	static spinlock_t dump_lock = SPINLOCK_INITIALIZER;
	int level = 1;

	spinlock_lock(&dump_lock);
	spinlock_lock(&dom->iopt_lock);
	top_of_gphys = gmm_top_of_high_avail_mem();
	if (top_of_gphys == 0) {
		top_of_gphys = gmm_top_of_low_avail_mem();
	}
	printf("Dumping iopt 0x0 to 0x%llx\n", top_of_gphys);
	for (gphys = 0; gphys < top_of_gphys; gphys += io_page_size(level)) {
		pte = *(struct iopt_entry *)iopt_walk(dom, gphys, -1, &level);
		if (!get_pte_vailed(pte)) {
			if (next_hphys != -1) {
				printf(" 0x%llx -> 0x%llx\n",
				       next_gphys - 1,
				       next_hphys - 1);
			}
			next_hphys = -1;
			continue;
		}
		hphys = get_pte_addr(pte);
		if (hphys != next_hphys) {
			if (next_hphys != -1) {
				printf(" 0x%llx -> 0x%llx\n",
				       next_gphys - 1,
				       next_hphys - 1);
			}
			printf("    0x%llx -> 0x%llx", gphys, hphys);
		}
		next_gphys = gphys + io_page_size(level);
		next_hphys = hphys + io_page_size(level);
	}
	if (next_hphys != 0) {
		printf(" 0x%llx -> 0x%llx\n",
		       next_gphys - 1,
		       next_hphys - 1);
	}
	spinlock_unlock(&dom->iopt_lock);
	spinlock_unlock(&dump_lock);
}
#endif

/*
 * Setup device context information
 * A caller should hold iommu->unit_lock.
 */
static struct context_entry *
get_context_entry(struct iommu *iommu, u8 bus, u8 devfn)
{
	struct root_entry *root;
	struct context_entry *context;
	void *virt;
	phys_t phys;
	vmmerr_t ret;

	root = &iommu->root_entry[bus];

	if (!root_entry_present(*root)) {
		panic("Root entry not present");
	}
	context = (struct context_entry *)phys_to_virt(
		(phys_t)root_entry_ctp(*root));
	if (context->h.avail == SHARED_CTXTBL) {
		ret = alloc_page(&virt, &phys);
		if (ret != 0) {
			printf("get_context_entry: Can't alloc a page.\n");
			return NULL;
		}
		/*
		 * Copy on write
		 */
		IOMMU_DBG("Copy on write context table 0x%02x\n",
			  bus);
		memcpy(virt, context, PAGE_SIZE);
		flush_cacheline_pg(iommu, (void *)(virt));
		clear_root_entry(*root);
		set_ctp(*root, phys);
		set_root_present(*root);
		flush_cacheline_dw(iommu, root);
		context = virt;
		context->h.avail = 0;
	}

	return &context[devfn];
}

static void
map_device_to_domain_iommu(struct iommu *iommu, struct io_domain *dom,
			   u8 bus_no, u8 devfn)
{
	struct context_entry *context;

	spinlock_lock(&iommu->unit_lock);
	
	context = get_context_entry(iommu, bus_no, devfn);
	if (!context) {
		printf("map_device_to_domain: "
		       "Can't get context entry.\n");
		spinlock_unlock(&iommu->unit_lock);
		return;
	}

	clear_context_entry(*context);
	set_agaw(*context, iommu->page_level - MIN_PAGE_LEVEL);
	set_asr(*context, get_iopt_phys(dom, iommu->page_level));
	set_context_trans_type(*context, 0x0); /* means "ASR field
						  points to a multi-level
						  page-table" */

	enable_fault_handling(*context);
	set_context_domid(*context, dom->domain_id);
	set_context_present(*context);

	flush_cacheline_dw(iommu, context);
	spinlock_unlock(&iommu->unit_lock);
}

static void
map_device_to_domain(struct io_domain *dom, u8 bus_no, u8 devfn)
{
	struct iommu *iommu;

	LIST_FOREACH(iommu_list, iommu) {
		map_device_to_domain_iommu(iommu, dom, bus_no, devfn);
	}
}

static void
map_mmio_resource(struct io_domain *dom, struct pci_device *dev)
{
	struct resource *resource;
	int i;

	for (i = 0; i < PCI_RESOURCE_NUM; i++) {
		resource = dev->resource + i;
		switch (resource->type) {
		case RESOURCE_TYPE_MMIO:
			map_guest_pages(dom, resource->start,
					resource->end - resource->start + 1);
			break;
		}
	}
}

static void
map_bus_to_domain_iommu(struct iommu *iommu, struct io_domain *dom,
			u8 bus_no)
{
	struct root_entry *root;
	phys_t ctxtbl;

	spinlock_lock(&iommu->unit_lock);
	root = &iommu->root_entry[bus_no];
	clear_root_entry(*root);
	ctxtbl = dom->shared_ctxtbl[iommu->page_level - MIN_PAGE_LEVEL];
	if (ctxtbl == 0x0) {
		panic("The context table does not exists. page_level %d",
		      iommu->page_level);
	}
	set_ctp(*root, ctxtbl);
	set_root_present(*root);
	flush_cacheline_dw(iommu, root);
	spinlock_unlock(&iommu->unit_lock);
}

static void
map_bus_to_domain(struct io_domain *dom, u8 bus_no)
{
	struct iommu *iommu;

	LIST_FOREACH(iommu_list, iommu) {
		map_bus_to_domain_iommu(iommu, dom, bus_no);
	}
}

static void
map_assigned_devices_to_domain(struct io_domain *dom)
{
	struct pci_device *dev;
	int bus_no;

	for (dev = pci_next_assgined_pci_device(NULL); dev;
	     dev = pci_next_assgined_pci_device(dev)) {
		map_mmio_resource(dom, dev);

		if (vm_get_id() == 0) {
			continue;
		}
		IOMMU_DBG("map bus 0x%x devfn 0x%x\n",
			  dev->bus_no, dev->devfn);
		map_device_to_domain(dom, dev->bus_no, dev->devfn);
		/*
		 * If assigned device is PCI-PCI bridge, Assign all
		 * devices behind it to a specified io domain.
		 */
		if (dev->type == PCI_CONFIG_HEADER_TYPE_1) {
			IOMMU_DBG("map bus 0x%x - 0x%x\n",
				dev->sec_bus, dev->sub_bus);
			for (bus_no = dev->sec_bus; bus_no <= dev->sub_bus;
			     bus_no++) {
				map_bus_to_domain(dom, bus_no);
			}
		}
	}
}

static void
clear_fault_record(struct iommu *iommu)
{
	u64 fro; // Fault Register Offset
	int nfo; // Number of fault register offset
	int i;

	spinlock_lock(&iommu->reg_lock);
	fro = cap_fro(iommu->cap);
	nfo = cap_nfr(iommu->cap);
	for (i = 0; i < nfo; i++) {
		/* writing back to clear fault status */
		write64(iommu->reg + fro + i * 16 + 8, 0x80000000);
	}
	/* Clearing lower 7bits of Fault Status Register */
	write32(iommu->reg + FSTS_REG, FSTS_MASK);
	spinlock_unlock(&iommu->reg_lock);
}

static void
dump_fault_record(struct iommu *iommu)
{
	u64 val;
	u64 fro; // Fault Register Offset
	int i;
	u32 status;
	int fri;
	int nfo; // Number of fault register offset

	spinlock_lock(&iommu->reg_lock);
	status = read32(iommu->reg + FSTS_REG);
	printf("VT-d fault status 0x%08X\n", status);
	fri = (status >> FSTS_REG_FRI_SHIFT) & FSTS_REG_FRI_MASK;

	fro = cap_fro(iommu->cap);
	nfo = cap_nfr(iommu->cap);
	for (i = 0; i < nfo; i++) {
		if (fri >= nfo) {
			fri = 0;
		}
		val = read64(iommu->reg + fro + fri * 16 + 8);
		if ((val & FR_F) == 0) {
			break;
		}
		printf("VT-d fault record[%d] 0x%016llX", fri, val);
		val = read64(iommu->reg + fro + fri * 16);
		printf("_%016llX\n", val);
		write64(iommu->reg + fro + fri * 16 + 8, FR_F);
		fri++;
	}
	write32(iommu->reg + FSTS_REG, FSTS_MASK);
	spinlock_unlock(&iommu->reg_lock);
}

static int
get_max_suppoted_page_level(struct iommu * iommu)
{
	u64 sagaw;
	int i;
	int page_level = 0;

	sagaw = cap_sagaw(iommu->cap);
	for (i = MAX_PAGE_LEVEL; i >= MIN_PAGE_LEVEL; i--) {
		if (sagaw & (1UL << (i - MIN_PAGE_LEVEL))) {
			page_level = i;
			break;
		}
	}
	return page_level;
}

static struct iommu *
new_iommu(phys_t reg_base)
{
	struct iommu *iommu;
	vmmerr_t ret;

	iommu = alloc(sizeof(struct iommu));
	if (!iommu) {
		return NULL;
	}

	memset(iommu, 0, sizeof(struct iommu));

	iommu->reg_phys = reg_base;
	iommu->reg = mapmem_hphys(reg_base, PAGE_SIZE,
				  MAPMEM_WRITE | MAPMEM_UC);
	if (iommu->reg == NULL) {
		free(iommu);
		return NULL;
	}
	iommu->cap = read64(iommu->reg + CAP_REG);
	iommu->ecap = read64(iommu->reg + ECAP_REG);
	iommu->page_level = get_max_suppoted_page_level(iommu);

	printf("IOMMU: base 0x%llx, cap 0x%llx, ecap 0x%llx, page_level %d\n",
	       reg_base, iommu->cap, iommu->ecap, iommu->page_level);

	spinlock_init(&iommu->unit_lock);
	spinlock_init(&iommu->reg_lock);

	LIST_APPEND(iommu_list, iommu);

	ret = set_root_entry_table(iommu);
	if (ret) {
		panic("IOMMU: set root entry failed\n");
	}
	clear_fault_record(iommu);
	write32(iommu->reg + FECTL_REG, 0);  /* clearing IM field */
	return iommu;
}

static void
vtd_pci_new(struct pci_device *dev)
{
	phys_t reg_base = 0;
	static int vtd_pci_count = 0;
	static u32 reg_val, writing_val;

	printf("VT-d PCI device %s found.\n", dev->name);
	vtd_pci_count++;

	if (vtd_pci_count > 1) {
		printf("Multi VT-d PCI device is not supported.\n");
		return;
	}

	reg_val = pci_read_config32(dev, VTD_PCI_VTBAR);
	if (reg_val == 0xffffffff) {
		printf("Failed to access VTBAR register.\n");
		return;
	}
	reg_base = reg_val & VTD_PCI_VTBAR_ADDRESS_MASK;
	if (reg_base == 0) {
		reg_base = VTD_MMREG_BASE;
		printf("Configure VT-d memory mapped registers to 0x%llx\n",
		       reg_base);
		writing_val = (reg_val & ~VTD_PCI_VTBAR_ADDRESS_MASK) |
			(reg_base & VTD_PCI_VTBAR_ADDRESS_MASK)
			| VTD_PCI_VTBAR_ENABLE;
	} else {
		printf("VT-d memory mapped registers is locate at 0x%llx\n",
		       reg_base);
		writing_val = reg_val | VTD_PCI_VTBAR_ENABLE;
	}
	pci_write_config32(dev, VTD_PCI_VTBAR, writing_val);
	reg_val = pci_read_config32(dev, VTD_PCI_VTBAR);
	if (writing_val != reg_val) {
		printf("Can't write VTD_PCI_VTBAR. writing 0x%X reading 0x%X\n",
		       writing_val, reg_val);
		return;
	}
	new_iommu(reg_base + VTD_ISOCH_OFFSET);
	new_iommu(reg_base);
}

struct pci_driver vtd_pci_driver = {
	.device = "id=8086:d155",
	.new = vtd_pci_new,
	.name = "vtd_pci",
	.longname = "VT-d PCI driver"
};

static int
vtddump_msghandler(int m, int c)
{
	struct iommu *iommu;

	LIST_FOREACH(iommu_list, iommu) {
		dump_fault_record(iommu);
	}
	return 0;
}

static void
setup_all_root_entry_to_domain(struct iommu *iommu,
			       struct io_domain *dom)
{
	int bus_no;
	struct root_entry *root;
	phys_t ctxtbl;

	for (bus_no = 0; bus_no < 256; bus_no++) {
		root = &iommu->root_entry[bus_no];
		if (root_entry_present(*root)) {
			printf("map_bus_to_domain: "
			       "root entry is already set. 0x02%x\n",
			       bus_no);
			continue;
		}
		clear_root_entry(*root);
		ctxtbl = dom->shared_ctxtbl[iommu->page_level - MIN_PAGE_LEVEL];
		if (ctxtbl == 0x0) {
			panic("The context table does not exists. "
			      "page_level %d",
			      iommu->page_level);
		}
		set_ctp(*root, ctxtbl);
		set_root_present(*root);
	}
	flush_cacheline_pg(iommu, iommu->root_entry);
}

static vmmerr_t
init_domain(struct io_domain *dom, unsigned short int domain_id)
{
	struct iommu *iommu = NULL;
	struct context_entry *context;
	void *vaddr;
	vmmerr_t ret;
	int page_level = 0, level;
	u32 page_level_mask = 0;
	bool coherency = 1;
	int i;

	iommu = LIST_HEAD(iommu_list);
	ASSERT(iommu);

	LIST_FOREACH(iommu_list, iommu) {
		page_level_mask |= (1U << iommu->page_level);
		if (page_level < iommu->page_level) {
			page_level = iommu->page_level;
		}
		coherency = coherency && ecap_c(iommu->ecap);
	}

	dom->domain_id = domain_id;
	dom->page_level = page_level;
	dom->max_addr = io_max_addr(page_level);
	dom->coherency = coherency;
	printf("IOMMU: io-domain %d "
	       "page_level %d max_addr 0x%llx coherency %d\n",
	       dom->domain_id, dom->page_level, dom->max_addr,
	       dom->coherency);

	/* Allocate a io page directory root */
	ret = alloc_page(&vaddr, &dom->pgd);
	if (ret != 0) {
		return VMMERR_NOMEM;
	}
	memset(vaddr, 0, PAGE_SIZE);
	flush_cacheline_pg_dom(dom, vaddr);

	/* Setup a shared context table */
	for (level = MIN_PAGE_LEVEL; level <= MAX_PAGE_LEVEL; level++ ) {
		if ((page_level_mask & (1U << level)) == 0) {
			continue;
		}
		ret = alloc_page(&vaddr, &dom->shared_ctxtbl[level
			- MIN_PAGE_LEVEL]);
		if (ret != 0) {
			return VMMERR_NOMEM;
		}
		memset(vaddr, 0, PAGE_SIZE);
		context = vaddr;
		context->h.avail = SHARED_CTXTBL;
		for (i = 0; i < 256; i++) {
			clear_context_entry(context[i]);
			set_agaw(context[i], level - MIN_PAGE_LEVEL);
			set_asr(context[i], get_iopt_phys(dom, level));
			/* Set multi-level page-table */
			set_context_trans_type(context[i], 0x0);
			enable_fault_handling(context[i]);
			set_context_domid(context[i], dom->domain_id);
			set_context_present(context[i]);
		}
		flush_cacheline_pg_dom(dom, vaddr);
	}
	spinlock_init(&dom->iopt_lock);

	return 0;
}

static struct io_domain *
create_domain(unsigned short int domain_id)
{
	struct io_domain *dom;

	if ((dom = alloc(sizeof(struct io_domain))) == NULL)
		return NULL;

	memset(dom, 0, sizeof(*dom));
	if (init_domain(dom, domain_id) != 0) {
		free(dom);
		return NULL; // Init. failed.
	}

	return dom;
}

static void
iommu_drvinit(void)
{
	clflush_size = ((cpuid_ebx(1) >> 8) & 0xff) * 8;

	dmar_init();
	if (! dmar_drhd_exists()) {
		pci_register_driver(&vtd_pci_driver);
		pci_match_add_compat("driver=vtd_pci");
	}
}

static void
iommu_devinit(void)
{
	struct dmar_drhd_u *drhd;
	struct iommu *iommu;

	if (dmar_drhd_exists()) {
		for (drhd = dmar_get_next_drhd(NULL); drhd;
		     drhd = dmar_get_next_drhd(drhd)) {
			iommu = new_iommu(drhd->address);
			if (!iommu) {
				continue;
			}
			drhd->iommu = iommu;
		}
	}

	if (!LIST_HEAD(iommu_list)) {
		return;
	}

	io_domain0 = create_domain(0);
	if (io_domain0 == NULL) {
		panic("Failed to create io domain 0");
	}

	LIST_FOREACH(iommu_list, iommu) {
		setup_all_root_entry_to_domain(iommu, io_domain0);
	}
	msgregister ("vtddump", vtddump_msghandler);
}

static void
iommu_initvm(void)
{
	struct io_domain *dom;
	struct iommu *iommu;

	if (!LIST_HEAD(iommu_list)) {
		return;
	}

	if (vm_get_id() == 0) {
		/*
		 * Hide iommu from OS.
		 */
		LIST_FOREACH(iommu_list, iommu) {
			mmio_register(iommu->reg_phys, PAGESIZE,
				      mmio_do_nothing, NULL);
		}
		dom = io_domain0;
	} else {
		dom = create_domain(vm_get_id());
		if (!dom) {
			panic("Failed to create domain.");
		}
	}

	map_all_guest_pages(dom);
	map_assigned_devices_to_domain(dom);
#ifdef DUMP_IOPT
	dump_iopt(dom);
#endif

	flush_all();
	vm_iommu_enabled();
}

static void
iommu_enable(void)
{
	struct iommu *iommu;

	if (!LIST_HEAD(iommu_list)) {
		return;
	}

	flush_all();

	LIST_FOREACH(iommu_list, iommu)
	{
		enable_dma_remapping(iommu);
	}
	printf("VT-d IOMMU enabled.\n");
}

DRIVER_INIT(iommu_drvinit);
DRIVER_DEVINIT(iommu_devinit);
DRIVER_VMINIT(iommu_initvm);
DRIVER_START(iommu_enable);
