#include <stdarg.h>

#include "tc_private/tc_handle.h"
#include "tc_private/tc_private.h"

#include "log/log.h"

#include "tc_tpm2.h"
#include "tpm2_common.h"

#include "tc_type.h"
#include "tc_errcode.h"

struct tpm2_decrypt_ctx
{
    TC_HANDLE        handle;
    uint32_t         key_index;
    TC_BUFFER       *key_auth_msg;
    TC_ALG           alg_decrypt;
    TC_BUFFER       *ciphter_text;
    TC_BUFFER       *plain_text;
};

TC_RC tpm2_decrypt_init(struct api_ctx_st *api_ctx, int num, ...)
{
    TC_RC rc = TC_SUCCESS;
    struct tpm2_decrypt_ctx* dctx = (struct tpm2_decrypt_ctx*)malloc(sizeof(struct tpm2_decrypt_ctx));

    va_list ap;
    va_start(ap, num);
    dctx->handle = va_arg(ap, TC_HANDLE);
    dctx->key_index = va_arg(ap, uint32_t);
    dctx->key_auth_msg = va_arg(ap, TC_BUFFER*);
    dctx->alg_decrypt = va_arg(ap, TC_ALG);
    dctx->ciphter_text = va_arg(ap, TC_BUFFER*);
    dctx->plain_text = va_arg(ap, TC_BUFFER*);
    va_end(ap);

    api_ctx->data = (HANDLE_DATA*)dctx;
    return rc;
}

TC_RC tpm2_decrypt_free(struct api_ctx_st *api_ctx)
{
    TC_RC rc = TC_SUCCESS;  
    free(api_ctx->data); 
    api_ctx->data = NULL;
    api_ctx->cmd_code = API_NULL;
    return rc;
}

