/*
 * UTF-7 to UTF-8 converter
 *
 * SPDX-FileType: SOURCE
 * SPDX-FileCopyrightText: Michael Bäuerle
 * SPDX-License-Identifier: BSD-2-Clause
 */

#include <assert.h>
#include <string.h>

#include <libbasexx-0/base64_decode.h>

#include "check_nul.h"
#include "iconv_utf-7.h"
#include "nonident.h"
#include "utf8.h"


/*
 * Limit for length of UTF-7 shift sequences (must be a multiple of 8)
 *
 * The Base 64 decoded payload must contain an integral number of Unicode
 * codepoints:
 *
 *     Base 64:  BBBB BBBB  (eight octets)
 *     Decoded:  DDD  DDD   (six octets)
 *     UTF-16 :  UU UU UU   (three Unicode codepoints with UTF-16BE encoding)
 */
#define UCIC0_I_LINE_LEN_MAX  1000U

/* Maximum size if UTF-16BE data */
#define UCIC0_I_BUF_SIZE  BXX0_BASE64_DECODE_LEN_OUT(UCIC0_I_LINE_LEN_MAX)

/* Size of RFC 2152 Set B */
#define UCIC0_I_SET_B_SIZE  64U


/* ========================================================================== */
/*
 * Check whether octet is part of set B
 *
 * Returns zero (false / no error) on success.
 */
static ucic0_i_bool ucic0_i_check_set_b(const unsigned char octet)
{
    static const unsigned char set_b[UCIC0_I_SET_B_SIZE] =
    {
        0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48,
        0x49, 0x4A, 0x4B, 0x4C, 0x4D, 0x4E, 0x4F, 0x50,
        0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58,
        0x59, 0x5A, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66,
        0x67, 0x68, 0x69, 0x6A, 0x6B, 0x6C, 0x6D, 0x6E,
        0x6F, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76,
        0x77, 0x78, 0x79, 0x7A, 0x30, 0x31, 0x32, 0x33,
        0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x2B, 0x2F
    };
    unsigned char i = 0;

    for (; UCIC0_I_SET_B_SIZE > i; ++i)
        if (set_b[i] == octet)
            return 0;

    return 1;
}


/* ========================================================================== */
/*
 * Calculate length of UTF-7 shift sequence starting at seq
 *
 * Returns length (clamped to limit UCIC0_I_LINE_LEN_MAX).
 */
static size_t ucic0_i_get_len_seq(const size_t max, const unsigned char *seq)
{
    size_t len = 0;
    size_t i   = 0;

    for (; max > i; ++i)
    {
        if (!ucic0_i_check_set_b(seq[i]))
            ++len;
        else
            break;
    }

    if (UCIC0_I_LINE_LEN_MAX < len)
        len = UCIC0_I_LINE_LEN_MAX;

    return len;
}


/* ========================================================================== */
/*
 * Encode Unicode codepoint for nonidentical conversion to UTF-8
 * and store octet-sequence into 'buf'
 *
 * Attention:
 * The caller must provide space for at least four octets!
 *
 * Returns length of UTF-8 sequence.
 */
size_t ucic0_i_encode_nonident(ucic0_i_state *state, unsigned char *buf)
{
    const size_t len = sizeof UCIC0_I_NONIDENT - 1U;

    assert(4U >= len);
    memcpy(buf, UCIC0_I_NONIDENT, len);
    ++(state->nonident);

    return len;
}


/* ========================================================================== */
/*
 * Flush output data
 *
 * Returns zero (false / no error) on success, 'errno' will be set otherwise.
 */
static ucic0_i_bool ucic0_i_flush_output(ucic0_i_state *state,
                                         char *data, size_t len)
{
    const size_t index_out = state->outlen_start - *(state->outlen);

    if (*(state->outlen) < len)
    {
        /* Not enough space in outarray */
        errno = UCIC0_I_E2BIG;
        return 1;
    }

    memcpy(&state->outarray[index_out], data, len);
    *(state->outlen) -= len;

    return 0;
}


/* ========================================================================== */
/*
 * Convert UTF-16BE to UTF-8
 *
 * Returns zero (false / no error) on success, 'errno' will be set otherwise.
 */
