/**********************************************************************
 * sockopt.c                                                August 2005
 *
 * KSSLD: An implementation of SSL/TLS in the Linux Kernel
 * Copyright (C) 2005  NTT COMWARE Corporation.
 *
 * This file based in part on code from LVS www.linuxvirtualserver.org
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
 * 02110-1301, USA.
 *
 **********************************************************************/

#include <linux/uio.h>
#include <linux/slab.h>
#include <linux/module.h>
#include <net/sock.h>
#include <asm/uaccess.h>
#include <linux/netfilter.h>

#include "kssl.h"
#include "sockopt.h"
#include "daemon.h"
#include "kssl_alloc.h"
#include "log.h"

/*
 * Proposed Setsockopt:
 *   cipher_suites
 */

static DECLARE_MUTEX(kssl_sockopt_mutex);


static int 
kssl_ctl_cpy_iov_from_lbuf(u8 *buf, size_t *len, struct iovec *iov)
{
	int bytes = 0;
	u32 tmp_len;

	if (*len < sizeof(u32))
		return -EINVAL;

	memcpy(&tmp_len, buf, sizeof(u32));
	bytes += sizeof(u32);
	tmp_len = ntohl(tmp_len);

	if (tmp_len > (size_t)(~0U>>1) || tmp_len > INT_MAX ||
			tmp_len + sizeof(u32) > *len)  {
		KSSL_DEBUG(6, "kssl_ctl_cpy_iov_from_lbuf: "
				"size it too large: %u\n", tmp_len);
		return -EINVAL;
	}
	iov->iov_len = (size_t)tmp_len;

	if (iov->iov_base) 
		kssl_kfree(iov->iov_base);

	if (!iov->iov_len) {
		iov->iov_base = NULL;
		return 0;
	}

	iov->iov_base = (u8 *)kssl_kmalloc(iov->iov_len, GFP_KERNEL);
	if (!iov->iov_base) {
		KSSL_DEBUG(6, "kssl_ctl_cpy_iov_from_lbuf: \n"
				"kssl_kmalloc: %u bytes\n", iov->iov_len);
		return -ENOMEM;
	}

	memcpy(iov->iov_base, buf+bytes, iov->iov_len);
	*len = iov->iov_len + bytes;

	return 0;
}


static int 
kssl_ctl_cpy_iov_to_lbuf(u8 *buf, size_t *len, struct iovec *iov)
{
	int bytes = 0;
	u32 tmp_len;

	if (iov->iov_len && !iov->iov_base)
		return -EINVAL;

	if (iov->iov_len > (u32)(~0U>>1) || iov->iov_len + sizeof(u32) > *len)
		return -EINVAL;

	tmp_len = htonl(iov->iov_len);
	memcpy(buf, &tmp_len, sizeof(u32));
	bytes += sizeof(u32);

	if (!iov->iov_len) 
		return 0;

	memcpy(buf+bytes, iov->iov_base, iov->iov_len);
	*len = iov->iov_len + bytes;

	return 0;
}




static int 
kssl_ctl_cpy_buf_to_lbuf(u8 *out_buf, size_t *out_len, 
		u8 *in_buf, size_t in_len)
{
	struct iovec tmp_iov;

	if (in_len > (size_t)(~0U>>1))
		return -EINVAL;

	tmp_iov.iov_base = in_buf;
	tmp_iov.iov_len = in_len;

	return kssl_ctl_cpy_iov_to_lbuf(out_buf, out_len, &tmp_iov);
}


static int 
kssl_ctl_cert_parse(kssl_daemon_t *daemon, u8 *buf, size_t *len)
{
	return kssl_ctl_cpy_iov_from_lbuf(buf, len, &daemon->cert.cert);
}


#define KSSL_CRT_KEY_PARSE_NEXT(iov, label)                                   \
do {                                                                          \
	size_t tmp_len;                                                       \
	tmp_len = *len - bytes;                                               \
	status = kssl_ctl_cpy_iov_from_lbuf(buf+bytes, &tmp_len, iov);        \
	if (status < 0) {                                                     \
		KSSL_DEBUG(6, "kssl_ctl_key_rsa_parse: "                      \
				"kssl_ctl_cpy_iov_from_lbuf: %s\n", label);   \
		goto error;                                                   \
	}                                                                     \
	bytes += tmp_len;                                                     \
} while(0);


static int 
kssl_ctl_key_rsa_parse(kssl_daemon_t *daemon, u8 *buf, size_t *len)
{
	int status = 0;
	size_t bytes = 0;

	KSSL_CRT_KEY_PARSE_NEXT(&daemon->key.key.rsa.n, "n");
	KSSL_CRT_KEY_PARSE_NEXT(&daemon->key.key.rsa.e, "e");
	KSSL_CRT_KEY_PARSE_NEXT(&daemon->key.key.rsa.d, "d");
	KSSL_CRT_KEY_PARSE_NEXT(&daemon->key.key.rsa.p, "p");
	KSSL_CRT_KEY_PARSE_NEXT(&daemon->key.key.rsa.q, "q");
	KSSL_CRT_KEY_PARSE_NEXT(&daemon->key.key.rsa.dmp1, "dmp1");
	KSSL_CRT_KEY_PARSE_NEXT(&daemon->key.key.rsa.dmq1, "dmq1");
	KSSL_CRT_KEY_PARSE_NEXT(&daemon->key.key.rsa.iqmp, "iqmp");

	*len = bytes;
	return 0;
error:
	kssl_key_destroy_data(&daemon->key);
	return status;
}


#ifdef WITH_DH_DSA_SUPPORT
static int 
kssl_ctl_key_dsa_parse(kssl_daemon_t *daemon, u8 *buf, size_t *len)
{
	int status = 0;
	size_t bytes = 0;

	KSSL_CRT_KEY_PARSE_NEXT(&daemon->key.key.dsa.p, "p");
	KSSL_CRT_KEY_PARSE_NEXT(&daemon->key.key.dsa.q, "q");
	KSSL_CRT_KEY_PARSE_NEXT(&daemon->key.key.dsa.g, "g");
	KSSL_CRT_KEY_PARSE_NEXT(&daemon->key.key.dsa.priv_key, "priv_key");
	KSSL_CRT_KEY_PARSE_NEXT(&daemon->key.key.dsa.pub_key, "pub_key");

	*len = bytes;
	return 0;
error:
	kssl_key_destroy_data(&daemon->key);
	return status;
}


