/******************************************************************************
 * Copyright (C) 2006 Tetsuya Kimata <kimata@acapulco.dyndns.org>
 *
 * All rights reserved.
 *
 * This software is provided 'as-is', without any express or implied
 * warranty.  In no event will the authors be held liable for any
 * damages arising from the use of this software.
 *
 * Permission is granted to anyone to use this software for any
 * purpose, including commercial applications, and to alter it and
 * redistribute it freely, subject to the following restrictions:
 *
 * 1. The origin of this software must not be misrepresented; you must
 *    not claim that you wrote the original software. If you use this
 *    software in a product, an acknowledgment in the product
 *    documentation would be appreciated but is not required.
 *
 * 2. Altered source versions must be plainly marked as such, and must
 *    not be misrepresented as being the original software.
 *
 * 3. This notice may not be removed or altered from any source
 *    distribution.
 *
 * $Id: ReadWriteAttacker.cpp 1786 2006-10-11 15:24:40Z svn $
 *****************************************************************************/

#include "Environment.h"

#include "apr_shm.h"
#include "apr_atomic.h"
#include "apr_thread_mutex.h"
#include "apr_thread_proc.h"
#include "apr_thread_cond.h"

#include "BBSThreadManager.h"
#include "BBSThreadList.h"
#include "BBSThreadListReader.h"
#include "BBSCommentIterator.h"
#include "TestRunner.h"
#include "CleanPointer.h"
#include "Auxiliary.h"
#include "Message.h"
#include "Macro.h"
#include "SourceInfo.h"

SOURCE_INFO_ADD("$Id: ReadWriteAttacker.cpp 1786 2006-10-11 15:24:40Z svn $");

#define AS_THREAD_START(pointer)    reinterpret_cast<apr_thread_start_t>(pointer)

static const apr_size_t READ_THREAD_COUNT   = 10;
static const apr_size_t WRITE_THREAD_COUNT  = 3;
static const apr_size_t RUN_ITER_COUNT      = 5;

typedef struct thread_context {
    apr_pool_t *pool;
    BBSThreadManager *thread_manager;
    BBSThreadList *thread_list;

    apr_size_t *write_ids;

    apr_size_t *ids_sort_by_view;
    apr_size_t *indexes_sort_by_view;

    apr_size_t write_worker_count;
    apr_size_t write_thread_count;
    apr_thread_mutex_t *write_finish_mutex;
    apr_atomic_t write_finish_count;

    apr_size_t read_worker_count;
    apr_size_t read_thread_count;
    apr_atomic_t read_finish_count;

    apr_thread_mutex_t *finish_mutex;
    apr_thread_cond_t *finish_cond;
} thread_context_t;

void show_usage(const char *prog_name)
{
    cerr << "Usage: " << prog_name << " <DATA_DIR_PATH> <THREAD_COUNT>" << endl;
}

static void run_write_by_id(apr_pool_t *pool,
                            BBSThreadManager *thread_manager,
                            apr_size_t thread_count, apr_size_t *thread_ids)
{
    apr_size_t thread_id;

    while (1) {
        thread_id = thread_ids[rand() % thread_count];

        try {
            thread_manager->add_comment(thread_id,
                                        MESSAGE_TEST_NAME, MESSAGE_TEST_TRIP,
                                        MESSAGE_TEST_EMAIL,
                                        MESSAGE_TEST_MESSAGE, apr_time_now(),
                                        true);
            break;
        } catch(const char *) {

        }
    }
}

static void APR_THREAD_FUNC write_attack_worker_by_id(apr_thread_t *thread,
                                                      thread_context_t *context)
{
    try {
        for (apr_size_t i = 1; i < TRD_MAX_COMMENT_COUNT; i++) {
            run_write_by_id(context->pool,
                            context->thread_manager,
                            context->write_thread_count,
                            context->write_ids);
        }

        apr_atomic_inc(&(context->write_finish_count));
        if (apr_atomic_read(&(context->write_finish_count))
            == context->write_worker_count) {
            if (apr_thread_cond_signal(context->finish_cond) != APR_SUCCESS) {
                THROW(MESSAGE_CONDITION_SIGNAL_FAILED);
            }
        }
    } catch(const char *message) {
        cerr << "Error: " << message << endl;
    }
}

static void run_read_by_index(apr_pool_t *pool,
                              BBSThreadManager *thread_manager,
                              apr_size_t thread_count,
                              apr_size_t *ids_sort_by_view,
                              apr_size_t *indexes_sort_by_view)
{
    apr_size_t index;
    BBSCommentIterator *bcomment_iter;

    APR_PALLOC(bcomment_iter, BBSCommentIterator *, pool,
               BBSCommentIterator::MEMORY_SIZE);

    for (apr_size_t i = 0; i < thread_count; i++) {
        index = (i + rand()) % thread_count;

        thread_manager->get_thread_by_index(pool,
                                            indexes_sort_by_view[index],
                                            ids_sort_by_view[index],
                                            bcomment_iter);
        CleanPointer<BBSCommentIterator> clean_ptr(bcomment_iter);
    }
}

