// $Id: pe.cpp,v 1.4 2003/01/31 07:22:18 yuya Exp $

////////////////////////////////////////////////////////////////////////////////

#include "exerb.h"

////////////////////////////////////////////////////////////////////////////////

static PIMAGE_IMPORT_DESCRIPTOR ExGetFirstImportDescriptor(DWORD base_address, DWORD *import_table_delta);
static bool ExGetImportTable(PIMAGE_NT_HEADERS nt_header, DWORD *address, DWORD *delta);
static void ExReplaceImportDllName(DWORD base_address, DWORD import_table_delta, PDWORD name_pool_address, PDWORD name_pool_size, PIMAGE_IMPORT_DESCRIPTOR first_descriptor, char* src, char* dest);
static void ExReplaceImportFunctionName(DWORD offset_of_name, PIMAGE_IMPORT_DESCRIPTOR first_descriptor, const char* target_dll_name, const char* source, const char* destination);
static PIMAGE_NT_HEADERS ExGetNtHeader(PIMAGE_DOS_HEADER dos_header);
static bool  ExGetSectionUnusedArea(PIMAGE_NT_HEADERS nt_header, char *section_name, PDWORD address, PDWORD size);
static PIMAGE_SECTION_HEADER ExGetEnclosingSectionHeader(PIMAGE_NT_HEADERS nt_header, DWORD rva);
static PIMAGE_SECTION_HEADER ExFindSection(PIMAGE_NT_HEADERS nt_header, char *section_name);

////////////////////////////////////////////////////////////////////////////////

extern char g_phi_so_filename[MAX_PATH];

////////////////////////////////////////////////////////////////////////////////

bool
ExReplaceImportTable(void *buffer)
{
	const DWORD base_address = (DWORD)buffer;

	DWORD import_table_delta = 0;
	const PIMAGE_IMPORT_DESCRIPTOR descriptor = ::ExGetFirstImportDescriptor(base_address, &import_table_delta);
	if ( !descriptor ) return false;

	PIMAGE_NT_HEADERS nt_header = ::ExGetNtHeader((PIMAGE_DOS_HEADER)base_address);

	DWORD name_pool_address = 0;
	DWORD name_pool_size    = 0;
	::ExGetSectionUnusedArea(nt_header, ".idata", &name_pool_address, &name_pool_size);

	char self_filename[MAX_PATH] = "";
	::ExGetSelfFileName(self_filename, sizeof(self_filename));

#ifdef RUBY18
	::ExReplaceImportDllName(base_address, import_table_delta, &name_pool_address, &name_pool_size, descriptor, "msvcrt-ruby17.dll",  self_filename);
	::ExReplaceImportDllName(base_address, import_table_delta, &name_pool_address, &name_pool_size, descriptor, "msvcrt-ruby18.dll",  self_filename);
	::ExReplaceImportDllName(base_address, import_table_delta, &name_pool_address, &name_pool_size, descriptor, "cygwin-ruby17.dll",  self_filename);
	::ExReplaceImportDllName(base_address, import_table_delta, &name_pool_address, &name_pool_size, descriptor, "cygwin-ruby18.dll",  self_filename);
#else
	::ExReplaceImportDllName(base_address, import_table_delta, &name_pool_address, &name_pool_size, descriptor, "mswin32-ruby16.dll", self_filename);
	::ExReplaceImportDllName(base_address, import_table_delta, &name_pool_address, &name_pool_size, descriptor, "mingw32-ruby16.dll", self_filename);
	::ExReplaceImportDllName(base_address, import_table_delta, &name_pool_address, &name_pool_size, descriptor, "cygwin-ruby16.dll",  self_filename);
#endif
	::ExReplaceImportDllName(base_address, import_table_delta, &name_pool_address, &name_pool_size, descriptor, "ruby.exe",           self_filename);

	if ( ::strlen(g_phi_so_filename) > 0 ) {
		::ExReplaceImportDllName(base_address, import_table_delta, &name_pool_address, &name_pool_size, descriptor, "phi.so", g_phi_so_filename);
	}

	const DWORD offset_of_name = base_address - import_table_delta;
	::ExReplaceImportFunctionName(offset_of_name, descriptor, self_filename, "rb_require",   "ex_require");
	::ExReplaceImportFunctionName(offset_of_name, descriptor, self_filename, "rb_f_require", "ex_f_require");

	return true;
}

////////////////////////////////////////////////////////////////////////////////

static PIMAGE_IMPORT_DESCRIPTOR
ExGetFirstImportDescriptor(DWORD base_address, DWORD *import_table_delta)
{
	const PIMAGE_NT_HEADERS nt_header = ::ExGetNtHeader((PIMAGE_DOS_HEADER)base_address);

	DWORD import_table_address = 0;
	if ( !::ExGetImportTable(nt_header, &import_table_address, import_table_delta) ) {
		return NULL;
	}

	const DWORD import_table_base = base_address + import_table_address;

	return (PIMAGE_IMPORT_DESCRIPTOR)import_table_base;
}

