/*
 * Copyright (c) 2014 Yuichi Watanabe
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice,
 *    this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 *    this list of conditions and the following disclaimer in the documentation
 *    and/or other materials provided with the distribution.
 * 3. Neither the name of the copyright holder nor the names of its
 *    contributors may be used to endorse or promote products derived from
 *    this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#include <core/printf.h>
#include "asm.h"
#include "constants.h"
#include "vcpu.h"

#define CHECK_VMCS_1(reg_name, bitmask)				\
do {								\
	ulong reg_val;						\
	asm_vmread((reg_name), &reg_val);			\
	if ((reg_val & (bitmask)) != (bitmask)) {		\
		printf("%s 0x%lx (0x%lx must be 1)\n",		\
		       #reg_name, (reg_val), (ulong)(bitmask));	\
	}							\
} while(0)

#define CHECK_VMCS_0(reg_name, bitmask)				\
do {								\
	ulong reg_val;						\
	asm_vmread((reg_name), &reg_val);			\
	if ((reg_val & (bitmask)) != 0) {		\
		printf("%s 0x%lx (0x%lx must be 0)\n",		\
		       #reg_name, reg_val, (ulong)(bitmask));	\
	}							\
} while(0)

#define CHECK_VMCS_VIRTUAL_8086_SEG(sel, base, limit, access)	\
do {								\
	ulong reg_val, reg_val2;				\
	asm_vmread((sel), &reg_val);				\
	asm_vmread((base), &reg_val2);				\
	if ((reg_val << 4) != reg_val2) {			\
		printf("%s %lx, %s %lx\n",			\
		       #sel, reg_val, #base, reg_val2);		\
	}							\
	asm_vmread((limit), &reg_val);				\
	if (reg_val != 0xffff) {				\
		printf("%s %lx must be 0xffff\n",		\
		       #limit, reg_val);			\
	}							\
	asm_vmread((access), &reg_val);				\
	if (reg_val != 0xf3) {					\
		printf("%s %lx must be 0xf3\n",			\
		       #access, reg_val);			\
	}							\
} while(0)

#define CHECK_VMCS_SEG_COMMON(sel, base, limit, access)		\
do {								\
	ulong reg_val, reg_val2;				\
	asm_vmread((access), &reg_val);				\
	if((reg_val & VMCS_GUEST_ACCESS_RIGHTS_UNUSABLE) == 0) {	\
		if ((reg_val & 0x80) != 0x80) {			\
			printf("Bit 7(P) of %s 0x%lx must be 1\n",	\
			       #access, reg_val);		\
		}						\
		if ((reg_val & 0xffff0f00) != 0) {			\
			printf("Reserved fields of %s 0x%lx must be 0\n", \
			       #access, reg_val);		\
		}						\
		asm_vmread((limit), &reg_val2);			\
		if (((reg_val2 & 0xfff) < 0xfff)		\
		    && (reg_val & 0x8000) != 0) {		\
			printf("Bit 15(G) of %s 0x%lx must be 0. "	\
			       "limit 0x%lx\n",			\
			       #access, reg_val, reg_val2);	\
		}						\
		if (((reg_val2 & 0xfff00000) >= 0x100000)	\
		    && (reg_val & 0x8000) != 0x8000) {		\
			printf("Bit 15(G) of %s 0x%lx must be 1. "	\
			       "limit 0x%lx\n",			\
			       #access, reg_val, reg_val2);	\
		}						\
	}							\
} while(0)

#define CHECK_VMCS_SEG(sel, base, limit, access)		\
do {								\
	ulong reg_val;						\
	ulong reg_val2;						\
	asm_vmread((access), &reg_val);				\
	if((reg_val & VMCS_GUEST_ACCESS_RIGHTS_UNUSABLE) == 0) {	\
		if ((reg_val & 0x1) != 0x1) {			\
			printf("Bit 0 of %s 0x%lx must be 1\n",	\
			       #access, reg_val);		\
		}						\
		if ((reg_val & 0x8) == 0x8 && (reg_val & 0x2) != 0x2) {	\
			printf("If bit 3 is 1, bit 1 of %s 0x%lx must be 1\n", \
			       #access, reg_val);		\
		}						\
		if ((reg_val & 0x10) != 0x10) {			\
			printf("Bit 4(S) of %s 0x%lx must be 1\n",	\
			       #access, reg_val);		\
		}						\
		asm_vmread((sel), &reg_val2);			\
		if ((reg_val & 0x3) <= 11			\
		    && ((reg_val & 0x60 >> 5) < (reg_val2 & 0x3)))	\
			printf("DPL of %s 0x%lx can't be less than "	\
			      "RPL of %s 0x%lx\n",		\
			       #access, reg_val, #sel, reg_val2);	\
	}							\
	CHECK_VMCS_SEG_COMMON(sel, base, limit, access);	\
} while(0)

static void
vt_check_cr_dr_msr(void)
{
	ulong reg_val;

	/*
	 * 26.3.1.1 Checks on Guest Control Registers, Debug Registers,
	 and MSRs
	*/
	CHECK_VMCS_1(VMCS_GUEST_CR0, CR0_PE_BIT | CR0_NE_BIT | CR0_PG_BIT);
	CHECK_VMCS_1(VMCS_GUEST_CR4, CR4_VMXE_BIT);

	/* TODO: check IA32_DEBUGCTL MSR */

	asm_vmread(VMCS_VMENTRY_CTL, &reg_val);
	if (reg_val & VMCS_VMENTRY_CTL_64_GUEST_BIT) {
		CHECK_VMCS_1(VMCS_GUEST_CR0, CR0_PE_BIT);
		CHECK_VMCS_1(VMCS_GUEST_CR4, CR4_PAE_BIT);
	} else {
		CHECK_VMCS_0(VMCS_GUEST_CR4, CR4_PCIDE_BIT);
	}

	asm_vmread(VMCS_VMEXIT_CTL, &reg_val);
	if (reg_val & VMCS_VMEXIT_CTL_LOAD_IA32_PAT_BIT) {
		int i;
		u8 pat_encoding;
		asm_vmread(VMCS_GUEST_IA32_PAT, &reg_val);
		for (i = 0; i < MSR_DATA_PAT_COUNT; i++) {
			pat_encoding = MSR_PAT_FIELD(reg_val, i);
			switch (pat_encoding) {
			case 0:
			case 1:
			case 4:
			case 5:
			case 6:
			case 7:
				break;
			default:
				printf("Unsupported pat encoding "
				       "0x%lx %i 0x%x\n",
				       reg_val, i, pat_encoding);
			}
		}
	}
}