static void APR_THREAD_FUNC read_attack_worker_by_index(apr_thread_t *thread,
                                                        thread_context_t *context)
{
    try {
        for (apr_size_t i = 0; i < RUN_ITER_COUNT; i++) {
            run_read_by_index(context->pool,
                              context->thread_manager,
                              context->read_thread_count,
                              context->ids_sort_by_view,
                              context->indexes_sort_by_view);
        }

        apr_atomic_inc(&(context->read_finish_count));
        if (apr_atomic_read(&(context->read_finish_count))
            == context->read_worker_count) {
            if (apr_thread_cond_signal(context->finish_cond) != APR_SUCCESS) {
                THROW(MESSAGE_CONDITION_SIGNAL_FAILED);
            }
        }
    } catch(const char *message) {
        cerr << "Error: " << message << endl;
    }
}

static void init_context(apr_pool_t *pool,
                         apr_size_t write_thread_count,
                         apr_size_t read_thread_count,
                         apr_size_t write_worker_count,
                         apr_size_t read_worker_count,
                         apr_size_t *write_ids,
                         BBSThreadManager *thread_manager,
                         BBSThreadList *thread_list,
                         thread_context_t *context)
{
    context->write_thread_count = write_thread_count;
    context->read_thread_count = read_thread_count;
    thread_manager->get_thread_info_sort_by_view(pool,
                                                 &(context->read_thread_count),
                                                 &(context->ids_sort_by_view),
                                                 &(context->indexes_sort_by_view));
    context->pool = pool;
    context->thread_manager = thread_manager;
    context->thread_list = thread_list;
    context->write_worker_count = write_worker_count;
    context->read_worker_count = read_worker_count;
    context->write_finish_count = 0;
    context->read_finish_count = 0;
    context->write_ids = write_ids;

    if (apr_thread_mutex_create(&(context->finish_mutex),
                                APR_THREAD_MUTEX_DEFAULT,
                                pool) != APR_SUCCESS) {
        THROW(MESSAGE_MUTEX_CREATION_FAILED);
    }
    if (apr_thread_cond_create(&(context->finish_cond), pool)
        != APR_SUCCESS) {
        THROW(MESSAGE_CONDITION_CREATION_FAILED);
    }
    if (apr_thread_mutex_lock(context->finish_mutex) != APR_SUCCESS) {
        THROW(MESSAGE_MUTEX_LOCK_FAILED);
    }
}

static void create_thread(BBSThreadManager *thread_manager,
                          apr_size_t write_thread_count,
                          apr_size_t *thread_ids)
{
    for (apr_size_t i = 0 ; i < write_thread_count; i++) {
        thread_ids[i] = thread_manager->create_thread(MESSAGE_TEST_SUBJECT,
                                                      MESSAGE_TEST_NAME,
                                                      MESSAGE_TEST_TRIP,
                                                      MESSAGE_TEST_EMAIL,
                                                      MESSAGE_TEST_MESSAGE,
                                                      apr_time_now(),
                                                      true);
    }
}

static void start_write_attack(apr_pool_t *pool,
                               apr_size_t write_worker_count,
                               thread_context_t *context,
                               apr_thread_t **workers,
                               volatile double *start_time)
{
    *start_time = get_time_sec();
    // スレッドの開始
    for (apr_size_t i = 0; i < write_worker_count; i++) {
        if (apr_thread_create(workers + i, NULL,
                              AS_THREAD_START(write_attack_worker_by_id),
                              context, pool) != APR_SUCCESS) {
            THROW(MESSAGE_THREAD_CREATION_FAILED);
        }
    }
}

static void start_read_attack(apr_pool_t *pool,
                              apr_size_t read_worker_count,
                              thread_context_t *context,
                              apr_thread_t **workers,
                              volatile double *start_time)
{
    *start_time = get_time_sec();
    // スレッドの開始
    for (apr_size_t i = 0; i < read_worker_count; i++) {
        if (apr_thread_create(workers + i, NULL,
                              AS_THREAD_START(read_attack_worker_by_index),
                              context, pool) != APR_SUCCESS) {
            THROW(MESSAGE_THREAD_CREATION_FAILED);
        }
    }
}

static void wait_attack_finish(apr_pool_t *pool,
                               apr_size_t write_worker_count,
                               apr_size_t read_worker_count,
                               thread_context_t *context,
                               apr_thread_t **write_workers,
                               apr_thread_t **read_workers,
                               volatile double *write_end_time,
                               volatile double *read_end_time)
{
    apr_status_t status;
    apr_size_t finished_attack;

    finished_attack = 0;
    while (finished_attack < 2) {
        if (apr_thread_cond_wait(context->finish_cond,
                                 context->finish_mutex)
            != APR_SUCCESS) {
            THROW(MESSAGE_CONDITION_WAIT_FAILED);
        }

        // 先に終了時刻の記録のみを行う
        if (apr_atomic_read(&(context->read_finish_count))
            == context->read_worker_count) {
            *read_end_time = get_time_sec();
            finished_attack++;
        }
        if (apr_atomic_read(&(context->write_finish_count))
            == context->write_worker_count) {
            *write_end_time = get_time_sec();
            finished_attack++;
        }

        if (apr_atomic_read(&(context->read_finish_count))
            == context->read_worker_count) {
            for (apr_size_t i = 0; i < read_worker_count; i++) {
                apr_thread_join(&status, *(read_workers + i));
            }
        }
        if (apr_atomic_read(&(context->write_finish_count))
            == context->write_worker_count) {
            for (apr_size_t i = 0; i < write_worker_count; i++) {
                apr_thread_join(&status, *(write_workers + i));
            }
        }
    }
}