static int 
kssl_ctl_key_dh_parse(kssl_daemon_t *daemon, u8 *buf, size_t *len)
{
	int status = 0;
	size_t bytes = 0;

	KSSL_CRT_KEY_PARSE_NEXT(&daemon->key.key.dh.p, "p");
	KSSL_CRT_KEY_PARSE_NEXT(&daemon->key.key.dh.g, "g");

	*len = bytes;
	return 0;
error:
	kssl_key_destroy_data(&daemon->key);
	return status;
}
#endif /* WITH_DH_DSA_SUPPORT */


static int
kssl_ctl_key_type_parse(kssl_daemon_t *daemon, u8 *buf, size_t *len)
{
	int status;
	struct iovec iov = { iov_base: NULL, iov_len: 0 };

	status = kssl_ctl_cpy_iov_from_lbuf(buf, len, &iov);
	if (status < 0) {
		KSSL_DEBUG(6, "kssl_ctl_key_type_parse: "
				"kssl_ctl_cpy_iov_from_lbuf\n");
		return status;
	}

	status = 0;
	if (!strncmp(KSSL_KEY_TYPE_RSA, iov.iov_base, iov.iov_len)) {
		daemon->key.type = kssl_key_type_rsa;
	}
#ifdef WITH_DH_DSA_SUPPORT
	else if (!strncmp(KSSL_KEY_TYPE_DSA, iov.iov_base, iov.iov_len)) {
		daemon->key.type = kssl_key_type_dsa;
	}
	else if (!strncmp(KSSL_KEY_TYPE_DH, iov.iov_base, iov.iov_len)) {
		daemon->key.type = kssl_key_type_dh;
	}
#endif /* WITH_DH_DSA_SUPPORT */
	else {
		KSSL_DEBUG(6, "kssl_ctl_key_type_parse: "
				"unknown/unsported key type\n");
		status = -EINVAL;
	}

	kssl_kfree(iov.iov_base);

	return 0;
}


static int 
kssl_ctl_cert_key_parse(kssl_ctl_t *ctl, u8 *buf, size_t *len)
{
	size_t tmp_len;
	size_t bytes = 0;
	int status = 0;
	kssl_daemon_t *daemon;
	kssl_daemon_t *new_daemon = NULL;

	new_daemon = kssl_daemon_create();
	if (!new_daemon) {
		KSSL_DEBUG(6, "kssl_ctl_cert_key_parse: kssl_daemon_create\n");
		return -ENOMEM;
	}

	tmp_len = *len;
	status = kssl_ctl_key_type_parse(new_daemon, buf, &tmp_len);
	if (status < 0) {
		KSSL_DEBUG(6, "kssl_ctl_cert_key_parse: "
				"kssl_ctl_key_type_parse\n");
		goto error;
	}
	bytes += tmp_len;

	switch (new_daemon->key.type) {
		case kssl_key_type_rsa:
			tmp_len = *len - bytes;
			status = kssl_ctl_key_rsa_parse(new_daemon, buf+bytes, 
					&tmp_len);
			if (status < 0) {
				KSSL_DEBUG(6, "kssl_ctl_cert_key_parse: "
						"kssl_ctl_key_rsa_parse\n");
				goto error;
			}
			bytes += tmp_len;

			tmp_len = *len - bytes;
			status = kssl_ctl_cert_parse(new_daemon, buf+bytes, 
					&tmp_len);
			if (status < 0) {
				KSSL_DEBUG(6, "kssl_ctl_cert_key_parse: "
						"kssl_ctl_cert_parse\n");
				goto error;
			}
			bytes += tmp_len;

			status = 0;
			break;
#ifdef WITH_DH_DSA_SUPPORT
		case kssl_key_type_dsa:
			tmp_len = *len - bytes;
			status = kssl_ctl_key_dsa_parse(new_daemon, buf+bytes, 
					&tmp_len);
			if (status < 0) {
				KSSL_DEBUG(6, "kssl_ctl_cert_key_parse: "
						"kssl_ctl_key_dsa_parse\n");
				goto error;
			}
			bytes += tmp_len;

			tmp_len = *len - bytes;
			status = kssl_ctl_cert_parse(new_daemon, buf+bytes, 
					&tmp_len);
			if (status < 0) {
				KSSL_DEBUG(6, "kssl_ctl_cert_key_parse: "
						"kssl_ctl_cert_parse\n");
				goto error;
			}
			bytes += tmp_len;

			status = 0;
			break;
		case kssl_key_type_dh:
			tmp_len = *len - bytes;
			status = kssl_ctl_key_dh_parse(new_daemon, buf+bytes, 
					&tmp_len);
			if (status < 0) {
				KSSL_DEBUG(6, "kssl_ctl_cert_key_parse: "
						"kssl_ctl_key_dh_parse\n");
				goto error;
			}
			bytes += tmp_len;

			status = 0;
			break;
#else /* #if WITH_DH_DSA_SUPPORT */
		case kssl_key_type_dsa:
		case kssl_key_type_dh:
#endif /* #if WITH_DH_DSA_SUPPORT */
		case kssl_key_type_none:
		case kssl_key_type_unknown:
		default:
			status = -EINVAL;
			KSSL_DEBUG(6, "kssl_ctl_cert_key_parse: "
					"unknown/unsuported key type\n");
			goto error;
	}

	*len = bytes;

	daemon = kssl_daemon_find(ntohl(ctl->vaddr), ntohs(ctl->vport));
	if (!daemon) {
		KSSL_DEBUG(6, "kssl_ctl_cert_cert_parse: "
				"kssl_daemon_find\n");
		return -ENOENT;
	}
	kssl_daemon_get_write(daemon);
	kssl_daemon_cpy_cert(daemon, new_daemon);
	kssl_daemon_put_write(daemon);

error:
	kssl_kfree(new_daemon);
	return status;
}


