/**********************************************************************
 * daemon.c                                                 August 2005
 *
 * KSSLD: An implementation of SSL/TLS in the Linux Kernel
 * Copyright (C) 2005  NTT COMWARE Corporation.
 *
 * This file based in part on code from LVS www.linuxvirtualserver.org
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
 * 02110-1301, USA.
 *
 **********************************************************************/

#include <linux/slab.h>

#include "types/cipher_suite_t.h"

#include "daemon.h"
#include "socket.h"
#include "kssl_alloc.h"
#include "log.h"

static struct list_head kssl_daemon_list;


/* Accepted cipher suites, in order of preference
 * Used as the default. May be changed at run time.
 * But cipher suites not on this list will not be accepted 
 *
 * List must end with CIPHER_SUITE_END_MARKER */

static const cipher_suite_t cs_implemented[] =
{
	TLS_RSA_WITH_AES_256_CBC_SHA,
	TLS_RSA_WITH_AES_128_CBC_SHA,
	
	TLS_RSA_WITH_3DES_EDE_CBC_SHA,
	TLS_RSA_WITH_DES_CBC_SHA,
	
	TLS_RSA_EXPORT1024_WITH_DES_CBC_SHA,
	TLS_RSA_EXPORT_WITH_DES40_CBC_SHA,
	
	TLS_RSA_WITH_NULL_SHA,
	TLS_RSA_WITH_NULL_MD5,
	
	CIPHER_SUITE_END_MARKER
};


/* Accepted asym methods, in order of preference
 * Used as the default. May be changed at run time.
 * But asym methods not on this list will not be accepted 
 *
 * List must end with kssl_asym_method_none */

static const kssl_asym_method_t am_implemented[] =
{
	/* kssl_asym_method_aes_kernel, Not implemented */
	kssl_asym_method_aep_user,
	kssl_asym_method_software,

	kssl_asym_method_none
};


static void 
__kssl_daemon_destroy(kssl_daemon_t *daemon)
{
	kssl_cert_destroy_data(&daemon->cert);
	kssl_key_destroy_data(&daemon->key);
	if(daemon->sock)
		kssl_socket_close(daemon->sock);
	if(daemon->cs)
		kssl_kfree(daemon->cs);
	if(daemon->am)
		kssl_kfree(daemon->am);
	kssl_kfree(daemon);
}


void 
kssl_daemon_destroy(kssl_daemon_t *daemon)
{
	list_del_init(&(daemon->list));
	if (atomic_read(&daemon->users) == 1)
		__kssl_daemon_destroy(daemon);
}


kssl_daemon_t *
kssl_daemon_create(void)
{
	kssl_daemon_t *daemon;

	daemon = (kssl_daemon_t *)kssl_kmalloc(sizeof(kssl_daemon_t), 
			GFP_KERNEL);
	if (!daemon)
		return NULL;

	memset(daemon, 0, sizeof(kssl_daemon_t));

	INIT_LIST_HEAD(&(daemon->list));
	kssl_daemon_get(daemon);
	rwlock_init(&daemon->lock);

	return(daemon);
}


int
kssl_daemon_foreach(int (*action)(kssl_daemon_t *daemon, void *data),
		void *data)
{
	int status = 0;
	kssl_daemon_t *daemon;
	struct list_head *c_list;
	struct list_head *t_list;

	list_for_each_safe(c_list, t_list, &kssl_daemon_list) {
		int ret;
		daemon = list_entry(c_list, kssl_daemon_t, list);
		ret = action(daemon, data);
		if (ret <= 0)
			return ret;
		status += ret;
	}

	return status;
}


kssl_daemon_t *
kssl_daemon_new(const u32 vaddr, const u16 vport)
{
	kssl_daemon_t *daemon;

	daemon = kssl_daemon_create();
	if (!daemon)
		return NULL;

	daemon->vaddr = vaddr;
	daemon->vport = vport;

	if(kssl_daemon_cipher_suite_set_cpy(daemon, cs_implemented,
			cipher_suite_count(cs_implemented)) < 0) {
		kssl_daemon_destroy(daemon);
		return NULL;
	}

	if(kssl_daemon_asym_method_set_cpy(daemon, am_implemented,
			kssl_daemon_asym_method_list_ncount(am_implemented,
				sizeof(am_implemented))) < 0) {
		kssl_daemon_destroy(daemon);
		return NULL;
	}

	return(daemon);
}


kssl_daemon_t *
kssl_daemon_find(u32 vaddr, u16 vport)
{
	kssl_daemon_t *daemon;
	struct list_head *c_list;
	struct list_head *t_list;

	list_for_each_safe(c_list, t_list, &kssl_daemon_list) {
		daemon = list_entry(c_list, kssl_daemon_t, list);
		if (daemon->vaddr == vaddr && daemon->vport == vport) {
			return daemon;
		}
	}

	return NULL;
}


kssl_daemon_t *
kssl_daemon_find_or_create_and_add(u32 vaddr, u16 vport)
{
	kssl_daemon_t *daemon;

	daemon = kssl_daemon_find(vaddr, vport);
	if (daemon)
		return daemon;

	daemon = kssl_daemon_new(vaddr, vport);
	if (!daemon) {
		KSSL_DEBUG(3, "kssl_daemon_find_or_create_and_add: "
				"kssl_daemon_new %08x:%04x\n", vaddr, vport);
		return NULL;
	}
	kssl_daemon_list_add(daemon);

	return daemon;
}


void
kssl_daemon_cpy_cert(kssl_daemon_t *dest, kssl_daemon_t *src)
{
	memcpy(&dest->cert, &src->cert, sizeof(kssl_cert_t));
	memcpy(&dest->key, &src->key, sizeof(kssl_key_t));
}


void
kssl_cert_destroy_data(kssl_cert_t *cert)
{
	if (cert->cert.iov_base)
		kssl_kfree(cert->cert.iov_base);
}


void
kssl_key_destroy_data(kssl_key_t *key)
{
	switch (key->type) {
		case kssl_key_type_rsa:
			if (key->key.rsa.n.iov_base)
				kssl_kfree(key->key.rsa.n.iov_base);
			if (key->key.rsa.e.iov_base)
				kssl_kfree(key->key.rsa.e.iov_base);
			if (key->key.rsa.d.iov_base)
				kssl_kfree(key->key.rsa.d.iov_base);
			if (key->key.rsa.p.iov_base)
				kssl_kfree(key->key.rsa.p.iov_base);
			if (key->key.rsa.q.iov_base)
				kssl_kfree(key->key.rsa.q.iov_base);
			if (key->key.rsa.dmp1.iov_base)
				kssl_kfree(key->key.rsa.dmp1.iov_base);
			if (key->key.rsa.dmq1.iov_base)
				kssl_kfree(key->key.rsa.dmq1.iov_base);
			if (key->key.rsa.iqmp.iov_base)
				kssl_kfree(key->key.rsa.iqmp.iov_base);
			break;
#ifdef WITH_DH_DSA_SUPPORT
		case kssl_key_type_dsa:
			if (key->key.dsa.p.iov_base)
				kssl_kfree(key->key.dsa.p.iov_base);
			if (key->key.dsa.q.iov_base)
				kssl_kfree(key->key.dsa.q.iov_base);
			if (key->key.dsa.g.iov_base)
				kssl_kfree(key->key.dsa.g.iov_base);
			if (key->key.dsa.priv_key.iov_base)
				kssl_kfree(key->key.dsa.priv_key.iov_base);
			if (key->key.dsa.pub_key.iov_base)
				kssl_kfree(key->key.dsa.pub_key.iov_base);
			break;
		case kssl_key_type_dh:
			if (key->key.dh.p.iov_base)
				kssl_kfree(key->key.dh.p.iov_base);
			if (key->key.dh.g.iov_base)
				kssl_kfree(key->key.dh.g.iov_base);
			break;
#else /* WITH_DH_DSA_SUPPORT */
		case kssl_key_type_dsa:
		case kssl_key_type_dh:
#endif /* WITH_DH_DSA_SUPPORT */
		case kssl_key_type_none:
		case kssl_key_type_unknown:
		default:
			break;
	}
}