static void
vt_check_seg(void)
{
	ulong reg_val, reg_val2;

	/*
	 * 26.3.1.2 Checks on Guest Segment Registers
	 */

	/* TR */
	CHECK_VMCS_0(VMCS_GUEST_TR_SEL, 0x4); /* TI (bit 2) */

	/* LDTR */
	asm_vmread (VMCS_GUEST_LDTR_ACCESS_RIGHTS, &reg_val);
	if ((reg_val & VMCS_GUEST_ACCESS_RIGHTS_UNUSABLE) == 0) {
		CHECK_VMCS_0(VMCS_GUEST_LDTR_SEL, 0x4); /* TI (bit 2) */
	}

	asm_vmread (VMCS_GUEST_RFLAGS, &reg_val);
	if (reg_val & RFLAGS_VM_BIT) {
		/* virtual 8086 */
		CHECK_VMCS_VIRTUAL_8086_SEG(VMCS_GUEST_CS_SEL,
					    VMCS_GUEST_CS_BASE,
					    VMCS_GUEST_CS_LIMIT,
					    VMCS_GUEST_CS_ACCESS_RIGHTS);
		CHECK_VMCS_VIRTUAL_8086_SEG(VMCS_GUEST_SS_SEL,
					    VMCS_GUEST_SS_BASE,
					    VMCS_GUEST_SS_LIMIT,
					    VMCS_GUEST_SS_ACCESS_RIGHTS);
		CHECK_VMCS_VIRTUAL_8086_SEG(VMCS_GUEST_DS_SEL,
					    VMCS_GUEST_DS_BASE,
					    VMCS_GUEST_DS_LIMIT,
					    VMCS_GUEST_DS_ACCESS_RIGHTS);
		CHECK_VMCS_VIRTUAL_8086_SEG(VMCS_GUEST_ES_SEL,
					    VMCS_GUEST_ES_BASE,
					    VMCS_GUEST_ES_LIMIT,
					    VMCS_GUEST_ES_ACCESS_RIGHTS);
		CHECK_VMCS_VIRTUAL_8086_SEG(VMCS_GUEST_FS_SEL,
					    VMCS_GUEST_FS_BASE,
					    VMCS_GUEST_FS_LIMIT,
					    VMCS_GUEST_FS_ACCESS_RIGHTS);
		CHECK_VMCS_VIRTUAL_8086_SEG(VMCS_GUEST_GS_SEL,
					    VMCS_GUEST_GS_BASE,
					    VMCS_GUEST_GS_LIMIT,
					    VMCS_GUEST_GS_ACCESS_RIGHTS);
	} else {
		asm_vmread(VMCS_GUEST_SS_SEL, &reg_val);
		asm_vmread(VMCS_GUEST_CS_SEL, &reg_val2);
		if ((reg_val & 0x3) != (reg_val2 & 0x3)) {
			printf("RPL of VMCS_GUEST_SS_SEL 0x%lx "
			       "and VMCS_GUEST_CS_SEL 0x%lx "
			       "are not the same.\n",
			       reg_val, reg_val2);

		}

		/* CS */
		asm_vmread(VMCS_GUEST_CS_ACCESS_RIGHTS, &reg_val);
		switch(reg_val & 0xf) {
		case 9:
		case 11:
			asm_vmread(VMCS_GUEST_SS_ACCESS_RIGHTS, &reg_val2);
			if ((reg_val & 0x60) != (reg_val2 & 0x60)) {
				printf("DPL of CS and SS are not the same. "
				       "0x%lx, 0x%lx\n", reg_val, reg_val2);
			}
			break;
		case 13:
		case 15:
			asm_vmread(VMCS_GUEST_SS_ACCESS_RIGHTS, &reg_val2);
			if ((reg_val & 0x60) > (reg_val2 & 0x60)) {
				printf("DPL of CS can't be grater than "
				       "DPL of SS%lx, 0x%lx\n",
				       reg_val, reg_val2);
			}
			break;
		default:
			printf("VMCS_GUEST_CS_ACCESS_RIGHTS 0x%lx (type)\n",
			       reg_val);
			break;
		}
		if ((reg_val & 0x10) != 0x10) {
			printf("VMCS_GUEST_CS_ACCESS_RIGHTS 0x%lx (S)\n",
			       reg_val);
		}
		CHECK_VMCS_SEG_COMMON(VMCS_GUEST_CS_SEL,
				      VMCS_GUEST_CS_BASE,
				      VMCS_GUEST_CS_LIMIT,
				      VMCS_GUEST_CS_ACCESS_RIGHTS);

		/* SS */
		asm_vmread(VMCS_GUEST_SS_ACCESS_RIGHTS, &reg_val);
		if((reg_val & VMCS_GUEST_ACCESS_RIGHTS_UNUSABLE) == 0) {
			switch(reg_val & 0xf) {
			case 3:
			case 7:
				break;
			default:
				printf("VMCS_GUEST_SS_ACCESS_RIGHTS 0x%lx "
				       "(type)\n",
				       reg_val);
				break;
			}
			if ((reg_val & 0x10) != 0x10) {
				printf("VMCS_GUEST_SS_ACCESS_RIGHTS "
				       "0x%lx (S)\n",
				       reg_val);
			}
		}
		asm_vmread(VMCS_GUEST_SS_SEL, &reg_val2);
		if (((reg_val & 0x60) >> 5) != (reg_val2 & 0x03)) {
			printf("DPL of VMCS_GUEST_SS_ACCESS_RIGHTS 0x%lx "
			       "and RPL of VMCS_GUEST_SS_SEL 0x%lx "
			       "are not the same\n",
			       reg_val, reg_val2);

		}
		CHECK_VMCS_SEG_COMMON(VMCS_GUEST_SS_SEL,
				      VMCS_GUEST_SS_BASE,
				      VMCS_GUEST_SS_LIMIT,
				      VMCS_GUEST_SS_ACCESS_RIGHTS);

		CHECK_VMCS_SEG(VMCS_GUEST_DS_SEL,
			       VMCS_GUEST_DS_BASE,
			       VMCS_GUEST_DS_LIMIT,
			       VMCS_GUEST_DS_ACCESS_RIGHTS);
		CHECK_VMCS_SEG(VMCS_GUEST_ES_SEL,
			       VMCS_GUEST_ES_BASE,
			       VMCS_GUEST_ES_LIMIT,
			       VMCS_GUEST_ES_ACCESS_RIGHTS);
		CHECK_VMCS_SEG(VMCS_GUEST_FS_SEL,
			       VMCS_GUEST_FS_BASE,
			       VMCS_GUEST_FS_LIMIT,
			       VMCS_GUEST_FS_ACCESS_RIGHTS);
		CHECK_VMCS_SEG(VMCS_GUEST_GS_SEL,
			       VMCS_GUEST_GS_BASE,
			       VMCS_GUEST_GS_LIMIT,
			       VMCS_GUEST_GS_ACCESS_RIGHTS);
	}
	/* TR */
	asm_vmread(VMCS_GUEST_TR_ACCESS_RIGHTS, &reg_val);
	if((reg_val & VMCS_GUEST_ACCESS_RIGHTS_UNUSABLE) == 0) {
		if ((reg_val & 0x10) != 0) {
			printf("VMCS_GUEST_TR_ACCESS_RIGHTS "
			       "0x%lx (S)\n",
			       reg_val);
		}
	}
	CHECK_VMCS_SEG_COMMON(VMCS_GUEST_TR_SEL,
			      VMCS_GUEST_TR_BASE,
			      VMCS_GUEST_TR_LIMIT,
			      VMCS_GUEST_TR_ACCESS_RIGHTS);
	/* LDTR */
	asm_vmread(VMCS_GUEST_LDTR_ACCESS_RIGHTS, &reg_val);
	if((reg_val & VMCS_GUEST_ACCESS_RIGHTS_UNUSABLE) == 0) {
		switch(reg_val & 0xf) {
		case 2:
			break;
		default:
			printf("VMCS_GUEST_LDTR_ACCESS_RIGHTS 0x%lx "
			       "(type)\n",
			       reg_val);
			break;
		}
		if ((reg_val & 0x10) != 0) {
			printf("VMCS_GUEST_LDTR_ACCESS_RIGHTS "
			       "0x%lx (S)\n",
			       reg_val);
		}
	}
	CHECK_VMCS_SEG_COMMON(VMCS_GUEST_LDTR_SEL,
			      VMCS_GUEST_LDTR_BASE,
			      VMCS_GUEST_LDTR_LIMIT,
			      VMCS_GUEST_LDTR_ACCESS_RIGHTS);

	asm_vmread(VMCS_VMENTRY_CTL, &reg_val);
	if (reg_val & VMCS_VMENTRY_CTL_64_GUEST_BIT) {
		/* IA-32e mode */

		/* CS */
		asm_vmread(VMCS_GUEST_CS_ACCESS_RIGHTS, &reg_val);
		if ((reg_val & 0x2000) == 0x2000
		    && (reg_val & 0x4000) != 0) {
			printf("If bit 13(L) is 1, bit 14(D/B) of "
			       "VMCS_GUEST_CS_ACCESS_RIGHTS 0x%lx (D/B) "
			       "must be 0.\n",
			       reg_val);
		}

		/* TR */
		asm_vmread(VMCS_GUEST_TR_ACCESS_RIGHTS, &reg_val);
		if((reg_val & VMCS_GUEST_ACCESS_RIGHTS_UNUSABLE) == 0) {
			switch(reg_val & 0xf) {
			case 3:
			case 11:
				break;
			default:
				printf("VMCS_GUEST_TR_ACCESS_RIGHTS 0x%lx "
				       "(type)\n",
				       reg_val);
				break;
			}
		}
	} else {
		/* TR */
		asm_vmread(VMCS_GUEST_TR_ACCESS_RIGHTS, &reg_val);
		if((reg_val & VMCS_GUEST_ACCESS_RIGHTS_UNUSABLE) == 0) {
			switch(reg_val & 0xf) {
			case 11:
				break;
			default:
				printf("VMCS_GUEST_TR_ACCESS_RIGHTS 0x%lx "
				       "(type)\n",
				       reg_val);
				break;
			}
		}
	}
}

void
vt_check_rip_rflags(void)
{
	ulong reg_val;

	/*
	 * 26.3.1.4 Checks on Guest RIP and RFLAGS
	 */

	asm_vmread(VMCS_VMENTRY_CTL, &reg_val);
	if (reg_val & VMCS_VMENTRY_CTL_64_GUEST_BIT) {
		asm_vmread(VMCS_GUEST_CS_ACCESS_RIGHTS, &reg_val);
		if ((reg_val & ACCESS_RIGHTS_L_BIT) == 0) {
			CHECK_VMCS_0(VMCS_GUEST_RIP, 0xffffffff00000000);
		}
	} else {
		CHECK_VMCS_0(VMCS_GUEST_RIP, 0xffffffff00000000);
	}

	CHECK_VMCS_1(VMCS_GUEST_RFLAGS, RFLAGS_ALWAYS1_BIT);
	CHECK_VMCS_0(VMCS_GUEST_RFLAGS, RFLAGS_RESERVED_MASK);
}

void
vt_vm_entry_check(void)
{
	vt_check_cr_dr_msr();
	vt_check_seg();
	vt_check_rip_rflags();
}