static int 
kssl_ctl_r_ip_port_parse(kssl_ctl_t *ctl, u8 *buf, size_t *len)
{
	u32 raddr;
	u16 rport;
	kssl_daemon_t *daemon;

	if (*len < sizeof(u32) + sizeof(u16)) {
		KSSL_DEBUG(3, "kssl_ctl_r_ip_port_parse: "
				"data too short\n");
		return -EINVAL;
	}

	memcpy((u8 *)&raddr, buf, sizeof(u32));
	memcpy((u8 *)&rport, buf + sizeof(u32), sizeof(u16));

	daemon = kssl_daemon_find(ntohl(ctl->vaddr), ntohs(ctl->vport));
	if (!daemon) {
		KSSL_DEBUG(3, "kssl_ctl_r_ip_port_parse: "
				"kssl_daemon_find %08x:%04x\n",
				ntohl(ctl->vaddr), ntohs(ctl->vport));
		return -ENOENT;
	}

	kssl_daemon_get_write(daemon);
	daemon->raddr = ntohl(raddr);
	daemon->rport = ntohs(rport);
	kssl_daemon_put_write(daemon);

	*len = sizeof(u32) + sizeof(u16);
	return 0;
}


static int 
kssl_ctl_ciphers_parse(kssl_ctl_t *ctl, u8 *buf, size_t *len)
{
	int status;
	size_t bytes;
	kssl_daemon_t *daemon;

	daemon = kssl_daemon_find(ntohl(ctl->vaddr), ntohs(ctl->vport));
	if (!daemon) {
		KSSL_DEBUG(3, "kssl_ctl_ciphers_parse: "
				"kssl_daemon_find %08x:%04x\n",
				ntohl(ctl->vaddr), ntohs(ctl->vport));
		return -ENOENT;
	}

	bytes = cipher_suite_ncount((cipher_suite_t *)buf, *len);

	status = kssl_daemon_cipher_suite_set_cpy(daemon, 
			(cipher_suite_t *)buf, bytes);
	if (status < 0) {
		KSSL_DEBUG(6, "kssl_ctl_ciphers_parse: "
				"kssl_daemon_cipher_suite_set_cpy\n");
		return -ENOENT;
	}

	*len = bytes + sizeof(cipher_suite_t);
	return 0;
}


int
kssl_ctl_daemon_find_or_new(kssl_ctl_t *ctl)
{
	if (! kssl_daemon_find_or_create_and_add(ntohl(ctl->vaddr),
				ntohs(ctl->vport))) {
		KSSL_DEBUG(3, "kssl_ctl_daemon_find_or_new: "
				"kssl_daemon_find_or_create_and_add");
		return -ENOENT;
	}

	return 0;
}


static int 
kssl_ctl_daemon_add(kssl_ctl_t *ctl, u8 *buf, size_t *len)
{
	u32 vaddr;
	u16 vport;
	kssl_daemon_t *daemon;

	vaddr = ntohl(ctl->vaddr);
	vport = ntohs(ctl->vport);

	daemon = kssl_daemon_find(vaddr, vport);

	/* Daemon already exists? Our work here is done */
	if (daemon)
		return 0;

	daemon = kssl_daemon_new(vaddr, vport);
	if (!daemon) {
		KSSL_DEBUG(3, "kssl_ctl_daemon_add: "
				"kssl_daemon_new %08x:%04x\n", vaddr, vport);
		return -ENOMEM;
	}
	kssl_daemon_list_add(daemon);

	*len = 0;
	return 0;
}


static int 
kssl_ctl_daemon_del(kssl_ctl_t *ctl, u8 *buf, size_t *len)
{
	u32 vaddr;
	u16 vport;
	kssl_daemon_t *daemon;

	vaddr = ntohl(ctl->vaddr);
	vport = ntohs(ctl->vport);

	daemon = kssl_daemon_find(vaddr, vport);

	/* No daemon? Our work here is done */
	if (!daemon) 
		return 0;

	KSSL_DEBUG(3, "Deleting daemon: %08x:%04x\n", vaddr, vport);

	kssl_daemon_destroy(daemon);

	*len = 0;
	return 0;
}


static int 
kssl_ctl_flush_action(kssl_daemon_t *daemon, void *data)
{
	KSSL_DEBUG(3, "Deleting daemon: %08x:%04x\n", 
				daemon->vaddr, daemon->vport);
	kssl_daemon_destroy(daemon);
	return 1;
}

static int 
kssl_ctl_flush(kssl_ctl_t *ctl, u8 *buf, size_t *len)
{
	kssl_daemon_foreach(kssl_ctl_flush_action, NULL);
	*len = 0;
	return 0;
}


static int 
kssl_ctl_daemon_mode_parse(kssl_ctl_t *ctl, u8 *buf, size_t *len)
{
	u32 vaddr;
	u16 vport;
	kssl_daemon_t *daemon;
	kssl_daemon_mode_t mode;

	if (*len < sizeof(u32)) {
		KSSL_DEBUG(3, "kssl_ctl_daemon_mode_parse: "
				"data too short\n");
		return -EINVAL;
	}

	vaddr = ntohl(ctl->vaddr);
	vport = ntohs(ctl->vport);
	memcpy((u8 *)&mode, buf, sizeof(u32));
	mode = ntohl(mode);

	daemon = kssl_daemon_find_or_create_and_add(vaddr, vport);
	if (!daemon){
		KSSL_DEBUG(3, "kssl_ctl_daemon_find_or_new: "
				"kssl_daemon_find_or_create_and_add");
		return -ENOENT;
	}

	KSSL_DEBUG(3, "Daemon: %08x:%04x Updating mode to %u\n",
			vaddr, vport, mode);

	kssl_daemon_get_write(daemon);
	daemon->mode = mode;
	kssl_daemon_put_write(daemon);

	*len = sizeof(u32);
	return 0;
}


static int 
kssl_ctl_asym_methods_parse(kssl_ctl_t *ctl, u8 *buf, size_t *len)
{
	int status;
	size_t noam;
	kssl_daemon_t *daemon;

	daemon = kssl_daemon_find(ntohl(ctl->vaddr), ntohs(ctl->vport));
	if (!daemon) {
		KSSL_DEBUG(3, "kssl_ctl_asym_methods_parse: "
				"kssl_daemon_find %08x:%04x\n",
				ntohl(ctl->vaddr), ntohs(ctl->vport));
		return -ENOENT;
	}

	noam = kssl_daemon_asym_method_list_ncount((kssl_asym_method_t *)buf, 
			*len);

	status = kssl_daemon_asym_method_set_cpy(daemon, 
			(kssl_asym_method_t *)buf, noam);
	if (status < 0) {
		KSSL_DEBUG(6, "kssl_ctl_asym_methods_parse: "
				"kssl_daemon_cipher_suite_set_cpy\n");
		return -ENOENT;
	}

	*len = (noam + 1) * sizeof(kssl_asym_method_t);
	return 0;
}


