
#include "kernel.h"
#include "screen.h"
#include "memmgr.h"

MemMgr MemoryManager;

MemMgr::MemMgr(void) : notassigns(), frees() {
    for (uint32 i = 0; i < REGION_NUM; i++) {
        notassigns.add(regions[i]);
    }
}

void MemMgr::init(multiboot_info_t* mbi) {
    extern int _start;

    start = uint32(&_start);
    size  = mbi->mem_upper * 1024;

    size_t kernel_size;
    if ((mbi->flags & MBI_FLAG_MODS_VALID) != 0) { // mods valid ? 
        module_t* mode = (module_t*)mbi->mods_addr;
        kernel_size = mode[mbi->mods_count - 1].mod_end - start;
    } else {
        extern void* _end;
        kernel_size = uint32(&_end) - start;
    }


    // create free list
    Region* free_region = notassigns.get(0);
    if (free_region != 0) {
        notassigns.remove(*free_region);
        free_region->start = start + kernel_size;
        free_region->size  = size - kernel_size;
        frees.add(*free_region);
    }
}

void* MemMgr::alloc(size_t size) {
    void* addr = 0;

    if (size != 0) {
        IntMgr::disable();

        size = getAllocSize(size);
        Region* free_region  = frees.searchBySize(size);
        Region* alloc_region = notassigns.get(0);

        if ((free_region != 0) && (alloc_region != 0)) {
            notassigns.remove(*alloc_region);
            alloc_region->start = free_region->start;
            alloc_region->size  = size;
            allocs.add(*alloc_region);
            addr = (void*)(alloc_region->start);

            free_region->start += size;
            free_region->size  -= size;
            if (free_region->size == 0) {
                frees.remove(*free_region);
                notassigns.add(*free_region);
            }
        }

        IntMgr::enable();
    }

    return addr;
}

void MemMgr::free(void* addr) {
    if (addr != 0) {
        IntMgr::disable();

        Region* alloc_region = allocs.searchByAddr(addr);
        if (alloc_region != 0) {
            allocs.remove(*alloc_region);

            // merge free regions, if you can.
            Region* merge_region;
            while ((merge_region = frees.merge(*alloc_region)) != 0) {
                notassigns.add(*alloc_region);
                alloc_region = merge_region;
                frees.remove(*alloc_region);
            }
            frees.add(*alloc_region);
        }

        IntMgr::enable();
    }
}

void RegionList::add(Region& region) {
    region.next = 0;
    region.prev = bottom;
    if (bottom == 0) {
        top = &region;
    } else {
        bottom->next = &region;
    }
    bottom = &region;
}

void RegionList::remove(Region& region) {
    if (region.prev == 0) {
        top = region.next;
    } else {
        region.prev->next = region.next;
    }
    if (region.next == 0) {
        bottom = region.prev;
    } else {
        region.next->prev = region.prev;
    }

    region.prev = region.next = 0;
}

Region* RegionList::merge(Region& region) {
    Region* merge_region = 0;

    for (Region* r = top; r != 0; r = r->next) {
        if (region.start + region.size == r->start) {
            r->start = region.start;
            r->size += region.size;
            merge_region = r;
            break;
        } else if (r->start + r->size == region.start) {
            r->size += region.size;
            merge_region = r;
            break;
        }
    }

    return merge_region;
}

Region* RegionList::get(uint32 i) {
    Region* region = top;
    for (uint32 j = 0; j < i; j++) {
        if (region == 0) break;
        region = region->next;
    } 
    return region;
}

Region* RegionList::searchBySize(size_t size) {
    Region* target_region = 0;
    for (Region* region = top; region != 0; region = region->next) {
        if (region->size >= size) {
            target_region = region;
            break;
        }
    }
    return target_region;
}

Region* RegionList::searchByAddr(void* addr) {
    Region* target_region = 0;
    for (Region* region = top; region != 0; region = region->next) {
        if (region->start == uint32(addr)) {
            target_region = region;
            break;
        }
    }
    return target_region;
}

void RegionList::show(void) {
    Console.printf("list :: top(%x), bottom(%x)\n", uint32(top), uint32(bottom));
    Region* region = top;
    while (region != 0) {
        Console.printf("  reg[%x] :: start(%x), size(%x)\n", 
                       uint32(region), region->start, region->size);
        region = region->next;
    }
}

// new / delete

void* operator new(size_t size) {
    return MemoryManager.alloc(size);
}

void operator delete(void* addr) {
    MemoryManager.free(addr);
}