TC_RC tpm2_decrypt(API_CTX *ctx)
{
    TC_RC rc = TC_SUCCESS;

    struct tpm2_decrypt_ctx* dctx = (struct tpm2_decrypt_ctx*)ctx->data;
    TC_HANDLE_CTX* tc_handle_ctx = (TC_HANDLE_CTX*)(dctx->handle);


    if (dctx->alg_decrypt == TC_RSA) {
        TPMT_RSA_DECRYPT inScheme;
        TPM2B_DATA label;
        TPM2B_PUBLIC_KEY_RSA decrypt_messsage = TPM2B_TYPE_INIT(TPM2B_PUBLIC_KEY_RSA, buffer);
        TPM2B_PUBLIC_KEY_RSA cipher_messsage = TPM2B_TYPE_INIT(TPM2B_PUBLIC_KEY_RSA, buffer);
        TSS2L_SYS_AUTH_RESPONSE sessionsDataout;
        TSS2L_SYS_AUTH_COMMAND sessionsData = {
            .auths    = {{.sessionHandle = TPM2_RS_PW}},
            .count    = 1
        };
        if (dctx->ciphter_text->size > TPM2_MAX_RSA_KEY_BYTES) {
            log_error("The length of the data to be decrypted exceeds the limit\n");
            return TC_DECRYPT_BUFFER_OVERSIZE;
        }
        cipher_messsage.size = dctx->ciphter_text->size;
        memcpy (cipher_messsage.buffer,
                dctx->ciphter_text->buffer,
                dctx->ciphter_text->size);

        if (dctx->key_auth_msg != NULL) {
            if (dctx->key_auth_msg->size > sizeof(TPMU_HA)) {
                log_error("Key authorization authentication password exceeds limit\n");
                return TC_AUTH_HMAC_OVERSIZE;
            }
            sessionsData.auths[0].hmac.size = dctx->key_auth_msg->size;
            memcpy(sessionsData.auths[0].hmac.buffer,
                   dctx->key_auth_msg->buffer,
                   dctx->key_auth_msg->size);
        }

        inScheme.scheme = TPM2_ALG_RSAES;
        label.size = 0;

        if (dctx->key_index > MAX_OBJECT_NODE_COUNT) {
            rc = Tss2_Sys_RSA_Decrypt((TSS2_SYS_CONTEXT*)tc_handle_ctx->handle.tc_context,
                                       dctx->key_index,
                                       &sessionsData,
                                       &cipher_messsage,
                                       &inScheme,
                                       &label,
                                       &decrypt_messsage,
                                       &sessionsDataout);
        }else{
            if (dctx->key_index > tc_handle_ctx->handle.tc_object->count) {
                log_error("Invalid object index\n");
                return TC_OBJECT_INDEX;
            }
            rc = Tss2_Sys_RSA_Decrypt((TSS2_SYS_CONTEXT*)tc_handle_ctx->handle.tc_context,
                                       tc_handle_ctx->handle.tc_object->node_info[dctx->key_index]->obj_handle,
                                       &sessionsData,
                                       &cipher_messsage,
                                       &inScheme,
                                       &label,
                                       &decrypt_messsage,
                                       &sessionsDataout);
        }
        dctx->plain_text->buffer = (uint8_t*)malloc(decrypt_messsage.size);
        memcpy(dctx->plain_text->buffer, decrypt_messsage.buffer, decrypt_messsage.size);
        dctx->plain_text->size = decrypt_messsage.size;
    }else{
        TPM2B_MAX_BUFFER out_data = TPM2B_TYPE_INIT(TPM2B_MAX_BUFFER, buffer);
        TPM2B_MAX_BUFFER in_data = TPM2B_TYPE_INIT(TPM2B_MAX_BUFFER, buffer);
        TPM2B_IV iv_out = TPM2B_TYPE_INIT(TPM2B_IV, buffer);
        TSS2L_SYS_AUTH_RESPONSE sessionsDataout;
        TSS2L_SYS_AUTH_COMMAND sessionsData = {
            .auths    = {{.sessionHandle = TPM2_RS_PW}},
            .count    = 1
        };
        TPM2B_IV iv_in = {
            .size = TPM2_MAX_SYM_BLOCK_SIZE,
            .buffer = { 0 }
        };
        if (dctx->ciphter_text->size > TPM2_MAX_DIGEST_BUFFER) {
            log_error("The length of the data to be decrypted exceeds the limit\n");
            return TC_DECRYPT_BUFFER_OVERSIZE;
        }
        in_data.size = dctx->ciphter_text->size;
        memcpy (in_data.buffer,
                dctx->ciphter_text->buffer,
                dctx->ciphter_text->size);

        if (dctx->key_auth_msg != NULL) {
            if (dctx->key_auth_msg->size > sizeof(TPMU_HA)) {
                log_error("Key authorization authentication password exceeds limit\n");
                return TC_AUTH_HMAC_OVERSIZE;
            }
            sessionsData.auths[0].hmac.size = dctx->key_auth_msg->size;
            memcpy(sessionsData.auths[0].hmac.buffer,
                   dctx->key_auth_msg->buffer,
                   dctx->key_auth_msg->size);
        }

        if (dctx->key_index > MAX_OBJECT_NODE_COUNT) {
            rc = Tss2_Sys_EncryptDecrypt2((TSS2_SYS_CONTEXT*)tc_handle_ctx->handle.tc_context,
                                        dctx->key_index,
                                        &sessionsData,
                                        &in_data,
                                        TPM2_YES,
                                        TPM2_ALG_NULL,
                                        &iv_in,
                                        &out_data,
                                        &iv_out,
                                        &sessionsDataout);
            if (rc != TSS2_RC_SUCCESS) {
                rc = Tss2_Sys_EncryptDecrypt((TSS2_SYS_CONTEXT*)tc_handle_ctx->handle.tc_context,
                                          dctx->key_index,
                                          &sessionsData,
                                          TPM2_YES,
                                          TPM2_ALG_NULL,
                                          &iv_in,
                                          &in_data,
                                          &out_data,
                                          &iv_out,
                                          &sessionsDataout);
            }
        }else{
            if (dctx->key_index > tc_handle_ctx->handle.tc_object->count) {
                log_error("Invalid object index\n");
                return TC_OBJECT_INDEX;
            }
            rc = Tss2_Sys_EncryptDecrypt2((TSS2_SYS_CONTEXT*)tc_handle_ctx->handle.tc_context,
                                        tc_handle_ctx->handle.tc_object->node_info[dctx->key_index]->obj_handle,
                                        &sessionsData,
                                        &in_data,
                                        TPM2_YES,
                                        TPM2_ALG_NULL,
                                        &iv_in,
                                        &out_data,
                                        &iv_out,
                                        &sessionsDataout);
            if (rc != TSS2_RC_SUCCESS) {
                rc = Tss2_Sys_EncryptDecrypt((TSS2_SYS_CONTEXT*)tc_handle_ctx->handle.tc_context,
                                          tc_handle_ctx->handle.tc_object->node_info[dctx->key_index]->obj_handle,
                                          &sessionsData,
                                          TPM2_YES,
                                          TPM2_ALG_NULL,
                                          &iv_in,
                                          &in_data,
                                          &out_data,
                                          &iv_out,
                                          &sessionsDataout);
            }
        }
        dctx->plain_text->buffer = (uint8_t*)malloc(out_data.size);
        memcpy(dctx->plain_text->buffer, out_data.buffer, out_data.size);
        dctx->plain_text->size = out_data.size;
    }

    if (rc != TSS2_RC_SUCCESS) {
        log_error("Failed to run api_decrypt:0x%0x\n", rc);
        rc = TC_COMMAND_DECRYPT;
    }
    return rc;
}