static size_t 
kssl_ctl_cert_len(kssl_daemon_t *daemon)
{
	return daemon->cert.cert.iov_len + sizeof(u32);
}


static size_t 
kssl_ctl_key_len_rsa(kssl_daemon_t *daemon)
{
	size_t len = 0;

	len += daemon->key.key.rsa.n.iov_len + sizeof(u32);
	len += daemon->key.key.rsa.e.iov_len + sizeof(u32);
	len += daemon->key.key.rsa.d.iov_len + sizeof(u32);
	len += daemon->key.key.rsa.p.iov_len + sizeof(u32);
	len += daemon->key.key.rsa.q.iov_len + sizeof(u32);
	len += daemon->key.key.rsa.dmp1.iov_len + sizeof(u32);
	len += daemon->key.key.rsa.dmq1.iov_len + sizeof(u32);
	len += daemon->key.key.rsa.iqmp.iov_len + sizeof(u32);

	return len;
}

#ifdef WITH_DH_DSA_SUPPORT
static size_t 
kssl_ctl_key_len_dsa(kssl_daemon_t *daemon)
{
	size_t len = 0;

	len += daemon->key.key.dsa.p.iov_len + sizeof(u32);
	len += daemon->key.key.dsa.q.iov_len + sizeof(u32);
	len += daemon->key.key.dsa.g.iov_len + sizeof(u32);
	len += daemon->key.key.dsa.priv_key.iov_len + sizeof(u32);
	len += daemon->key.key.dsa.pub_key.iov_len + sizeof(u32);

	return len;
}


static size_t 
kssl_ctl_key_len_dh(kssl_daemon_t *daemon)
{
	size_t len = 0;

	len += daemon->key.key.dh.p.iov_len + sizeof(u32);
	len += daemon->key.key.dh.g.iov_len + sizeof(u32);

	return len;
}
#endif /* WITH_DH_DSA_SUPPORT */


static int 
kssl_ctl_cert_key_len(kssl_ctl_t *ctl, size_t *len)
{
	int status = 0;
	kssl_daemon_t *daemon;

	daemon = kssl_daemon_find(ntohl(ctl->vaddr), ntohs(ctl->vport));
	if (!daemon) {
		KSSL_DEBUG(6, "kssl_ctl_cert_key_len: kssl_daemon_find\n");
		return -ENOENT;
	}
	kssl_daemon_get_read(daemon);

	*len = 0;
	switch (daemon->key.type) {
		case kssl_key_type_rsa:
			*len += strlen(KSSL_KEY_TYPE_RSA) + sizeof(u32);
			*len += kssl_ctl_key_len_rsa(daemon);
			*len += kssl_ctl_cert_len(daemon);
			break;
#ifdef WITH_DH_DSA_SUPPORT
		case kssl_key_type_dsa:
			*len += strlen(KSSL_KEY_TYPE_DSA) + sizeof(u32);
			*len += kssl_ctl_key_len_dsa(daemon);
			*len += kssl_ctl_cert_len(daemon);
			break;
		case kssl_key_type_dh:
			*len += strlen(KSSL_KEY_TYPE_DH) + sizeof(u32);
			*len += kssl_ctl_key_len_dh(daemon);
			break;
#else /* WITH_DH_DSA_SUPPORT */
		case kssl_key_type_dsa:
		case kssl_key_type_dh:
#endif /* WITH_DH_DSA_SUPPORT */
		case kssl_key_type_none:
			*len += sizeof(u32);
			break;
		case kssl_key_type_unknown:
		default:
			KSSL_DEBUG(6, "kssl_ctl_cert_key_len: "
					"unknown/unsuported key type\n");
			status = -EINVAL;
	}

	kssl_daemon_put_read(daemon);
	return status;
}


static int
kssl_ctl_cert_key_len_write(kssl_ctl_t *ctl, u8 *buf, size_t *len)
{
	int status;
	u32 out_len;

	if (*len < sizeof(u32))
		return -EINVAL;

	status = kssl_ctl_cert_key_len(ctl, &out_len);
	if (status < 0) {
		KSSL_DEBUG(6, "kssl_ctl_cert_key_len\n");
		return status;
	}
	out_len += sizeof(kssl_ctl_t);
	out_len = htonl(out_len);

	memcpy(buf, &out_len, sizeof(u32));

	*len = sizeof(u32);
	return 0;
}


static int
kssl_ctl_r_ip_port_write(kssl_ctl_t *ctl, u8 *buf, size_t *len)
{
	u32 raddr;
	u16 rport;
	kssl_daemon_t *daemon;

	if (*len < sizeof(u32) + sizeof(u16))
		return -EINVAL;

	daemon = kssl_daemon_find(ntohl(ctl->vaddr), ntohs(ctl->vport));
	if (!daemon) {
		KSSL_DEBUG(6, "kssl_ctl_cert_key_len: kssl_daemon_find\n");
		return -ENOENT;
	}

	kssl_daemon_get_read(daemon);
	raddr = htonl(daemon->raddr);
	rport = htons(daemon->rport);
	kssl_daemon_put_read(daemon);

	memcpy(buf, &raddr, sizeof(u32));
	memcpy(buf + sizeof(u32), &rport, sizeof(u16));

	*len = sizeof(u32) + sizeof(u16);
	return 0;
}


static int 
kssl_ctl_ciphers_len(kssl_ctl_t *ctl, size_t *len)
{
	int status = 0;
	kssl_daemon_t *daemon;

	daemon = kssl_daemon_find(ntohl(ctl->vaddr), ntohs(ctl->vport));
	if (!daemon) {
		KSSL_DEBUG(6, "kssl_ctl_cert_key_len: kssl_daemon_find\n");
		return -ENOENT;
	}

	kssl_daemon_get_read(daemon);
	*len = cipher_suite_count(daemon->cs);
	kssl_daemon_put_read(daemon);

	return status;
}


static int
kssl_ctl_ciphers_len_write(kssl_ctl_t *ctl, u8 *buf, size_t *len)
{
	int status;
	u32 out_len;

	if (*len < sizeof(u32))
		return -EINVAL;

	status = kssl_ctl_ciphers_len(ctl, &out_len);
	if (status < 0) {
		KSSL_DEBUG(6, "kssl_ctl_ciphers_len\n");
		return status;
	}
	out_len += sizeof(kssl_ctl_t);
	out_len = htonl(out_len);

	memcpy(buf, &out_len, sizeof(u32));

	*len = sizeof(u32);
	return 0;
}


static int
kssl_ctl_ciphers_write(kssl_ctl_t *ctl, u8 *buf, size_t *len)
{
	int status;
	size_t out_len;
	kssl_daemon_t *daemon;

	status = kssl_ctl_ciphers_len(ctl, &out_len);
	if (status < 0) {
		KSSL_DEBUG(6, "kssl_ctl_ciphers_len\n");
		return status;
	}

	if (*len < out_len)
		return -EINVAL;

	daemon = kssl_daemon_find(ntohl(ctl->vaddr), ntohs(ctl->vport));
	if (!daemon) {
		KSSL_DEBUG(6, "kssl_ctl_ciphers_len: kssl_daemon_find\n");
		return -ENOENT;
	}

	kssl_daemon_get_read(daemon);
	memcpy(buf, daemon->cs, out_len);
	kssl_daemon_put_read(daemon);

	*len = out_len;
	return 0;
}


static int 
kssl_ctl_asym_methods_len(kssl_ctl_t *ctl, size_t *len)
{
	int status = 0;
	kssl_daemon_t *daemon;

	daemon = kssl_daemon_find(ntohl(ctl->vaddr), ntohs(ctl->vport));
	if (!daemon) {
		KSSL_DEBUG(6, "kssl_ctl_asym_methods_len: "
				"kssl_daemon_find\n");
		return -ENOENT;
	}

	kssl_daemon_get_read(daemon);
	*len = kssl_daemon_asym_method_list_count(daemon->am) *
		sizeof(kssl_asym_method_t);
	kssl_daemon_put_read(daemon);

	return status;
}


static int
kssl_ctl_asym_methods_len_write(kssl_ctl_t *ctl, u8 *buf, size_t *len)
{
	int status;
	u32 out_len;

	if (*len < sizeof(u32))
		return -EINVAL;

	status = kssl_ctl_asym_methods_len(ctl, &out_len);
	if (status < 0) {
		KSSL_DEBUG(6, "kssl_ctl_asym_methods_len\n");
		return status;
	}
	out_len += sizeof(kssl_ctl_t);
	out_len = htonl(out_len);

	memcpy(buf, &out_len, sizeof(u32));

	*len = sizeof(u32);
	return 0;
}


static int
kssl_ctl_asym_methods_write(kssl_ctl_t *ctl, u8 *buf, size_t *len)
{
	int status;
	size_t out_len;
	kssl_daemon_t *daemon;

	status = kssl_ctl_asym_methods_len(ctl, &out_len);
	if (status < 0) {
		KSSL_DEBUG(6, "kssl_ctl_asym_methods_len\n");
		return status;
	}

	if (*len < out_len)
		return -EINVAL;

	daemon = kssl_daemon_find(ntohl(ctl->vaddr), ntohs(ctl->vport));
	if (!daemon) {
		KSSL_DEBUG(6, "kssl_ctl_cert_key_len: kssl_daemon_find\n");
		return -ENOENT;
	}

	kssl_daemon_get_read(daemon);
	memcpy(buf, daemon->am, out_len);
	kssl_daemon_put_read(daemon);

	*len = out_len;
	return 0;
}


static int
kssl_ctl_daemon_mode_write(kssl_ctl_t *ctl, u8 *buf, size_t *len)
{
	kssl_daemon_t *daemon;
	kssl_daemon_mode_t mode;

	if (*len < sizeof(u32))
		return -EINVAL;

	daemon = kssl_daemon_find(ntohl(ctl->vaddr), ntohs(ctl->vport));
	if (!daemon) {
		KSSL_DEBUG(6, "kssl_ctl_daemon_mode_len: "
				"kssl_daemon_find\n");
		return -ENOENT;
	}

	kssl_daemon_get_read(daemon);
	mode = htonl(daemon->mode);
	kssl_daemon_put_read(daemon);

	memcpy(buf, &mode, sizeof(u32));

	*len = sizeof(u32);
	return 0;
}


typedef struct {
	u8 *buf;
	size_t offset;
	size_t len;
} kssl_ctl_daemons_write_flim_flam_t;


static int
kssl_ctl_daemons_write_func(kssl_daemon_t *daemon, void *data) 
{
	u32 addr;
	u16 port;
	kssl_ctl_daemons_write_flim_flam_t *flim_flam;

	flim_flam = (kssl_ctl_daemons_write_flim_flam_t *)data;

	if (flim_flam->len < flim_flam->offset + sizeof(u32)+sizeof(u16))
		return 0;

	kssl_daemon_get_read(daemon);
	addr = htonl(daemon->vaddr);
	memcpy(flim_flam->buf + flim_flam->offset, 
			&addr, sizeof(u32));
	port = htons(daemon->vport);
	memcpy(flim_flam->buf + flim_flam->offset + sizeof(u32), 
			&port, sizeof(u16));
	kssl_daemon_put_read(daemon);

	flim_flam->offset += sizeof(u32) + sizeof(u16);

	return 1;

}

static int 
kssl_ctl_daemons_write(kssl_ctl_t *ctl, u8 *buf, size_t *len)
{
	int status;
	kssl_ctl_daemons_write_flim_flam_t flim_flam;

	flim_flam.buf = buf;
	flim_flam.offset = 0;
	flim_flam.len = *len;

	status = kssl_daemon_list_for_each(kssl_ctl_daemons_write_func,
			&flim_flam);
	if (status < 0) {
		KSSL_DEBUG(3, "kssl_ctl_daemons_write: "
				"kssl_ctl_daemons_write_func\n");
		return status;
	}

	*len = flim_flam.len;
	return 0;
}


static int
kssl_ctl_daemons_len_write_func(kssl_daemon_t *daemon, void *data) 
{
	*(u32 *)data += sizeof(u32)+sizeof(u16);
	return 1;
}