static ucic0_i_bool ucic0_i_process_utf16be(ucic0_i_state *state,
                                            const unsigned char *utf16,
                                            const size_t len_utf16)
{
    /* Four-octet UTF-8 sequence for every two-octet UTF-16BE codepoint */
    unsigned char buf[2U * UCIC0_I_BUF_SIZE] = { 0 };
    size_t        len_buf_remain             = sizeof buf;
    unsigned int  hsurrogate                 = 0;
    size_t        i                          = 0;

    assert((2U * len_utf16) <= len_buf_remain);
    for (; len_utf16 > i + 1U; i += 2U)
    {
        /* Calculate Unicode codepoint from big-endian octet pair */
        unsigned int  ucp      = (unsigned int)utf16[i] << 8 |
                                 (unsigned int)utf16[i + 1];
        unsigned char utf8[4]  = { 0 };
        size_t        len_utf8 = 0;

        if (UCIC0_I_UTF16_SURROGATE_HIGH((long int)ucp))
        {
            if (hsurrogate)
                hsurrogate = ucp;
            else
            {
                hsurrogate = ucp;
                continue;
            }
        }
        else if (UCIC0_I_UTF16_SURROGATE_LOW((long int)ucp))
        {
            if (hsurrogate)
            {
                const unsigned long int hs   = hsurrogate;
                const unsigned long int ls   = ucp;
                unsigned long int       pair = 0x10000UL;

                pair += (((hs & 0x3FFUL) << 10) | (ls & 0x3FFUL));
                hsurrogate = 0;

                assert(0xFFFFUL < pair);
                len_utf8 = ucic0_i_encode_utf8(utf8, pair);
            }
        }
        else
            len_utf8 = ucic0_i_encode_utf8(utf8, ucp);

        if(0U == len_utf8)
            len_utf8 = ucic0_i_encode_nonident(state, utf8);

        assert(len_buf_remain >= len_utf8);
        memcpy(&buf[sizeof buf - len_buf_remain], utf8, len_utf8);
        len_buf_remain -= len_utf8;
    }
    /* Discard potential single trailing octet according to RFC 2152 */

    if (ucic0_i_flush_output(state, (char*)buf, sizeof buf - len_buf_remain))
        return 1;

    return 0;
}


/* ========================================================================== */
/*
 * Decode Base 64
 *
 * Returns zero (false / no error) on success, 'errno' will be set otherwise.
 */
static ucic0_i_bool ucic0_i_process_base64(ucic0_i_state *state,
                                           const unsigned char *seq,
                                           const size_t len_seq)
{
    unsigned char buf[UCIC0_I_BUF_SIZE] = { 0 };
    size_t        len_buf_remain        = sizeof buf;
    size_t        len_seq_remain        = len_seq;

    /* Ignore ill-formed tail (unused bits not zero) according to RFC 2152 */
    if ( 0 > bxx0_base64_decode(buf, &len_buf_remain, seq, &len_seq_remain,
                                BXX0_BASE64_DECODE_FLAG_NOPAD |
                                BXX0_BASE64_DECODE_FLAG_INVTAIL) ||
         0U != len_seq_remain )
    {
        /* Base 64 decoder failed */
        size_t i = 0;

        for (; len_seq > i; ++i)
            if (ucic0_i_nonident(state))
                return 1;
    }
    else if (ucic0_i_process_utf16be(state, buf, sizeof buf - len_buf_remain))
        return 1;

    return 0;
}


/* ========================================================================== */
/*
 * Process UTF-7 data in shift state (Base 64 encoded UTF-16BE)
 *
 * Returns zero (false / no error) on success, 'errno' will be set otherwise.
 */
static ucic0_i_bool ucic0_i_process_shift(ucic0_i_state *state)
{
    const size_t         index_in = state->inlen_start - *(state->inlen);
    const unsigned char *seq      = (const unsigned char *)
                                    &state->inarray[index_in];
    const size_t         len_seq  = ucic0_i_get_len_seq(*(state->inlen), seq);

    if (0U != len_seq)
    {
        if (ucic0_i_process_base64(state, seq, len_seq))
            return 1;
    }

    /* Consume input data */
    *(state->inlen) -= len_seq;

    if (*(state->inlen))
    {
        unsigned char c = state->inarray[state->inlen_start - *(state->inlen)];

        if (ucic0_i_check_set_b(c))
        {
            /* Consume potential shift reset mark '-' */
            if (0x2D == (int)c)
                --*(state->inlen);

            /* Clear shift state */
            state->shift = 0;
        }
    }

    return 0;
}


/* ========================================================================== */
/*
 * Check whether octet c is directly encoded
 *
 * Shift mark ('+') must be rejected.
 *
 * Returns zero (false / no error) on success.
 */
static ucic0_i_bool ucic0_i_check_direct(const int c)
{
    /* Check for control characters (Rule 3) */
    if(0x09 == c || 0x0A == c || 0x0D == c|| 0x20 == c)
        return 0;

    /* Check for character set D (Rule 1) */
    else if( (0x30 <= c && 0x39 >= c) || (0x41 <= c && 0x5A >= c) ||
             (0x61 <= c && 0x7A >= c) || (0x27 <= c && 0x29 >= c) ||
             (0x2C <= c && 0x2F >= c) || 0x3A == c || 0x3F == c )
        return 0;

    /* Check for character set O (Rule 1) */
    else if( (0x21 <= c && 0x26 >= c) || 0x2A == c ||
             (0x3B <= c && 0x3E >= c) || 0x40 == c || 0x5B == c ||
             (0x5D <= c && 0x60 >= c) || (0x7B <= c && 0x7D >= c) )
        return 0;

    /* Octet not valid for direct encoding */
    return 1;
}


/* ========================================================================== */
/*
 * Synchronize to next directly encoded octet or start of shift sequence
 *
 * A nonidentical conversion is executed for every input octet (until
 * synchronization is reached again), if requested by caller.
 *
 * Returns zero (false / no error) on success, 'errno' will be set otherwise.
 */
static ucic0_i_bool ucic0_i_resync(ucic0_i_state *state)
{
    ucic0_i_bool first = 1;

    while (*(state->inlen))
    {
        const size_t index_in = state->inlen_start - *(state->inlen);

        if (!first)
        {
            int c = (unsigned char)state->inarray[index_in];

            /* Check for directly encoded octet or start of shift sequence */
            if (!ucic0_i_check_direct(c) || (0x2B == c))
                break;
        }

        /* Check for NUL control character */
        if (ucic0_i_check_nul(state, state->inarray[index_in]))
            break;

        first = 0;
        if (ucic0_i_nonident(state))
            return 1;
    }

    state->resync = 0;
    return 0;
}


/* ========================================================================== */
/*
 * Process UTF-7 data in normal state (US-ASCII)
 *
 * Returns zero (false / no error) on success, 'errno' will be set otherwise.
 */
static ucic0_i_bool ucic0_i_process_direct(ucic0_i_state *state)
{
    while (*(state->inlen))
    {
        const size_t index_in  = state->inlen_start  - *(state->inlen);
        const size_t index_out = state->outlen_start - *(state->outlen);
        const int    c         = (unsigned char)state->inarray[index_in];
        size_t       len       = 1;

        /* Check for shift mark '+' */
        if (0x2B == c)
        {
            if (2U > *(state->inlen))
            {
                /* Invalid shift sequence at end of data */
                errno = UCIC0_I_EINVAL;
                return 1;
            }

            /* Check for special sequence "+-" (literal '+') */
            if ((char)0x2D == state->inarray[index_in + 1U])
                ++len;
            else
            {
                /* Set shift state */
                --*(state->inlen);
                state->shift = 1;
                break;
            }
        }
        else if (ucic0_i_check_direct(c))
        {
            state->resync = 1;
            break;
        }

        /* Flush output data */
        if (0 == *(state->outlen))
        {
            /* Not enough space in outarray */
            errno = UCIC0_I_E2BIG;
            return 1;
        }
        state->outarray[index_out] = c;
        --*(state->outlen);

        /* Consume input data */
        assert(*(state->inlen) >= len);
        *(state->inlen) -= len;
    }

    return 0;
}


/* ========================================================================== */
ucic0_i_bool ucic0_i_conv_utf7(ucic0_i_state *state)
{
    assert(0U == UCIC0_I_LINE_LEN_MAX % 8U);

    while (*(state->inlen))
    {
        if (state->resync)
        {
            /* Lost synchronization */
            if (ucic0_i_resync(state))
               return 1;
            else
               state->resync = 0;
        }
        else if (state->shift)
        {
            if (ucic0_i_process_shift(state))
               return 1;
        }
        else
        {
            if (ucic0_i_process_direct(state))
               return 1;
        }

        if (state->abort)
            break;
    }

    return 0;
}
