/**********************************************************************
 * message.c                                                August 2005
 *
 * KSSLD: An implementation of SSL/TLS in the Linux Kernel
 * Copyright (C) 2005  NTT COMWARE Corporation.
 *
 * This file based in part on code from LVS www.linuxvirtualserver.org
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
 * 02110-1301, USA.
 *
 **********************************************************************/

#include <linux/kernel.h>
#include <linux/net.h>
#include <linux/version.h>

#include "alert.h"
#include "record.h"
#include "handshake.h"
#include "message.h"
#include "css.h"
#include "app_data.h"
#include "kssl_alloc.h"
#include "log.h"

#include "types/change_cipher_spec_t.h"
#include "types/record_t.h"
#include "types/handshake_t.h"
#include "types/alert_t.h"

static int kssl_message_split(kssl_record_t **pool, size_t msg_len);
static int kssl_message_process_in_internal(kssl_record_t *cr);

#if 0
static void kssl_message_show_head(kssl_record_t *cr, const char *tag);
#endif

static int kssl_message_parse(kssl_record_t *cr, alert_t *alert);


kssl_message_t *kssl_message_create(content_type_t ct)
{
	kssl_message_t *msg;

	KSSL_DEBUG(12, "kssl_message_create: enter: %d\n", ct);

	msg = (kssl_message_t *)kssl_kmalloc(sizeof(kssl_message_t), 
			GFP_KERNEL);
	if (!msg)
		return NULL;

	memset(msg, 0, sizeof(kssl_message_t));
	msg->type = ct;

	return msg;
}


void kssl_message_destroy(kssl_message_t *msg)
{
	KSSL_DEBUG(12, "kssl_message_destroy: enter: %d\n", msg->type);

	if(msg->type == ct_handshake || msg->type == ct_ssl2)
        	kssl_handshake_destroy_data(&(msg->data.handshake));
	kssl_kfree(msg);
}

/* N.B: Pool is returned unchanged on error */

static int kssl_message_split(kssl_record_t **pool, size_t msg_len)
{
	size_t status;
	kssl_record_t *head;
	kssl_record_t *tail;

	KSSL_DEBUG(12, "kssl_message_split enter\n"); 

	head = *pool;
	tail = NULL;

	status = kssl_record_vec_split(head, &tail, msg_len);
	if (status < 0) {
		KSSL_DEBUG(6, "kssl_message_split: split vec failed\n");
		return status;
	}
	*pool = tail;
	if (tail) {
		memcpy(&(tail->record.head), &(head->record.head), 
			sizeof(tls_head_t));
		memcpy(&(tail->ssl2_head), &(head->ssl2_head), 
			sizeof(ssl2_head_t));
	}

	/*
	kssl_message_show_head(head, "message split");
	KSSL_DEBUG(6, "*head total_len=%d\n", head->total_len);
	{
	int i;
	for (i = 0; i < head->iov_len; i++) {
		printk(KERN_DEBUG " iov no %x: iov_len=%x iov_base=%p\n", i,
					((head->iov+i)->iov_len),
					((head->iov+i)->iov_base));
	}
	if (tail) {
		printk(KERN_DEBUG "tail total_len=%d\n", tail->total_len);
		for (i = 0; i < tail->iov_len; i++) {
			printk(KERN_DEBUG 
					" iov no %x: iov_len=%x iov_base=%p\n",
					i, ((tail->iov+i)->iov_len),
					((tail->iov+i)->iov_base));
		}
	}
	}
	*/

	/* Process the record for processing */
	/* Actually they are processed straight away,
	 * but the record needs to be on some list so
	 * we can find it if its connection closes */
	list_move_tail(&(head->list), &kssl_message_in_list);
	head->state = KSSL_CONN_RECORD_IN_STARTED;

	/* Try and process the reccord now */
	/* This is to allow change cipher spec messages
	 * to take effect before the an attemt is made
	 * to parse the next message */
	status = kssl_message_process_in_internal(head);
	if (status < 0) {
		KSSL_DEBUG(6, "kssl_message_split: "
				"kssl_message_process_in_internal\n");
		return status;
	}

	return 0;
}

#if 0
static void kssl_message_show_head(kssl_record_t *cr, 
		const char *tag)
{
	switch (cr->record.head.type) {
		case ct_change_cipher_spec:
			printk(KERN_DEBUG "%s: change cipher spec\n", tag);
			return;
		case ct_alert:
			printk(KERN_DEBUG "%s: alert: level=%d desc=%d\n",
					tag, cr->msg->alert.level,
					cr->msg->alert.description);
			return;
		case ct_handshake:
			printk(KERN_DEBUG "%s: handshake: type=%d length=%d\n",
					tag, cr->msg->handshake.msg_type,
					cr->msg->handshake.length);
			return;
		case ct_application_data:
			printk(KERN_DEBUG "%s: application data\n", tag);
			return;
		case ct_ssl2:
			printk(KERN_DEBUG "%s: ssl2 format message\n", tag);
			return;
		case ct_last:
			printk(KERN_DEBUG "%s: unknown\n", tag);
			return;
	}
	return;
}
#endif