static int 
kssl_ctl_daemons_len(kssl_ctl_t *ctl, size_t *len)
{
	int status;

	*len = 0;
	status = kssl_daemon_list_for_each(kssl_ctl_daemons_len_write_func,
			len);
	if (status < 0) {
		KSSL_DEBUG(3, "kssl_ctl_daemons_len: "
				"kssl_ctl_daemons_len_write_func\n");
		return status;
	}

	return 0;
}


static int 
kssl_ctl_daemons_len_write(kssl_ctl_t *ctl, u8 *buf, size_t *len)
{
	int status;
	u32 out_len;

	out_len = 0;
	status = kssl_daemon_list_for_each(kssl_ctl_daemons_len_write_func,
			&out_len);
	if (status < 0) {
		KSSL_DEBUG(3, "kssl_ctl_daemons_len_write: "
				"kssl_ctl_daemons_len_write_func\n");
		return status;
	}
	out_len += sizeof(kssl_ctl_t);
	out_len = htonl(out_len);

	memcpy(buf, &out_len, sizeof(u32));

	*len = sizeof(u32);
	return 0;
}


static int
kssl_ctl_cert_write(kssl_daemon_t *daemon, u8 *buf, size_t *len)
{
	return kssl_ctl_cpy_iov_to_lbuf(buf, len, &daemon->cert.cert);
}


#define KSSL_CRT_KEY_WRITE_NEXT(iov)                                          \
do {                                                                          \
	size_t tmp_len;                                                       \
	tmp_len = *len - bytes;                                               \
	status = kssl_ctl_cpy_iov_to_lbuf(buf+bytes, &tmp_len, iov);          \
	if (status < 0)                                                       \
		return status;                                                \
	bytes += tmp_len;                                                     \
} while(0);


static int
kssl_ctl_key_rsa_write(kssl_daemon_t *daemon, u8 *buf, size_t *len)
{
	int status;
	size_t bytes = 0;

	KSSL_CRT_KEY_WRITE_NEXT(&daemon->key.key.rsa.n);
	KSSL_CRT_KEY_WRITE_NEXT(&daemon->key.key.rsa.e);
	KSSL_CRT_KEY_WRITE_NEXT(&daemon->key.key.rsa.d);
	KSSL_CRT_KEY_WRITE_NEXT(&daemon->key.key.rsa.p);
	KSSL_CRT_KEY_WRITE_NEXT(&daemon->key.key.rsa.q);
	KSSL_CRT_KEY_WRITE_NEXT(&daemon->key.key.rsa.dmp1);
	KSSL_CRT_KEY_WRITE_NEXT(&daemon->key.key.rsa.dmq1);
	KSSL_CRT_KEY_WRITE_NEXT(&daemon->key.key.rsa.iqmp);

	*len = bytes;
	return 0;
}


#ifdef WITH_DH_DSA_SUPPORT
static int
kssl_ctl_key_dsa_write(kssl_daemon_t *daemon, u8 *buf, size_t *len)
{
	int status;
	size_t bytes = 0;

	KSSL_CRT_KEY_WRITE_NEXT(&daemon->key.key.dsa.p);
	KSSL_CRT_KEY_WRITE_NEXT(&daemon->key.key.dsa.q);
	KSSL_CRT_KEY_WRITE_NEXT(&daemon->key.key.dsa.g);
	KSSL_CRT_KEY_WRITE_NEXT(&daemon->key.key.dsa.priv_key);
	KSSL_CRT_KEY_WRITE_NEXT(&daemon->key.key.dsa.pub_key);

	*len = bytes;
	return 0;
}


static int
kssl_ctl_key_dh_write(kssl_daemon_t *daemon, u8 *buf, size_t *len)
{
	int status;
	size_t bytes = 0;

	KSSL_CRT_KEY_WRITE_NEXT(&daemon->key.key.dh.p);
	KSSL_CRT_KEY_WRITE_NEXT(&daemon->key.key.dh.g);

	*len = bytes;
	return 0;
}
#endif /* WITH_DH_DSA_SUPPORT */


static int
kssl_ctl_cert_key_write(kssl_ctl_t *ctl, u8 *buf, size_t *len)
{
	int status;
	size_t tmp_len;
	size_t bytes = 0;
	kssl_daemon_t *daemon;

	daemon = kssl_daemon_find(ntohl(ctl->vaddr), ntohs(ctl->vport));
	if (!daemon) {
		KSSL_DEBUG(6, "kssl_ctl_cert_key_write: kssl_daemon_find\n");
		return -ENOENT;
	}
	kssl_daemon_get_read(daemon);

	switch (daemon->key.type) {
		case kssl_key_type_rsa:
			tmp_len = *len - bytes;
			kssl_ctl_cpy_buf_to_lbuf(buf+bytes, &tmp_len,
					KSSL_KEY_TYPE_RSA,
					strlen(KSSL_KEY_TYPE_RSA));
			bytes += tmp_len;

			tmp_len = *len - bytes;
			status = kssl_ctl_key_rsa_write(daemon, buf+bytes, 
					&tmp_len);
			if (status < 0) 
				return status;
			bytes += tmp_len;

			tmp_len = *len - bytes;
			status = kssl_ctl_cert_write(daemon, buf+bytes, 
					&tmp_len);
			if (status < 0) 
				return status;
			bytes += tmp_len;

			status = 0;
			break;
#ifdef WITH_DH_DSA_SUPPORT
		case kssl_key_type_dsa:
			tmp_len = *len - bytes;
			kssl_ctl_cpy_buf_to_lbuf(buf+bytes, &tmp_len,
					KSSL_KEY_TYPE_DSA,
					strlen(KSSL_KEY_TYPE_DSA));
			bytes += tmp_len;

			tmp_len = *len - bytes;
			status = kssl_ctl_key_dsa_write(daemon, buf+bytes, 
					&tmp_len);
			if (status < 0) 
				return status;
			bytes += tmp_len;

			tmp_len = *len - bytes;
			status = kssl_ctl_cert_write(daemon, buf+bytes, 
					&tmp_len);
			if (status < 0) 
				return status;
			bytes += tmp_len;

			status = 0;
			break;
		case kssl_key_type_dh:
			tmp_len = *len - bytes;
			kssl_ctl_cpy_buf_to_lbuf(buf+bytes, &tmp_len,
					KSSL_KEY_TYPE_DH,
					strlen(KSSL_KEY_TYPE_DH));
			bytes += tmp_len;

			tmp_len = *len - bytes;
			status = kssl_ctl_key_dh_write(daemon, buf+bytes, 
					&tmp_len);
			if (status < 0) 
				return status;
			bytes += tmp_len;

			status = 0;
			break;
#else /* WITH_DH_DSA_SUPPORT */
		case kssl_key_type_dsa:
		case kssl_key_type_dh:
#endif /* WITH_DH_DSA_SUPPORT */
		case kssl_key_type_none:
		case kssl_key_type_unknown:
		default:
			status = -EINVAL;
	}

	kssl_daemon_put_read(daemon);
	*len = bytes;
	return status;
}