static void check_write_attack(apr_pool_t *pool,
                               BBSThreadManager *thread_manager,
                               thread_context_t *context)
{
    BBSCommentIterator *bcomment_iter;

    APR_PALLOC(bcomment_iter, BBSCommentIterator *, pool,
               BBSCommentIterator::MEMORY_SIZE);

    for (apr_size_t i = 0; i < context->write_thread_count; i++) {
        thread_manager->get_thread_by_id(pool,
                                         context->write_ids[i],
                                         bcomment_iter);
        if (bcomment_iter->get_comment_count() != TRD_MAX_COMMENT_COUNT) {
            THROW(MESSAGE_BBS_COMMENT_WRITE_FAILED);
        }

        CleanPointer<BBSCommentIterator> clean_ptr(bcomment_iter);
    }
}

static void run_attack(apr_pool_t *pool,
                       apr_size_t read_worker_count,
                       BBSThreadManager *thread_manager,
                       BBSThreadList *thread_list)
{
    apr_size_t write_ids[WRITE_THREAD_COUNT];
    apr_size_t write_worker_count;
    thread_context_t context;
    volatile double write_start_time;
    volatile double write_end_time;
    volatile double read_start_time;
    volatile double read_end_time;
    apr_thread_t **write_workers;
    apr_thread_t **read_workers;

    show_test_name("read_write_attack");

    srand(static_cast<unsigned int>(time(NULL)));

    write_worker_count = WRITE_THREAD_COUNT;
    create_thread(thread_manager, WRITE_THREAD_COUNT, write_ids);

    init_context(pool,
                 WRITE_THREAD_COUNT, READ_THREAD_COUNT,
                 write_worker_count, read_worker_count,
                 write_ids, thread_manager, thread_list,
                 &context);

    APR_PALLOC(write_workers, apr_thread_t **, pool,
               write_worker_count*sizeof(apr_thread_t *));
    APR_PALLOC(read_workers, apr_thread_t **, pool,
               read_worker_count*sizeof(apr_thread_t *));

    start_read_attack(pool, read_worker_count, &context, read_workers,
                      &read_start_time);
    start_write_attack(pool, write_worker_count, &context, write_workers,
                       &write_start_time);

    wait_attack_finish(pool, write_worker_count, read_worker_count,
                       &context, write_workers, read_workers,
                       &write_end_time, &read_end_time);

    check_write_attack(pool, thread_manager, &context);

    show_item("write by id",
              (write_end_time - write_start_time) *
              1000/(TRD_MAX_COMMENT_COUNT-1) /
              context.write_thread_count,
              " msec");
    show_item("read by index",
              (read_end_time - read_start_time) *
              1000/RUN_ITER_COUNT /
              context.read_thread_count/read_worker_count,
              " msec");
}

void run_all(apr_pool_t *pool, int argc, const char * const *argv)
{
    const char *data_dir_path;
    apr_size_t read_worker_count;
    apr_shm_t *manager_shm;
    apr_shm_t *list_shm;

    if (argc != 3) {
        THROW(MESSAGE_ARGUMENT_INVALID);
    }

    data_dir_path = argv[1];
    read_worker_count = atoi(argv[2]);

    if (!is_exist(pool, data_dir_path)) {
        THROW(MESSAGE_DAT_DIR_NOT_FOUND);
    }

    manager_shm = create_shm(pool, BBSThreadManager::get_memory_size());
    list_shm = create_shm(pool, BBSThreadList::get_memory_size());

    BBSThreadManager *thread_manager
        = BBSThreadManager::get_instance(manager_shm, pool, data_dir_path);
    BBSThreadList *thread_list = BBSThreadListReader::read(pool, data_dir_path,
                                                           list_shm);
    thread_manager->init(thread_list);

    show_item("data_dir", data_dir_path);
    show_item("manager memory",
              BBSThreadManager::get_memory_size()/1024/static_cast<double>(1024),
              " MB");
    show_item("on memory thread", thread_manager->get_thread_count(), "");
    show_item("worker", read_worker_count);
    show_line();

    run_attack(pool, read_worker_count, thread_manager, thread_list);

    thread_manager->sync_all_thread();

    // チェック
    BBSThreadManager::check_finalize_state(thread_manager);
}

// Local Variables:
// mode: c++
// coding: utf-8-dos
// End:
