/*
 *  Copyright (C) 2004 Morten Fjord-Larsen
 *  Copyright (C) 2005 Kouji TAKAO <kouji@netlab.jp>
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU Library General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
 */

#ifdef HAVE_CONFIG_H
#  include <config.h>
#endif

#include <errno.h>
#include <string.h>
#include <glib/gi18n.h>
#include <mhash.h>

#include "gpass/crypt-stream.h"

/***********************************************************
 *
 * GPassCryptStream
 *
 ***********************************************************/
static void
gpass_crypt_stream_instance_init(GTypeInstance *instance, gpointer g_class)
{
    GPassCryptStream *self = GPASS_CRYPT_STREAM(instance);

    self->file = NULL;
    self->mcrypt = 0;
}

GType
gpass_crypt_stream_get_type(void)
{
    static GType type = 0;
    
    if (type == 0) {
        static const GTypeInfo info = {
            sizeof(GPassCryptStreamClass),
            NULL,
            NULL,
            NULL,
            NULL,
            NULL,
            sizeof(GPassCryptStream),
            0,
            gpass_crypt_stream_instance_init
        };
        
        type = g_type_register_static(G_TYPE_OBJECT, "GPassCryptStreamType",
                                      &info, G_TYPE_FLAG_ABSTRACT);
    }
    return type;
}

static GError *
gpass_crypt_stream_compute_key(const gchar *password,
                               guchar **key, gsize *key_bytes)
{
    MHASH thread;
    GError *error = NULL;
    
    *key_bytes = mhash_get_block_size(MHASH_SHA1);
    thread = mhash_init(MHASH_SHA1);
    if (thread == MHASH_FAILED) {
        g_set_error(&error, 0, 0, _("failed mhash_init()"));
        return error;
    }
    mhash(thread, password, strlen(password));
    *key = g_malloc(sizeof(guchar) * (*key_bytes));
    mhash_deinit(thread, *key);
    return NULL;
}

static GError *
setup_blowfish_cipher(guchar *key, gsize key_bytes, MCRYPT *mcrypt)
{
    MCRYPT thread;
    guchar iv[] = { 5, 23, 1, 123, 12, 3, 54, 94 };
    int iv_size;
    int result;
    GError *error = NULL;

    thread = mcrypt_module_open(MCRYPT_BLOWFISH, NULL, MCRYPT_CBC, NULL);
    if (thread == MCRYPT_FAILED) {
        g_set_error(&error, 0, 0, _("failed mcrypt_module_open()"));
        return error;
    }
    iv_size = mcrypt_enc_get_iv_size(thread);
    if (iv_size != 8) {
        g_set_error(&error, 0, 0, _("IV size != 8: %d"), iv_size);
        mcrypt_module_close(thread);
        return error;
    }
    result = mcrypt_generic_init(thread, key, key_bytes, iv);
    if (result < 0) {
        g_set_error(&error, 0, 0, _("failed mcrypt_generic_init(): %s"),
                    mcrypt_strerror(result));
        mcrypt_module_close(thread);
        return error;
    }
    *mcrypt = thread;
    return NULL;
} 

static GError *
gpass_crypt_stream_open(GPassCryptStream *self, FILE *file,
                        const char *password)
{
    MCRYPT mcrypt;
    guchar *key;
    gsize key_bytes;
    GError *error;
    
    error = gpass_crypt_stream_compute_key(password, &key, &key_bytes);
    if (error != NULL) {
        return error;
    }
    error = setup_blowfish_cipher(key, key_bytes, &mcrypt);
    g_free(key);
    if (error != NULL) {
        return error;
    }
    self->file = file;
    self->mcrypt = mcrypt;
    self->block_size = mcrypt_enc_get_block_size(mcrypt);
    return NULL;
}

static void
gpass_crypt_stream_close(GPassCryptStream *self)
{
    mcrypt_generic_deinit(self->mcrypt);
    mcrypt_module_close(self->mcrypt);
    fclose(self->file);
    g_object_unref(self);
}

/***********************************************************
 *
 * GPassDecryptStream
 *
 ***********************************************************/
static GPassCryptStreamClass *parent_decrypt_class = NULL;

static void
gpass_decrypt_stream_instance_init(GTypeInstance *instance, gpointer g_class)
{
    GPassDecryptStream *self = GPASS_DECRYPT_STREAM(instance);
    
    self->buffer[0] = NULL;
    self->buffer[1] = NULL;
    self->ptr = self->ptr_end = NULL;
}

static void
gpass_decrypt_stream_instance_finalize(GObject *object)
{
    GPassDecryptStream *self = GPASS_DECRYPT_STREAM(object);
    
    g_free(self->buffer[0]);
    g_free(self->buffer[1]);
    G_OBJECT_CLASS(parent_decrypt_class)->finalize(object);
}

static void
gpass_decrypt_stream_class_init(gpointer g_class, gpointer g_class_data)
{
    GObjectClass *gobject_class = G_OBJECT_CLASS(g_class);

    parent_decrypt_class = g_type_class_peek_parent(g_class);
    gobject_class->finalize = gpass_decrypt_stream_instance_finalize;
}

GType
gpass_decrypt_stream_get_type(void)
{
    static GType type = 0;
    
    if (type == 0) {
        static const GTypeInfo info = {
            sizeof(GPassDecryptStreamClass),
            NULL,
            NULL,
            gpass_decrypt_stream_class_init,
            NULL,
            NULL,
            sizeof(GPassDecryptStream),
            0,
            gpass_decrypt_stream_instance_init
        };
        
        type = g_type_register_static(GPASS_TYPE_CRYPT_STREAM,
                                      "GPassDecryptStreamType", &info, 0);
    }
    return type;
}

GError *
gpass_decrypt_stream_open(GPassDecryptStream **self, FILE *file,
                          const char *password)
{
    GPassDecryptStream *result;
    GPassCryptStream *crypt;
    gint len;
    GError *error;

    result = g_object_new(GPASS_TYPE_DECRYPT_STREAM, NULL);
    crypt = GPASS_CRYPT_STREAM(result);
    error = gpass_crypt_stream_open(crypt, file, password);
    if (error != NULL) {
        g_object_unref(result);
        return error;
    }
    result->buffer[0] = g_malloc(sizeof(guchar) * crypt->block_size);
    result->buffer[1] = g_malloc(sizeof(guchar) * crypt->block_size);
    result->buffer_index = 0;
    len = fread(result->buffer[0], 1, crypt->block_size, crypt->file);
    if (len == crypt->block_size) {
        result->final = FALSE;
    }
    else {
        result->final = TRUE;
    }
    *self = result;
    return NULL;
}

static void
remove_padding(GPassDecryptStream *self)
{
    GPassCryptStream *crypt = GPASS_CRYPT_STREAM(self);
    guchar val;
    gint i;

    val = self->ptr[crypt->block_size - 1];
    if (val >= crypt->block_size) {
        return;
    }
    for (i = crypt->block_size - val; i < crypt->block_size; i++) {
        if (self->ptr[i] != val) {
            return;
        }
    }
    self->ptr_end -= val;
}

static GError *
read_block(GPassDecryptStream *self)
{
    GPassCryptStream *crypt = GPASS_CRYPT_STREAM(self);
    gint len;
    GError *error = NULL;

    if (mdecrypt_generic(crypt->mcrypt,
                         self->buffer[self->buffer_index],
                         crypt->block_size) != 0) {
        g_set_error(&error, 0, 0, _("failed mdecrypt_generic()"));
        return error;
    }
    self->ptr = self->buffer[self->buffer_index];
    self->ptr_end = self->ptr + crypt->block_size;
    
    self->buffer_index = (self->buffer_index + 1) % 2;
    len = fread(self->buffer[self->buffer_index], 1, crypt->block_size,
                crypt->file);
    if (len != crypt->block_size) {
        remove_padding(self);
        self->final = TRUE;
    }
    return NULL;
}

GError *
gpass_decrypt_stream_read(GPassDecryptStream *self, gchar *buf, gsize count,
                          gsize *result)
{
    gsize read_len = 0;
    gchar *ptr = buf;
    GError *error = NULL;
    
    if (gpass_decrypt_stream_eof(self)) {
        g_set_error(&error, 0, 0, _("Premature end of file"));
        return error;
    }
    while (1) {
        gsize buffer_len;
        
        if (self->ptr == self->ptr_end) {
            error = read_block(self);
            if (error != NULL) {
                return error;
            }
        }
        buffer_len = self->ptr_end - self->ptr;
        if (buffer_len < count) {
            memcpy(ptr, self->ptr, buffer_len);
            ptr += buffer_len;
            read_len += buffer_len;
            count -= buffer_len;
            self->ptr = self->ptr_end;
            if (self->final) {
                break;
            }
        }
        else {
            memcpy(ptr, self->ptr, count);
            read_len += count;
            self->ptr += count;
            break;
        }
    }
    *result = read_len;
    return NULL;
}

GError *
gpass_decrypt_stream_read_line(GPassDecryptStream *self, GString **line)
{
    GError *error = NULL;
    
    if (gpass_decrypt_stream_eof(self)) {
        g_set_error(&error, 0, 0, _("Premature end of file"));
        return error;
    }
    *line = g_string_assign(*line, "");
    while (1) {
        gsize buffer_len;
        guchar *p;

        if (self->ptr == self->ptr_end) {
            error = read_block(self);
            if (error != NULL) {
                return error;
            }
        }
        buffer_len = self->ptr_end - self->ptr;
        p = memchr(self->ptr, '\n', buffer_len);
        if (p == NULL) {
            *line = g_string_append_len(*line, self->ptr, buffer_len);
            self->ptr = self->ptr_end;
            if (self->final) {
                break;
            }
        }
        else {
            gint len;
            
            len = p - self->ptr + 1 /* CR */;
            *line = g_string_append_len(*line, self->ptr, len);
            self->ptr += len;
            break;
        }
    }
    return NULL;
}

gboolean
gpass_decrypt_stream_eof(GPassDecryptStream *self)
{
    GPassCryptStream *crypt = GPASS_CRYPT_STREAM(self);

    if (feof(crypt->file) && self->final && self->ptr == self->ptr_end) {
        return TRUE;
    }
    return FALSE;
}

void
gpass_decrypt_stream_close(GPassDecryptStream *self)
{
    gpass_crypt_stream_close(GPASS_CRYPT_STREAM(self));
}

/***********************************************************
 *
 * GPassEncryptStream
 *
 ***********************************************************/
static GPassCryptStreamClass *parent_encrypt_class = NULL;

static void
gpass_encrypt_stream_instance_init(GTypeInstance *instance, gpointer g_class)
{
    GPassEncryptStream *self = GPASS_ENCRYPT_STREAM(instance);
    
    self->buffer = NULL;
    self->ptr = self->ptr_end = NULL;
}

static void
gpass_encrypt_stream_instance_finalize(GObject *object)
{
    GPassEncryptStream *self = GPASS_ENCRYPT_STREAM(object);
    
    g_free(self->buffer);
    G_OBJECT_CLASS(parent_encrypt_class)->finalize(object);
}

static void
gpass_encrypt_stream_class_init(gpointer g_class, gpointer g_class_data)
{
    GObjectClass *gobject_class = G_OBJECT_CLASS(g_class);

    parent_encrypt_class = g_type_class_peek_parent(g_class);
    gobject_class->finalize = gpass_encrypt_stream_instance_finalize;
}

GType
gpass_encrypt_stream_get_type(void)
{
    static GType type = 0;
    
    if (type == 0) {
        static const GTypeInfo info = {
            sizeof(GPassEncryptStreamClass),
            NULL,
            NULL,
            gpass_encrypt_stream_class_init,
            NULL,
            NULL,
            sizeof(GPassEncryptStream),
            0,
            gpass_encrypt_stream_instance_init
        };
        
        type = g_type_register_static(GPASS_TYPE_CRYPT_STREAM,
                                      "GPassEncryptStreamType", &info, 0);
    }
    return type;
}

GError *
gpass_encrypt_stream_open(GPassEncryptStream **self, FILE *file,
                          const char *password)
{
    GPassEncryptStream *result;
    GPassCryptStream *crypt;
    GError *error;
    
    result = g_object_new(GPASS_TYPE_ENCRYPT_STREAM, NULL);
    crypt = GPASS_CRYPT_STREAM(result);
    error = gpass_crypt_stream_open(crypt, file, password);
    if (error != NULL) {
        g_object_unref(result);
        return error;
    }
    result->buffer = g_malloc(sizeof(guchar) * crypt->block_size);
    result->ptr = result->buffer;
    result->ptr_end = result->ptr + crypt->block_size;
    *self = result;
    return NULL;
}

static GError *
write_block(GPassEncryptStream *self)
{
    GPassCryptStream *crypt = GPASS_CRYPT_STREAM(self);
    int rc;
    GError *error = NULL;

    rc = mcrypt_generic(crypt->mcrypt, self->buffer, crypt->block_size);
    if (rc != 0) {
        g_set_error(&error, 0, 0, _("failed mcrypt_generic()"));
        return error;
    }
    rc = fwrite(self->buffer, crypt->block_size, 1, crypt->file);
    if (rc == 0 && ferror(crypt->file)) {
        g_set_error(&error, 0, errno, g_strerror(errno));
        return error;
    }
    return NULL;
}

GError *
gpass_encrypt_stream_write(GPassEncryptStream *self,
                           const gchar *buf, gsize count)
{
    const gchar *ptr;
    GError *error;

    if (count == 0) {
        return NULL;
    }
    ptr = buf;
    while (count > 0) {
        gint buffer_len;
        
        buffer_len = self->ptr_end - self->ptr;
        if (buffer_len <= count) {
            memcpy(self->ptr, ptr, buffer_len);
            error = write_block(self);
            if (error != NULL) {
                return error;
            }
            ptr += buffer_len;
            count -= buffer_len;
            self->ptr = self->buffer;
        }
        else {
            memcpy(self->ptr, ptr, count);
            self->ptr += count;
            count = 0;
        }
    }
    return NULL;
}

static void
append_padding(GPassEncryptStream *self)
{
    guchar padding;
    guchar *p;
            
    padding = self->ptr_end - self->ptr;
    for (p = self->ptr; p < self->ptr_end; p++) {
        *p = padding;
    }
}

void
gpass_encrypt_stream_close(GPassEncryptStream *self)
{
    if (self->ptr != self->buffer) {
        append_padding(self);
        write_block(self);
    }
    gpass_crypt_stream_close(GPASS_CRYPT_STREAM(self));
}