/* Formats private key
 * Will change if/when asym rsa code is changed/replaced */

static inline void __reverse_key(u8 *dest, const u8 *src, size_t len)
{
	size_t i;
	for (i = 0; i < len; i += 4)
		memmove(dest+i, src+len-4-i, 4);
}

static u8 *
kssl_key_get_private_buf_rsa(kssl_key_rsa_t *rsa_key)
{
	u16 len;
	u8 *buf;

	if(rsa_key->n.iov_len != rsa_key->d.iov_len)
		return NULL;

	len = rsa_key->n.iov_len + rsa_key->d.iov_len;

	buf = (u8 *)kssl_kmalloc(len +2, GFP_KERNEL);
	if (!buf)
		return NULL;

	*buf = (rsa_key->n.iov_len >> 2) & 0xff;
	*(buf+1) = ((rsa_key->n.iov_len >> 2) >> 8) & 0xff;

	/* Asym has word order from least to most significant */
	__reverse_key(buf + 2, rsa_key->n.iov_base, rsa_key->n.iov_len);
	__reverse_key(buf + 2 + rsa_key->n.iov_len, rsa_key->d.iov_base,
			rsa_key->d.iov_len);

	return buf;
}

/* XXX: Only supports RSA */

u8 *
kssl_key_get_private_buf(kssl_key_t *key)
{
	switch (key->type) {
		case kssl_key_type_rsa:
			return kssl_key_get_private_buf_rsa(&(key->key.rsa));
		case kssl_key_type_dsa:
		case kssl_key_type_dh:
		case kssl_key_type_none:
		case kssl_key_type_unknown:
		default:
			break;
	}

	return NULL;
}


static int
kssl_daemon_cipher_suite_implemented(const cipher_suite_t *cs)
{
	return cipher_suite_nfind((cipher_suite_t *)cs_implemented, cs,
			sizeof(cs_implemented)) ? 1 : 0;
}


static int
kssl_daemon_cipher_suite_implemented_list(const cipher_suite_t *cs_list, 
		size_t bytes)
{
	const cipher_suite_t *a;
	
	a = cs_list;
	while (bytes >= sizeof(cipher_suite_t)) {
		if (!kssl_daemon_cipher_suite_implemented(a))
			return 0;
		a++;
		bytes -= sizeof(cipher_suite_t);
	}

	return 1;
}


cipher_suite_t *
kssl_daemon_cipher_suite_find_list(kssl_daemon_t *daemon, 
		cipher_suite_t *cs_list, size_t bytes)
{
	cipher_suite_t *a;
	cipher_suite_t cs_end = CIPHER_SUITE_END_MARKER;
	
	kssl_daemon_get_read(daemon);
	for (a = (cipher_suite_t *)daemon->cs;
			cipher_suite_cmp(a, &cs_end); a++) {
		if (cipher_suite_nfind(cs_list, a, bytes)) {
			kssl_daemon_put_read(daemon);
			return a;
		}
	}

	kssl_daemon_put_read(daemon);
	return NULL;
}


cipher_suite_t *
kssl_daemon_cipher_suite_find(kssl_daemon_t *daemon, cipher_suite_t *cs)
{
	return kssl_daemon_cipher_suite_find_list(daemon,
			cs, sizeof(cipher_suite_t));
}


