/*
 * @file  packet_editor.cpp
 * @brief Packet Editor Header
 *
 * Copyright (C) 2009  NTT COMWARE Corporation.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
 * 02110-1301 USA
 *
 **********************************************************************/

#include <boost/regex.hpp>
#include "packet_editor.h"
#include "sslproxy.h"
#include "sslproxyserver.h"
#include "sslproxysession.h"

/*!
 * Packet editor constructor.
 */
packet_editor::packet_editor(const sslproxy_session* session)
{
    /*-------- DEBUG LOG --------*/
    if (LOG_LV_DEBUG == logger_get_log_level(LOG_CAT_PACKET_EDIT)) {
        LOGGER_PUT_LOG_DEBUG(LOG_CAT_PACKET_EDIT, 1,
        "in_function : Constructor packet_editor::packet_editor(const sslproxysession* session)");
    }
    /*------ DEBUG LOG END ------*/

    this->session = session;

    /*-------- DEBUG LOG --------*/
    if (LOG_LV_DEBUG == logger_get_log_level(LOG_CAT_PACKET_EDIT)) {
        LOGGER_PUT_LOG_DEBUG(LOG_CAT_PACKET_EDIT, 2,
        "out_function : Constructor packet_editor::packet_editor(const sslproxysession* session)");
    }
    /*------ DEBUG LOG END ------*/
}

/*!
 * Packet editor destructor.
 */
packet_editor::~packet_editor()
{
    /*-------- DEBUG LOG --------*/
    if (LOG_LV_DEBUG == logger_get_log_level(LOG_CAT_PACKET_EDIT)) {
        LOGGER_PUT_LOG_DEBUG(LOG_CAT_PACKET_EDIT, 3,
        "in/out_function : Destructor packet_editor::~packet_editor(void)");
    }
    /*------ DEBUG LOG END ------*/
}

/*!
 * Edit client message function.
 */
void packet_editor::edit_client(char* client_msg, size_t& client_length)
{
    /*-------- DEBUG LOG --------*/
    if (LOG_LV_DEBUG == logger_get_log_level(LOG_CAT_PACKET_EDIT)) {
        LOGGER_PUT_LOG_DEBUG(LOG_CAT_PACKET_EDIT, 4,
        "in_function : void packet_editor::edit_client(char* client_msg, size_t& client_length) : "
        "client_msg(%s), client_length(%d)", client_msg, client_length);
    }
    /*------ DEBUG LOG END ------*/

    // Edit HTTP (request header)
    std::list<std::pair<std::string, std::string > >::iterator it, end;
    it  = ::http_request_header.begin();
    end = ::http_request_header.end();
    http_request request(std::string(client_msg, client_length));
    for (;it != end; ++it) {
        // Set request header field
        if (it->first == "set") {
            // "Header-Field-Name":"Set-Value"
            std::vector<std::string> set_vector = split(it->second, ":", 2);
            if (set_vector.size() == 2)
                expand_macro(set_vector.at(1));
                // Overwrite or insert.
                request.header(set_vector.at(0), set_vector.at(1));
        }
        // Remove request header field
        else if (it->first == "unset") {
            // "Header-Field-Name"
            request.header(it->second, "");
        }
        // Add request header field
        else if (it->first == "add") {
            // "Header-Field-Name":"Set-Value"
            std::vector<std::string> add_vector = split(it->second, ":", 2);
            if (add_vector.size() == 2) {
                field_range current_range = request.header(add_vector.at(0));
                expand_macro(add_vector.at(1));
                // If header field already exists, concatinate values.
                if (current_range.first != current_range.second) {
                    std::string new_value = current_range.first->second;
                    new_value += "," + add_vector.at(1);
                    request.header(add_vector.at(0), new_value);
                // otherwise insert new header field.
                } else {
                    request.header(add_vector.at(0), add_vector.at(1));
                }
            }
        }
        // Replace request header field using regular expression
        else if (it->first == "replace") {
            // "Header-Field-Name":"From-Value(regex)":"To-Value"
            std::vector<std::string> replace_vector = split(it->second, ":", 3);
            if (replace_vector.size() == 3) {
                field_range current_range = request.header(replace_vector.at(0));
                expand_macro(replace_vector.at(1));
                expand_macro(replace_vector.at(2));
                boost::regex exp(replace_vector.at(1));
                for (;current_range.first != current_range.second; current_range.first++) {
                    std::string new_value = current_range.first->second;
                    // Replace only if exist
                    if (boost::regex_search(new_value, exp)) {
                        new_value = boost::regex_replace(new_value, exp, replace_vector.at(2));
                        request.header(replace_vector.at(0), new_value);
                    }
                }
            }
        }
    }

    /*
     * Insert other protocol editor.
     */

    std::string edited = request.as_string();
    // New client message is too long (over buffer size)
    if (edited.size() > MAX_BUFFER_SIZE) {
        LOGGER_PUT_LOG_ERROR(LOG_CAT_PACKET_EDIT, 1, "Edited message is too long. Drop message.");
    }
    else {
        // Set new client message size.
        client_length = edited.size();
        // Set new client message.
        memcpy(client_msg, edited.c_str(), client_length);
    }

    /*-------- DEBUG LOG --------*/
    if (LOG_LV_DEBUG == logger_get_log_level(LOG_CAT_PACKET_EDIT)) {
        LOGGER_PUT_LOG_DEBUG(LOG_CAT_PACKET_EDIT, 5,
        "out_function : void packet_editor::edit_client(char* client_msg, size_t& client_length)");
    }
    /*------ DEBUG LOG END ------*/
}