static int
kssl_ctl_set(struct sock *sk, int cmd, void *ubuf, size_t len)
{
	u8 *kbuf = NULL;
	kssl_ctl_t ctl;
	int status = 0;
	size_t tmp_len;
	int (*parse_func)(kssl_ctl_t *ctl, u8 *buf, size_t *len);
	char *parse_func_tag;

	/* Sanity check on input data */
	if (len < sizeof(kssl_ctl_t)) {
		KSSL_DEBUG(3, "kssl_ctl_set: len %u < %u\n",
				len, sizeof(kssl_ctl_t));
		return -EINVAL;
	}
	else if (len > 40960) { /* Aribtary value */
		KSSL_DEBUG(3, "kssl_ctl_set: len %u > 40960\n", len);
		return -EINVAL;
	}

	kbuf = (u8 *)kssl_kmalloc(len, GFP_KERNEL);
	if(!kbuf) {
		KSSL_DEBUG(6, "kssl_ctl_set: kssl_kmalloc\n");
		return -ENOMEM;
	}

	if (copy_from_user(kbuf, ubuf, len)) {
		status = -EFAULT;
		goto leave_free;
	}
	memcpy(&ctl, kbuf, sizeof(kssl_ctl_t));

	if (ntohl(ctl.version) != KSSL_CTL_VERSION) {
		 KSSL_DEBUG(3, "kssl_ctl_set: invalid version "
				"%u %u.%u.%u (%08x) != %u %u.%u.%u (%08x)\n",
				(ntohl(ctl.version) >> 24) & 0xff,
				(ntohl(ctl.version) >> 16) & 0xff,
				(ntohl(ctl.version) >> 8) & 0xff,
				ntohl(ctl.version) & 0xff,
				ntohl(ctl.version),
				KSSL_CTL_VERSION_MAGIC,
				KSSL_CTL_VERSION_MAJOR,
				KSSL_CTL_VERSION_MINOR,
				KSSL_CTL_VERSION_PATCH,
				KSSL_CTL_VERSION);
		status = -EINVAL;
		goto leave_free;
	}

	MOD_INC_USE_COUNT;
	if (down_interruptible(&kssl_sockopt_mutex)) {
		status = -ERESTARTSYS;
		goto leave_dec;
	}

	switch(cmd) {
		case KSSL_CTL_CERT_KEY:
			parse_func = kssl_ctl_cert_key_parse;
			parse_func_tag = "kssl_ctl_cert_key_parse";
			break;
		case KSSL_CTL_R_IP_PORT:
			parse_func = kssl_ctl_r_ip_port_parse;
			parse_func_tag = "kssl_ctl_r_ip_port_parse";
			break;
		case KSSL_CTL_CIPHERS:
			parse_func = kssl_ctl_ciphers_parse;
			parse_func_tag = "kssl_ctl_ciphers_parse";
			break;
		case KSSL_CTL_DAEMON_ADD:
			parse_func = kssl_ctl_daemon_add;
			parse_func_tag = "kssl_ctl_daemon_add";
			break;
		case KSSL_CTL_DAEMON_DEL:
			parse_func = kssl_ctl_daemon_del;
			parse_func_tag = "kssl_ctl_daemon_del";
			break;
		case KSSL_CTL_DAEMON_MODE:
			parse_func = kssl_ctl_daemon_mode_parse;
			parse_func_tag = "kssl_ctl_daemon_mode_parse";
			break;
		case KSSL_CTL_ASYM_METHODS:
			parse_func = kssl_ctl_asym_methods_parse;
			parse_func_tag = "kssl_ctl_asym_methods_parse";
			break;
		case KSSL_CTL_FLUSH:
			parse_func = kssl_ctl_flush;
			parse_func_tag = "kssl_ctl_flush";
			break;
		case KSSL_CTL_DAEMONS:
		case KSSL_CTL_CERT_KEY_LEN:
		case KSSL_CTL_CIPHERS_LEN:
		case KSSL_CTL_ASYM_METHODS_LEN:
		default:
			 KSSL_DEBUG(3, "kssl_ctl_set: "
					"unknown/unsported request\n");
			status = -EINVAL;
			goto leave_up;
	}

	tmp_len = len - sizeof(kssl_ctl_t);
	status = parse_func(&ctl, kbuf+sizeof(kssl_ctl_t), &tmp_len);
	if (status < 0) {
		 KSSL_DEBUG(6, "kssl_ctl_set: %s\n", parse_func_tag);
		goto leave_up;
	}

	status = 0;
leave_up:
	up(&kssl_sockopt_mutex);
leave_dec:
	MOD_DEC_USE_COUNT;
leave_free:
	if (kbuf)
		kssl_kfree(kbuf);
	return status;
}