static int 
kssl_message_parse_tls(kssl_record_t *cr) 
{
	u8 buf[4];

	KSSL_DEBUG(12, "kssl_message_parse_tls enter head.type=%d\n",
			cr->record.head.type);

	/* Verify Content Type */
	switch (cr->record.head.type) {
		case ct_change_cipher_spec:
			if (cr->content_len < cr->offset + 
					CHANGE_CIPHER_SPEC_NLEN) {
				KSSL_DEBUG(3, "kssl_message_parse_tls: "
						"head too short for "
						"change chipher spec: "
						"%u bytes\n", cr->content_len);
				return -EINVAL;
			}
			kssl_record_vec_cpy(cr, buf, 
					CHANGE_CIPHER_SPEC_NLEN, cr->offset);
			change_cipher_spec_from_buf(&(cr->msg->data.css), buf);
			if (cr->msg->data.css.type != ccst_change_cipher_spec) {
				KSSL_DEBUG(3, "kssl_message_parse_tls: "
						"invalid type of"
						"change chipher spec: %u\n",
						cr->msg->data.css.type);
				return -EINVAL;
			}
			KSSL_NOTICE(3, "ISSL009: Recv(ssl): CHANGE_CIPHER_SPEC (from client)\n");
			return CHANGE_CIPHER_SPEC_NLEN;
		case ct_alert:
			if (cr->content_len < cr->offset + ALERT_NLEN) {
				KSSL_DEBUG(3, "kssl_message_parse_tls: "
						"head too short for alert: "
						"%u bytes\n",
						cr->content_len);
				return -EINVAL;
			}
			kssl_record_vec_cpy(cr, buf, ALERT_NLEN,
					cr->offset);
			alert_from_buf(&(cr->msg->data.alert), buf);
			return ALERT_NLEN;
		case ct_handshake:
			if (cr->content_len < cr->offset + 
					HANDSHAKE_HEAD_NLEN) {
				KSSL_DEBUG(3, "kssl_message_parse_tls: "
						"head too short for "
						"handshake: %u bytes\n",
						cr->content_len);
				return -EINVAL;
			}
			kssl_record_vec_cpy(cr, buf, HANDSHAKE_HEAD_NLEN,
					cr->offset);
			handshake_head_from_buf(&(cr->msg->data.handshake), 
					buf);
			if (cr->content_len < cr->offset + 
					HANDSHAKE_HEAD_NLEN + 
					cr->msg->data.handshake.length) {
				KSSL_DEBUG(3, "kssl_message_parse_tls: "
						"body too short for "
						"handshake: %u < %u\n",
						cr->content_len, 
						HANDSHAKE_HEAD_NLEN + 
						cr->msg->data.handshake.length);
				return -EINVAL;
			}
			if (kssl_handshake_body_tls(cr) < 0) {
				KSSL_DEBUG(3, "kssl_message_parse_tls: "
						"kssl_handshake_body_tls\n");
				return -EINVAL;
			}
			return(HANDSHAKE_HEAD_NLEN + 
					cr->msg->data.handshake.length);
		case ct_application_data:
			return cr->content_len - cr->offset;
		case ct_last:
		default:
			KSSL_DEBUG(3, "kssl_message_parse_tls: "
				"unknown message type: %u\n",
				cr->record.head.type);
			return -EINVAL;
	}

	/* Not reached */
	return -EINVAL;
}


static int 
kssl_message_parse_ssl2(kssl_record_t *cr) 
{
	u8 buf;

	KSSL_DEBUG(12, "kssl_message_parse_ssl2: enter\n");
	/*
	printk(KERN_DEBUG "head total_len=%u\n", cr->total_len);
	{
	int i;
	for (i = 0; i < cr->iov_len; i++) {
		printk(KERN_DEBUG " iov no %u: iov_len=%u iov_base=%p\n", i,
				((cr->iov+i)->iov_len),
				((cr->iov+i)->iov_base));
										        }
	}
	*/

	kssl_record_vec_cpy(cr, &buf, 1, cr->offset);

	cr->msg->data.handshake.length = cr->ssl2_head.len;
	cr->msg->data.handshake.msg_type = buf;

	/* Just assume that it is a handshake,
	 * the parsing of the handshake body will
	 * fail if it isn't */
	if (kssl_handshake_body_ssl2(cr) < 0)
		return -EINVAL;

	return cr->msg->data.handshake.length;
}