/*!
 * Edit server message function.
 */
void packet_editor::edit_server(char* server_msg, size_t& server_length)
{
    /*-------- DEBUG LOG --------*/
    if (LOG_LV_DEBUG == logger_get_log_level(LOG_CAT_PACKET_EDIT)) {
        LOGGER_PUT_LOG_DEBUG(LOG_CAT_PACKET_EDIT, 6,
        "in_function : void packet_editor::edit_server(char* server_msg, size_t& server_length) : "
        "server_msg(%s), server_length(%d)", server_msg, server_length);
    }
    /*------ DEBUG LOG END ------*/

    // Edit HTTP (response header)
    std::list<std::pair<std::string, std::string > >::iterator it, end;
    it  = ::http_response_header.begin();
    end = ::http_response_header.end();
    http_response response(std::string(server_msg, server_length));
    for (;it != end; ++it) {
        // Set request header field
        if (it->first == "set") {
            // "Header-Field-Name":"Set-Value"
            std::vector<std::string> set_vector = split(it->second, ":", 2);
            if (set_vector.size() == 2)
                expand_macro(set_vector.at(1));
                // Overwrite or insert.
                response.header(set_vector.at(0), set_vector.at(1));
        }
        // Remove request header field
        else if (it->first == "unset") {
            // "Header-Field-Name"
            response.header(it->second, "");
        }
        // Add request header field
        else if (it->first == "add") {
            // "Header-Field-Name":"Set-Value"
            std::vector<std::string> add_vector = split(it->second, ":", 2);
            if (add_vector.size() == 2) {
                field_range current_range = response.header(add_vector.at(0));
                expand_macro(add_vector.at(1));
                // If header field already exists, concatinate values.
                if (current_range.first != current_range.second) {
                    std::string new_value = current_range.first->second;
                    new_value += "," + add_vector.at(1);
                    response.header(add_vector.at(0), new_value);
                // otherwise insert new header field.
                } else {
                    response.header(add_vector.at(0), add_vector.at(1));
                }
            }
        }
        // Replace request header field using regular expression
        else if (it->first == "replace") {
            // "Header-Field-Name":"From-Value(regex)":"To-Value"
            std::vector<std::string> replace_vector = split(it->second, ":", 3);
            if (replace_vector.size() == 3) {
                field_range current_range = response.header(replace_vector.at(0));
                expand_macro(replace_vector.at(1));
                expand_macro(replace_vector.at(2));
                boost::regex exp(replace_vector.at(1));
                for (;current_range.first != current_range.second; current_range.first++) {
                    std::string new_value = current_range.first->second;
                    // Replace only if exist
                    if (boost::regex_search(new_value, exp)) {
                        new_value = boost::regex_replace(new_value, exp, replace_vector.at(2));
                        response.header(replace_vector.at(0), new_value);
                    }
                }
            }
        }
    }

    /*
     * Insert other protocol editor.
     */

    std::string edited = response.as_string();
    // New server message is too long (over buffer size)
    if (edited.size() > MAX_BUFFER_SIZE) {
        LOGGER_PUT_LOG_ERROR(LOG_CAT_PACKET_EDIT, 2, "Edited message is too long. Drop message.");
    }
    else {
        // Set new server message size.
        server_length = edited.size();
        // Set new server message.
        memcpy(server_msg, edited.c_str(), server_length);
    }

    /*-------- DEBUG LOG --------*/
    if (LOG_LV_DEBUG == logger_get_log_level(LOG_CAT_PACKET_EDIT)) {
        LOGGER_PUT_LOG_DEBUG(LOG_CAT_PACKET_EDIT, 7,
        "out_function : void packet_editor::edit_server(char* server_msg, size_t& server_length)");
    }
    /*------ DEBUG LOG END ------*/
}