static int
kssl_ctl_get(struct sock *sk, int cmd, void *ubuf, int *len)
{
	u8 *kbuf = NULL;
	kssl_ctl_t ctl;
	int status = 0;
	size_t out_len;
	size_t tmp_len;
	int (*get_func)(kssl_ctl_t *ctl, u8 *buf, size_t *len);
	const char *get_func_tag = NULL;

	/* Sanity check on input data */
	if (*len < sizeof(kssl_ctl_t)) {
		 KSSL_DEBUG(3, "kssl_ctl_get: len %u < %u\n",
				*len, sizeof(kssl_ctl_t));
		return -EINVAL;
	}
	else if (*len > 40960) { /* Aribtary value */
		 KSSL_DEBUG(3, "kssl_ctl_get: len %u > 40960\n", *len);
		return -EINVAL;
	}

	kbuf = (u8 *)kssl_kmalloc(*len, GFP_KERNEL);
	if(!kbuf) {
		 KSSL_DEBUG(6, "kssl_ctl_get: kssl_kmalloc\n");
		return -ENOMEM;
	}
	memset(kbuf, 0, *len);

	if (copy_from_user(kbuf, ubuf, *len)) {
		status = -EFAULT;
		goto leave_free;
	}
	memcpy(&ctl, kbuf, sizeof(kssl_ctl_t));

	if (ntohl(ctl.version) != KSSL_CTL_VERSION) {
		 KSSL_DEBUG(3, "kssl_ctl_get: invalid version "
				"%u %u.%u.%u (%08x) != %u %u.%u.%u (%08x)\n",
				(ntohl(ctl.version) >> 24) & 0xff,
				(ntohl(ctl.version) >> 16) & 0xff,
				(ntohl(ctl.version) >> 8) & 0xff,
				ntohl(ctl.version) & 0xff,
				ntohl(ctl.version),
				KSSL_CTL_VERSION_MAGIC,
				KSSL_CTL_VERSION_MAJOR,
				KSSL_CTL_VERSION_MINOR,
				KSSL_CTL_VERSION_PATCH,
				KSSL_CTL_VERSION);
		status = -EINVAL;
		goto leave_free;
	}

	/* More specific length check */
	out_len = 0;
	
	switch(cmd) {
		case KSSL_CTL_CERT_KEY:
			status = kssl_ctl_cert_key_len(&ctl, &out_len);
			if (status < 0) {
				 KSSL_DEBUG(6, "kssl_ctl_get: "
					"kssl_ctl_cert_key_len\n");
			}
			get_func = kssl_ctl_cert_key_write;
			get_func_tag = "kssl_ctl_cert_key_write";
			break;
		case KSSL_CTL_CERT_KEY_LEN:
			out_len = sizeof(u32);
			get_func = kssl_ctl_cert_key_len_write;
			get_func_tag = "kssl_ctl_cert_key_len_write";
			break;
		case KSSL_CTL_R_IP_PORT:
			out_len = sizeof(u32) + sizeof(u16);
			get_func = kssl_ctl_r_ip_port_write;
			get_func_tag = "kssl_ctl_r_ip_port_write";
			break;
		case KSSL_CTL_CIPHERS_LEN:
			out_len = sizeof(u32);
			get_func = kssl_ctl_ciphers_len_write;
			get_func_tag = "kssl_ctl_ciphers_len_write";
			break;
		case KSSL_CTL_CIPHERS:
			status = kssl_ctl_ciphers_len(&ctl, &out_len);
			if (status < 0) {
				 KSSL_DEBUG(6, "kssl_ctl_get: "
					"kssl_ctl_ciphers_len\n");
			}
			get_func = kssl_ctl_ciphers_write;
			get_func_tag = "kssl_ctl_ciphers_write";
			break;
		case KSSL_CTL_DAEMON_MODE:
			out_len = sizeof(u32);
			get_func = kssl_ctl_daemon_mode_write;
			get_func_tag = "kssl_ctl_daemon_mode_write";
			break;
		case KSSL_CTL_ASYM_METHODS_LEN:
			out_len = sizeof(u32);
			get_func = kssl_ctl_asym_methods_len_write;
			get_func_tag = "kssl_ctl_asym_methods_len_write";
			break;
		case KSSL_CTL_ASYM_METHODS:
			status = kssl_ctl_asym_methods_len(&ctl, &out_len);
			if (status < 0) {
				 KSSL_DEBUG(6, "kssl_ctl_get: "
					"kssl_ctl_asym_methods_len\n");
			}
			get_func = kssl_ctl_asym_methods_write;
			get_func_tag = "kssl_ctl_asym_methods_write";
			break;
		case KSSL_CTL_DAEMONS_LEN:
			out_len = sizeof(u32);
			get_func = kssl_ctl_daemons_len_write;
			get_func_tag = "kssl_ctl_daemons_len_write";
			break;
		case KSSL_CTL_DAEMONS:
			status = kssl_ctl_daemons_len(&ctl, &out_len);
			if (status < 0) {
				 KSSL_DEBUG(6, "kssl_ctl_get: "
					"kssl_ctl_daemons_len\n");
			}
			get_func = kssl_ctl_daemons_write;
			get_func_tag = "kssl_ctl_daemons_write";
			break;
		default:
			 KSSL_DEBUG(6, "kssl_ctl_get: "
					"unknown/unsported request\n");
			status = -EINVAL;
			goto leave_free;
	}

	if (*len < sizeof(kssl_ctl_t) + out_len) {
		 KSSL_DEBUG(3, "kssl_ctl_get: len %u < %u\n",
				*len, sizeof(kssl_ctl_t) + out_len);
		return -EINVAL;
	}

	MOD_INC_USE_COUNT;
	if (down_interruptible(&kssl_sockopt_mutex)) {
		status = -ERESTARTSYS;
		goto leave_dec;
	}

	tmp_len = *len-sizeof(kssl_ctl_t);
	status = get_func(&ctl, kbuf+sizeof(kssl_ctl_t), &tmp_len);
	if (status < 0) {
		 KSSL_DEBUG(6, "kssl_ctl_get: %s\n", get_func_tag);
		goto leave_up;
	}

	*len = tmp_len + sizeof(kssl_ctl_t);
	if (copy_to_user(ubuf, kbuf, *len)) {
		goto leave_up;
		status = -EFAULT;
	}

	status = 0;
leave_up:
	up(&kssl_sockopt_mutex);
leave_dec:
	MOD_DEC_USE_COUNT;
leave_free:
	if (kbuf)
		kssl_kfree(kbuf);
	return status;
}

static struct nf_sockopt_ops kssl_sockopts = {
	{ NULL, NULL }, PF_INET,
	KSSL_CTL_BASE, KSSL_CTL_LAST+1, kssl_ctl_set,
	KSSL_CTL_BASE, KSSL_CTL_LAST+1, kssl_ctl_get
};


int __init
kssl_ctl_init(void)
{
	int status = 0;

	status = nf_register_sockopt(&kssl_sockopts);
	if (status) {
		 KSSL_DEBUG(6, "kssl_ctl_init: could register sockopt");
		return status;
	}
		

	return 0;
}


void __exit
kssl_ctl_cleanup(void)
{
	nf_unregister_sockopt(&kssl_sockopts);
}