static int 
kssl_message_parse(kssl_record_t *cr, alert_t *alert) 
{
	int len;

	KSSL_DEBUG(12, "kssl_message_parse: enter\n");

	cr->msg = kssl_message_create(cr->record.head.type);
	if (!cr->msg)
		return -EINVAL;

	if (cr->record.head.type == ct_ssl2)
		len = kssl_message_parse_ssl2(cr);
	else
		len = kssl_message_parse_tls(cr);

	if (len < 0) {
		alert->level = al_fatal;
		alert->description = ad_illegal_parameter;
		kssl_message_destroy(cr->msg);
		cr->msg = NULL;
		return len;
	}

	if (!kssl_message_mask_check(cr)) {
		if (cr->conn->msg_mask) {
			alert->level = al_fatal;
			alert->description = ad_unexpected_message;
		}
		else {
			/* If there is no mask, no messages are accepted.
			 * An alert will have already been sent.
			 * Thus set the level to 0 so a subsequent
			 * alert will not be sent */
			alert->level = 0;
		}
		kssl_message_destroy(cr->msg);
		cr->msg = NULL;
		return -EINVAL;
	}

	kssl_message_update_mask_in(cr);

	return len;
}


int 
kssl_message_in(kssl_record_t **cr, alert_t *alert) 
{
	int len;

	KSSL_DEBUG(12, "kssl_message_in: enter\n");

	(*cr)->offset = KSSL_CONN_HEAD_NLEN(*cr);

	while (*cr) {
		len = kssl_message_parse(*cr, alert);
		if (len < 0) {
			KSSL_DEBUG(6, "kssl_message_in: "
					"kssl_message_parse\n");
			return len;
		}
		len = kssl_message_split(cr, len);
		if (len < 0) {
			KSSL_DEBUG(6, "kssl_message_in: "
					"kssl_message_split\n");
			return len;
		}
	}

	return 0;
}


static inline int __kssl_message_process_in_internal(kssl_record_t *cr,
		alert_t *alert) 
{
	KSSL_DEBUG(12, "__kssl_message_process_in_internal: enter "
			"type=%d content_len=%d offset=%d\n", 
			cr->record.head.type, cr->content_len,
			cr->offset);

	if (cr->conn->sec_param_in_act)
		cr->conn->conn_state.in_seq++;

	/* Process Message */
	switch (cr->record.head.type) {
		case ct_change_cipher_spec:
			return kssl_change_cipher_spec_process(cr, alert);
		case ct_alert:
			return kssl_alert_process(cr, alert);
		case ct_handshake:
		/* non-client_hello handshake messages
		 * encoded in ssl2 records will be rejected
		 * in kssl_message_in(). 
		 * I.e. before  message_process_in() is called */
		case ct_ssl2:
			return kssl_handshake_process(cr, alert);
		case ct_application_data:
			return kssl_application_data_process(cr, alert);
		case ct_last:
		default:
			alert->level = al_fatal;
			alert->description = ad_illegal_parameter;
			return -EINVAL;
	}

	/* Not reached */
	return -EINVAL;
}


static int kssl_message_process_in_internal(kssl_record_t *cr)
{
	int status = 0;
	alert_t alert = { al_warning, ad_close };

	KSSL_DEBUG(12, "kssl_message_process_in_internal: enter "
			"type=%d content_len=%d\n", cr->record.head.type,
			cr->content_len);
	status = __kssl_message_process_in_internal(cr, &alert);
	if (status < 0) {
		KSSL_DEBUG(6, "kssl_message_process_in_internal: "
				"__kssl_message_process_in_internal: "
				"closing connection\n");
		if (kssl_alert_send(cr, &alert, 1) < 0) {
			kssl_conn_close(cr->conn, KSSL_CONN_SSL_CLOSE);
			kssl_record_destroy(cr);
			KSSL_DEBUG(6, "kssl_message_process_in_internal: "
					"kssl_alert_send failed\n");
		}
		/* Implied by kssl_alert_send(..., 1) */
		/* kssl_record_destroy(cr); */
	}

	return status;

}


int kssl_message_process_in(kssl_record_t *cr)
{
	alert_t alert = { al_warning, ad_close };
	KSSL_DEBUG(1, "BUG! kssl_message_process_in_internal enter:\n");

	if (kssl_alert_send(cr, &alert, 1) < 0) {
		kssl_conn_close(cr->conn, KSSL_CONN_SSL_CLOSE);
		kssl_record_destroy(cr);
		KSSL_DEBUG(3, "kssl_message_process_in: "
				"kssl_alert_send failed\n");
	}
	/* Implied by kssl_alert_send(..., 1) */
	/* kssl_record_destroy(cr); */

        return -EEXIST;
}