static bool
ExGetImportTable(PIMAGE_NT_HEADERS nt_header, DWORD *address, DWORD *delta)
{
	const DWORD import_table_rva = nt_header->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress;

	const PIMAGE_SECTION_HEADER import_section = ::ExGetEnclosingSectionHeader(nt_header, import_table_rva);
	if ( !import_section ) {
		*delta   = 0;
		*address = 0;
		return false;
	}

	*delta   = import_section->VirtualAddress - import_section->PointerToRawData;
	*address = import_table_rva - *delta;

	return true;
}

static void
ExReplaceImportDllName(DWORD base_address, DWORD import_table_delta, PDWORD name_pool_address, PDWORD name_pool_size, PIMAGE_IMPORT_DESCRIPTOR first_descriptor, char* src, char* dest)
{
	DEBUGMSG2("ExReplaceImportDllName(..., '%s', '%s')\n", src, dest);

	const DWORD offset_of_name = base_address - import_table_delta;
	const DWORD length_of_dest = ::strlen(dest);

	for ( PIMAGE_IMPORT_DESCRIPTOR descriptor = first_descriptor; descriptor->Name; descriptor++ ) {
		char *dll_name = (char*)(offset_of_name + descriptor->Name);

		if ( ::stricmp(dll_name, src) != 0 ) {
			continue;
		}

		if ( length_of_dest <= ::strlen(dll_name) ) {
			::strcpy(dll_name, dest);
		} else if ( length_of_dest + 1 <= *name_pool_size ) {
			const DWORD address = *name_pool_address - length_of_dest - 1;
			descriptor->Name = address + import_table_delta;

			::memcpy((void*)(base_address + address), dest, length_of_dest);

			*name_pool_address -= length_of_dest + 1;
			*name_pool_size    -= length_of_dest + 1;
		} else {
			::rb_raise(rb_eLoadError, "Fail to modify the import table. exe/dll file name is too long.");
		}
	}
}

static void
ExReplaceImportFunctionName(DWORD offset_of_name, PIMAGE_IMPORT_DESCRIPTOR first_descriptor, const char* target_dll_name, const char* source, const char* destination)
{
	for ( PIMAGE_IMPORT_DESCRIPTOR pDescriptor = first_descriptor; pDescriptor->Name; pDescriptor++ ) {
		const char *dll_name = (char*)(offset_of_name + pDescriptor->Name);

		if ( ::stricmp(dll_name, target_dll_name) != 0 ) {
			continue;
		}

		PIMAGE_THUNK_DATA thunk    = (PIMAGE_THUNK_DATA)pDescriptor->Characteristics;
		PIMAGE_THUNK_DATA thunkIAT = (PIMAGE_THUNK_DATA)pDescriptor->FirstThunk;

		if ( !thunk ) {
			if ( !thunkIAT ) {
				continue;
			}
			thunk = thunkIAT;
		}

		thunk    = (PIMAGE_THUNK_DATA)((DWORD)thunk    + offset_of_name);
		thunkIAT = (PIMAGE_THUNK_DATA)((DWORD)thunkIAT + offset_of_name);

		while ( thunk->u1.AddressOfData ) {
			if ( !(thunk->u1.Ordinal & IMAGE_ORDINAL_FLAG) ) {
				const PIMAGE_IMPORT_BY_NAME import_by_name = (PIMAGE_IMPORT_BY_NAME)((DWORD)(thunk->u1.AddressOfData) + offset_of_name);
				const LPSTR function_name = (LPSTR)import_by_name->Name;
				if ( ::strcmp(function_name, source) == 0 ) {
					::strcpy(function_name, destination);
				}
			}

			thunk++;
			thunkIAT++;
		}
	}
}

static PIMAGE_NT_HEADERS
ExGetNtHeader(PIMAGE_DOS_HEADER dos_header)
{
	return (PIMAGE_NT_HEADERS)((DWORD)dos_header + dos_header->e_lfanew);
}

static bool
ExGetSectionUnusedArea(PIMAGE_NT_HEADERS nt_header, char *section_name, PDWORD address, PDWORD size)
{
	const PIMAGE_SECTION_HEADER section = ::ExFindSection(nt_header, section_name);
	if ( !section ) {
		*address = 0;
		*size    = 0;
		return false;
	}

	*address = section->PointerToRawData + section->SizeOfRawData;
	*size    = section->SizeOfRawData    - section->Misc.VirtualSize;

	return true;
}

static PIMAGE_SECTION_HEADER
ExGetEnclosingSectionHeader(PIMAGE_NT_HEADERS nt_header, DWORD rva)
{
	PIMAGE_SECTION_HEADER section = IMAGE_FIRST_SECTION(nt_header);

	for ( int i = 0; i < nt_header->FileHeader.NumberOfSections; i++, section++ ) {
		if ( (rva >= section->VirtualAddress) && (rva < (section->VirtualAddress + section->Misc.VirtualSize)) ) {
			return section;
		}
	}

	return NULL;
}

static PIMAGE_SECTION_HEADER
ExFindSection(PIMAGE_NT_HEADERS nt_header, char *section_name)
{
	PIMAGE_SECTION_HEADER section = IMAGE_FIRST_SECTION(nt_header);

	for ( int i = 0; i < nt_header->FileHeader.NumberOfSections; i++, section++ ) {
		if ( ::strnicmp(section_name, (char*)section->Name, 8) == 0 ) {
			return section;
		}
	}

	return NULL;
}

////////////////////////////////////////////////////////////////////////////////