int
kssl_daemon_cipher_suite_set_cpy(kssl_daemon_t *daemon, 
		const cipher_suite_t *cs_list, size_t bytes)
{
	cipher_suite_t *new_list;
	cipher_suite_t cs_end = CIPHER_SUITE_END_MARKER;
	
	if (!kssl_daemon_cipher_suite_implemented_list(cs_list, bytes))
		return -EINVAL;

	new_list = kssl_kmalloc(bytes + sizeof(cipher_suite_t), 
			GFP_KERNEL);
	if (!new_list)
		return -ENOMEM;

	memcpy(new_list, cs_list, bytes);
	memcpy(new_list + (bytes / sizeof(cipher_suite_t)), &cs_end, 
			sizeof(cipher_suite_t));

	kssl_daemon_get_write(daemon);
	if (daemon->cs)
		kssl_kfree(daemon->cs);
	daemon->cs = new_list;
	kssl_daemon_put_write(daemon);

	return 0;
}


static int
kssl_daemon_asym_method_implemented(const kssl_asym_method_t am)
{
	const kssl_asym_method_t *a;

	for (a = am_implemented; *a != kssl_asym_method_none; a++) {
		if (am == *a)
			return 1;
	}
	return 0;
}


static int
kssl_daemon_asym_method_implemented_list(const kssl_asym_method_t *am_list, 
		size_t noam)
{
	while (noam--) {
		if (!kssl_daemon_asym_method_implemented(*am_list))
			return 0;
		am_list++;
	}

	return 1;
}


size_t
kssl_daemon_asym_method_list_count(const kssl_asym_method_t *am_list)
{
	return kssl_daemon_asym_method_list_ncount(am_list, ~0UL);
}


size_t
kssl_daemon_asym_method_list_ncount(const kssl_asym_method_t *am_list,
		size_t bytes)
{
	size_t i;

	if(!am_list)
		return 0;

	for (i=0; *am_list != kssl_asym_method_none &&
			(i * sizeof(kssl_asym_method_t) < bytes); 
			am_list++, i++)
		;
	
	return i;
}


int
kssl_daemon_asym_method_set_cpy(kssl_daemon_t *daemon, 
		const kssl_asym_method_t *am_list, size_t noam)
{
	kssl_asym_method_t *new_list;
	
	if (!kssl_daemon_asym_method_implemented_list(am_list, noam))
		return -EINVAL;

	new_list = kssl_kmalloc((noam + 1) * sizeof(kssl_asym_method_t), 
			GFP_KERNEL);
	if (!new_list)
		return -ENOMEM;

	memcpy(new_list, am_list, noam * sizeof(kssl_asym_method_t));
	*(new_list + noam) = kssl_asym_method_none;

	kssl_daemon_get_write(daemon);
	if (daemon->am)
		kssl_kfree(daemon->am);
	daemon->am = new_list;
	kssl_daemon_put_write(daemon);

	return 0;
}



void 
kssl_daemon_list_add(kssl_daemon_t *daemon) 
{
	list_add(&(daemon->list), &kssl_daemon_list);
}


int
kssl_daemon_list_for_each(int (*func)(kssl_daemon_t *daemon, void *data), 
		void *data)
{
	int status;
	int count = 0;
	kssl_daemon_t *daemon;
	struct list_head *c_list;
	struct list_head *t_list;

	list_for_each_safe(c_list, t_list, &kssl_daemon_list) {
		daemon = list_entry(c_list, kssl_daemon_t, list);
		status = func(daemon, data);
		if (status < 0) {
			return status;
		}
		if (status == 0) {
			break;
		}
		count++;
	}

	return count;
}

int __init 
kssl_daemon_init(void) 
{
	INIT_LIST_HEAD(&kssl_daemon_list);
	return 0;
}

void __exit 
kssl_daemon_cleanup(void)
{
	kssl_daemon_t *daemon;
	struct list_head *list;

	list = &kssl_daemon_list;
	
	while(!list_empty(list)) {
		daemon = list_entry(list->next, kssl_daemon_t, list);
		kssl_daemon_destroy(daemon);
	}
}