/*!
 * Expand macro function.
 *
 * @param[in/out]   source  string
 */
void packet_editor::expand_macro(std::string& source) {
    /*-------- DEBUG LOG --------*/
    if (LOG_LV_DEBUG == logger_get_log_level(LOG_CAT_PACKET_EDIT)) {
        LOGGER_PUT_LOG_DEBUG(LOG_CAT_PACKET_EDIT, 8,
        "in_function : void packet_editor::expand_macro(std::string& source) : "
        "source(%s)", source.c_str());
    }
    /*------ DEBUG LOG END ------*/
    int pos;

    pos = source.find("%{CLIENT_ADDR}");
    if (pos != std::string::npos) {
        std::string endpoint = this->session->get_remote_endpoint();
        if (endpoint.length() > 0) {
            int addr_end = endpoint.find(':');
            source.replace(pos, 14, endpoint.substr(0, addr_end));
        }
    }

    pos = source.find("%{CLIENT_PORT}");
    if (pos != std::string::npos) {
        std::string endpoint = this->session->get_remote_endpoint();
        if (endpoint.length() > 0) {
            int port_begin = endpoint.find(':') + 1;
            source.replace(pos, 14, endpoint.substr(port_begin, endpoint.length() - port_begin));
        }
    }

    pos = source.find("%{SERVER_ADDR}");
    if (pos != std::string::npos) {
        int addr_end = ::target_endpoint.find(':');
        source.replace(pos, 14, ::target_endpoint.substr(0, addr_end));
    }

    pos = source.find("%{SERVER_PORT}");
    if (pos != std::string::npos) {
        int port_begin = ::target_endpoint.find(':') + 1;
        source.replace(pos, 14, ::target_endpoint.substr(port_begin, ::target_endpoint.length() - port_begin));
    }

    pos = source.find("%{RECV_ADDR}");
    if (pos != std::string::npos) {
        int addr_end = ::recv_endpoint.find(':');
        source.replace(pos, 12, ::recv_endpoint.substr(0, addr_end));
    }

    pos = source.find("%{RECV_PORT}");
    if (pos != std::string::npos) {
        int port_begin = ::recv_endpoint.find(':') + 1;
        source.replace(pos, 12, ::recv_endpoint.substr(port_begin, ::recv_endpoint.length() - port_begin));
    }

    /*-------- DEBUG LOG --------*/
    if (LOG_LV_DEBUG == logger_get_log_level(LOG_CAT_PACKET_EDIT)) {
        LOGGER_PUT_LOG_DEBUG(LOG_CAT_PACKET_EDIT, 9,
        "out_function : void packet_editor::expand_macro(std::string& source)");
    }
    /*------ DEBUG LOG END ------*/
}

/*!
 * Split string function.
 * Split string by delimiter and return token vector.
 * If limit is specified and positive, it represents the maximum number of fields
 * the delimiter will be split into.
 *
 * @param[in]   source  string
 * @param[in]   delimiter   delimiter
 * @param[in]   limit   max token
 */
std::vector<std::string> packet_editor::split(const std::string& source, const std::string& delimiter, int limit = 0) {
    /*-------- DEBUG LOG --------*/
    if (LOG_LV_DEBUG == logger_get_log_level(LOG_CAT_PACKET_EDIT)) {
        LOGGER_PUT_LOG_DEBUG(LOG_CAT_PACKET_EDIT, 10,
        "in_function : std::vector<std::string> packet_editor::split"
        "(const std::string& source, const std::string& delimiter, int limit) : "
        "source(%s), delimiter(%s), limit(%d)", source.c_str(), delimiter.c_str(), limit);
    }
    /*------ DEBUG LOG END ------*/
    std::vector<std::string> words;
    int begin = 0;
    int end   = source.size();
    limit--;
    while (limit != 0) {
        end = source.find(delimiter, begin);
        if (end == std::string::npos) {
            end = source.size();
            break;
        }
        words.push_back(source.substr(begin, end - begin));
        begin = end + 1;
        limit--;
    }
    words.push_back(source.substr(begin, end - begin));
    /*-------- DEBUG LOG --------*/
    if (LOG_LV_DEBUG == logger_get_log_level(LOG_CAT_PACKET_EDIT)) {
        LOGGER_PUT_LOG_DEBUG(LOG_CAT_PACKET_EDIT, 11,
        "out_function : std::vector<std::string> packet_editor::split"
        "(const std::string& source, const std::string& delimiter, int limit)");
    }
    /*------ DEBUG LOG END ------*/
    return words;
}
