#include "maxent.h"

#include "command.h"
#include "fft.h"
#include "mem.h"
#include "opus.h"
#include "parse.h"
#include "phase.h"
#include "utility.h"

#define  RETURN_MAXENT_ERROR \
	 {   if (ndim == 1) \
	         sprintf(error_msg, "'maxent': %s", err_msg); \
	     else \
	         sprintf(error_msg, "'maxent%d': %s", ndim, err_msg); \
	     return  ERROR;   }

#define  CHECK_GREATER_THAN_ZERO(x, msg) \
	 {   if (x <= 0) \
	     {   sprintf(err_msg, "%s = %f, must be > 0", msg, x); \
		 RETURN_MAXENT_ERROR;   }   }

#define  CHECK_DIM_NOT_THERE(string) \
	 {   if (dim >= 0) \
	     {   sprintf(err_msg, "'%s' should be before any 'dim'", string); \
		 RETURN_MAXENT_ERROR;   }   }

#define  CHECK_DIM_THERE(string) \
	 {   if (dim < 0) \
	     {   sprintf(err_msg, "'%s' found before any 'dim'", string); \
		 RETURN_MAXENT_ERROR;   }   }

#define  FOUND_TWICE(string) \
         {   sprintf(err_msg, "'%s' found twice", string); \
	     RETURN_MAXENT_ERROR;   }

#define  FOUND_TWICE_FOR_DIM(string) \
         {   sprintf(err_msg, "'%s' found twice for 'dim' %d", string, dim+1); \
	     RETURN_MAXENT_ERROR;   }

#define  CHECK_CONDITIONS(found, string) \
	 {   CHECK_DIM_NOT_THERE(string); \
	     if (found)  FOUND_TWICE(string); \
	     found = TRUE;   }

#define  CHECK_CONDITIONS_FOR_DIM(found, string) \
	 {   CHECK_DIM_THERE(string); \
	     if (found)  FOUND_TWICE_FOR_DIM(string); \
	     found = TRUE;   }

static int ncodes = 0;
static int ndim;
static int dim;
static int tot_npts_in;
static int tot_npts_out;
static int npoints_in[MAX_NDIM];
static int npts_sample[MAX_NDIM];
static int npoints_sample[MAX_NDIM];
static int npoints_out[MAX_NDIM];
static int npoints_ignore[MAX_NDIM];
static int *ignore_list[MAX_NDIM];
static int *ignore_points[MAX_NDIM];
static int *sample_list[MAX_NDIM];
static int data_type[MAX_NDIM];
static int max_niter = 40;
static int phase_code[MAX_NDIM];
static int fft_code[MAX_NDIM];

static float decay[MAX_NDIM];
static float noise = 100;
static float rate = 1;
static float def = 1;

static int count;

static FILE *log_file = NULL;

static Bool positive = FALSE;
static Bool have_decay[MAX_NDIM];
static Bool have_phase[MAX_NDIM];
static Bool dim_found[MAX_NDIM];

static Bool npts_found;
static Bool ignore_found;
static Bool sample_found[MAX_NDIM];
static Bool decay_found;
static Bool phase_found;
static Bool complex_found;

static Ignore_func ignore_func;

static Line err_msg;

void do_maxent(int code, float *data)
{
    if (log_file)
	fprintf(log_file, "working on data number %d\n\n", ++count);

    (*ignore_func)(data);

    do_mem(data);

    if (log_file)
	fflush(log_file);
}

static Status dim_parse(Generic_ptr *var, String error_msg)
{
    int *d = (int *) var[0];

    dim = *d;

    if (dim > ndim)
    {
	sprintf(err_msg, "'dim' = %d > 'ndim' = %d", dim, ndim);
	RETURN_MAXENT_ERROR;
    }

    if (dim < 1)
    {
	sprintf(err_msg, "'dim' = %d < 1", dim);
	RETURN_MAXENT_ERROR;
    }

    dim--;

    if (dim_found[dim])
	FOUND_TWICE_FOR_DIM("dim");

    dim_found[dim] = TRUE;

    npts_found = ignore_found = sample_found[dim] = decay_found =
				phase_found = complex_found = FALSE;

    return  OK;
}

static Status iter_parse(Generic_ptr *var, String error_msg)
{
    int *d = (int *) var[0];
    static Bool found = FALSE;

    CHECK_CONDITIONS(found, "iter");

    max_niter = *d;

    return  OK;
}

static Status positive_parse(Generic_ptr *var, String error_msg)
{
    static Bool found = FALSE;

    CHECK_CONDITIONS(found, "positive");

    positive = TRUE;

    return  OK;
}

static Status rate_parse(Generic_ptr *var, String error_msg)
{
    float *d = (float *) var[0];
    static Bool found = FALSE;

    CHECK_CONDITIONS(found, "rate");

    rate = *d;

    CHECK_GREATER_THAN_ZERO(rate, "rate");

    return  OK;
}

static Status def_parse(Generic_ptr *var, String error_msg)
{
    float *d = (float *) var[0];
    static Bool found = FALSE;

    CHECK_CONDITIONS(found, "def");

    def = *d;

    CHECK_GREATER_THAN_ZERO(def, "def");

    return  OK;
}

static Status noise_parse(Generic_ptr *var, String error_msg)
{
    float *d = (float *) var[0];
    static Bool found = FALSE;

    CHECK_CONDITIONS(found, "noise");

    noise = *d;

    CHECK_GREATER_THAN_ZERO(noise, "noise");

    return  OK;
}

static Status log_parse(Generic_ptr *var, String error_msg)
{
    String file_name;
    static Bool found = FALSE;

    CHECK_CONDITIONS(found, "log");

    file_name = (char *) var[0];

    if ((log_file = fopen(file_name, WRITE)) == NULL)
    {
	sprintf(error_msg, "'maxent': opening '%s' as log file", file_name);
	return  ERROR;
    }

    return  OK;
}

static Status npts_parse(Generic_ptr *var, String error_msg)
{
    int *d = (int *) var[0];

    CHECK_CONDITIONS_FOR_DIM(npts_found, "npts");

    npoints_out[dim] = *d;

    if (npoints_out[dim] != ceil_power_of_2(npoints_out[dim]))
    {
	sprintf(err_msg, "'dim' %d: 'npts' = %d, must be power of 2",
						dim+1, npoints_out[dim]);
	RETURN_MAXENT_ERROR;
    }

    if (npoints_out[dim] < npoints_in[dim])
    {
	sprintf(err_msg, "'dim' %d: 'npts' = %d, must be >= %d",
				dim+1, npoints_out[dim], npoints_in[dim]);
	RETURN_MAXENT_ERROR;
    }

    return  OK;
}

static Status alloc_ignore_memory()
{
    MALLOC(ignore_points[dim], int, npoints_ignore[dim]);

    return  OK;
}

static Status ignore_parse(Generic_ptr *var, String error_msg)
{
    int i;
    int *d = (int *) var[0];

    CHECK_CONDITIONS_FOR_DIM(ignore_found, "ignore");

    npoints_ignore[dim] = *d;

    if (npoints_ignore[dim] < 1)
    {
	sprintf(err_msg, "'dim' %d: 'ignore' has %d points, must be > 0",
						dim+1, npoints_ignore[dim]);
	RETURN_MAXENT_ERROR;
    }

    if (npoints_ignore[dim] >= npoints_in[dim])
    {
	sprintf(err_msg, "'dim' %d: 'ignore' has %d points, must be < %d",
				dim+1, npoints_ignore[dim], npoints_in[dim]);
	RETURN_MAXENT_ERROR;
    }

    if (alloc_ignore_memory() == ERROR)
    {
	sprintf(err_msg, "'dim' %d: allocating memory for 'ignore'", dim+1);
	RETURN_MAXENT_ERROR;
    }

    for (i = 0; i < npoints_ignore[dim]; i++)
    {
	d = (int *) var[i+1];

	if (*d > npoints_in[dim])
	{
	    sprintf(err_msg, "'dim' %d: 'ignore' point #%d = %d, must be <= %d",
				dim+1, i+1, *d, npoints_in[dim]);
	    RETURN_MAXENT_ERROR;
	}

	if ((i > 0) && (*d <= ignore_points[dim][i-1]))
	{
	    sprintf(err_msg,
			"'dim' %d: 'ignore' point #%d = %d <= point #%d = %d",
				dim+1, i+1, *d, i, ignore_points[dim][i-1]);
	    RETURN_MAXENT_ERROR;
	}
	else if ((i == 0) && (*d < 1))
	{
	    sprintf(err_msg, "'dim' %d: 'ignore' point #%d = %d, must be > 0",
							dim+1, i+1, *d);
	    RETURN_MAXENT_ERROR;
	}

	ignore_points[dim][i] = *d - 1;
    }

    return  OK;
}

static Status sample_parse(Generic_ptr *var, String error_msg)
{
    int i, j, n;
    int *d = (int *) var[0];

    CHECK_DIM_THERE("sample");

    n = *d;

    if (n < 1)
    {
	sprintf(err_msg, "'dim' %d: a 'sample' has %d points, must be > 0",
								dim+1, n);
	RETURN_MAXENT_ERROR;
    }

    npts_sample[dim] += n;
    if (npts_sample[dim] > npoints_in[dim])
    {
	sprintf(err_msg,
	    "'dim' %d: 'sample' has >= %d points, must be = %d (in total)",
				dim+1, npts_sample[dim], npoints_in[dim]);
	RETURN_MAXENT_ERROR;
    }

    for (i = 0; i < n; i++)
    {
	d = (int *) var[i+1];

	j = i + npts_sample[dim] - n;
	if ((j > 0) && (*d <= sample_list[dim][j-1]))
	{
	    sprintf(err_msg,
			"'dim' %d: 'sample' point #%d = %d <= point #%d = %d",
				dim+1, j+1, *d, j, sample_list[dim][j-1]+1);
	    RETURN_MAXENT_ERROR;
	}
	else if ((j == 0) && (*d < 1))
	{
	    sprintf(err_msg, "'dim' %d: 'sample' point #%d = %d, must be > 0",
							dim+1, i+1, *d);
	    RETURN_MAXENT_ERROR;
	}

	sample_list[dim][j] = *d - 1;
    }

    sample_found[dim] = TRUE;

    return  OK;
}

static Status complex_parse(Generic_ptr *var, String error_msg)
{
    CHECK_CONDITIONS_FOR_DIM(complex_found, "complex");

    data_type[dim] = COMPLEX_DATA;

    return  OK;
}

static Status decay_parse(Generic_ptr *var, String error_msg)
{
    float *d = (float *) var[0];

    CHECK_CONDITIONS_FOR_DIM(decay_found, "decay");

    decay[dim] = *d;

    CHECK_GREATER_THAN_ZERO(decay[dim], "decay");

/*
    decay[dim] = exp(log(decay[dim]) / (npoints_in[dim]/2 - 1));
*/
    decay[dim] = exp(-log(decay[dim]) / (npoints_in[dim]/2 - 1));
		/* convert to decay per point */

    have_decay[dim] = TRUE;

    return  OK;
}

static Status phase_parse(Generic_ptr *var, String error_msg)
{
    CHECK_CONDITIONS_FOR_DIM(phase_found, "phasing");

    if (npoints_out[dim] == 0)
    {
	sprintf(err_msg, "'dim' %d: must have 'npts' before 'phase'", dim+1);
	RETURN_MAXENT_ERROR;
    }

    have_phase[dim] = TRUE;

    return  init_phase_max(phase_code+dim, npoints_out[dim], var, error_msg);
}

static Status phase2_parse(Generic_ptr *var, String error_msg)
{
    CHECK_CONDITIONS_FOR_DIM(phase_found, "phasing");

    if (npoints_out[dim] == 0)
    {
	sprintf(err_msg, "'dim' %d: must have 'npts' before 'phase2'", dim+1);
	RETURN_MAXENT_ERROR;
    }

    have_phase[dim] = TRUE;

    return  init_phase2_max(phase_code+dim, npoints_out[dim], var, error_msg);
}

static Status alloc_maxent_memory()
{
    int i, j;

    for (i = 0; i < ndim; i++)
    {
	MALLOC(ignore_list[i], int, npoints_in[i]);
	MALLOC(sample_list[i], int, npoints_in[i]);

	npts_sample[i] = npoints_ignore[i] = 0;
	npoints_sample[i] = npoints_in[i];

	for (j = 0; j < npoints_in[i]; j++)
	    ignore_list[i][j] = sample_list[i][j] = j;
    }

    return  OK;
}

static void setup_lists()
{
    int i, j, k, *d_in, *d_out;

    for (i = 0; i < ndim; i++)
    {
	for (j = 0; j < npoints_ignore[i]; j++)
	{
	    d_in = ignore_list[i] + ignore_points[i][j] + 1;
	    d_out = d_in - j - 1;

	    if (j < (npoints_ignore[i] - 1))
		k = ignore_points[i][j+1] - ignore_points[i][j] - 1;
	    else
		k = npoints_in[i] - ignore_points[i][j] - 1;

	    COPY_VECTOR(d_out, d_in, k);
	}

	npoints_in[i] -= npoints_ignore[i];

	for (j = 0; j < npoints_in[i]; j++)
	    sample_list[i][j] = sample_list[i][ignore_list[i][j]];
    }
}

static Status sample_consistency(String error_msg)
{
    int i, n;

    for (i = 0; i < ndim; i++)
    {
	if (!sample_found[i])
	    continue;

	n = npts_sample[i];
	if (n != npoints_sample[i])
	{
	    sprintf(err_msg,
		"'dim' %d: 'sample' has %d points in total, must be = %d",
						i+1, n, npoints_in[i]);
	    RETURN_MAXENT_ERROR;
	}

	if ((sample_list[i][n-1] + 1) > npoints_out[i])
	{
	    sprintf(err_msg, "'dim' %d: 'npts' = %d, must be >= %d",
				i+1, npoints_out[i], sample_list[i][n-1]+1);
	    RETURN_MAXENT_ERROR;
	}
    }

    return  OK;
}

static Status start_maxent(int n, int *type, int *npts_in, String error_msg)
{
    int i;

    ndim = n;

    if ((ndim < 1) || (ndim > 3))  /* this should never happen */
    {
	sprintf(err_msg, "number of dims. must be 1, 2, or 3");
	RETURN_MAXENT_ERROR;
    }

    if (ncodes > 0)
    {
	sprintf(err_msg, "can only do one maxent");
	RETURN_MAXENT_ERROR;
    }

    for (i = 0; i < ndim; i++)
    {
	if (npts_in[i] % 2)
	{
	    sprintf(err_msg, "'dim' %d: must have even number of points", i+1);
	    RETURN_MAXENT_ERROR;
	}

	npoints_in[i] = npts_in[i];
	data_type[i] = type[i];
    }

    if (alloc_maxent_memory() == ERROR)
    {
	sprintf(err_msg, "allocating memory for lists");
	RETURN_MAXENT_ERROR;
    }

    dim = -1;

    for (i = 0; i < ndim; i++)
    {
	dim_found[i] = FALSE;
	npoints_out[i] = 0;
	have_phase[i] = FALSE;
	have_decay[i] = FALSE;
    }

    return  OK;
}

static Status end_maxent(int *type, int *npts_out, String error_msg)
{
    int i;
    float scale;
    Transform opus, tropus;

    for (i = 0; i < ndim; i++)
    {
	if (!dim_found[i])
	{
	    sprintf(err_msg, "'dim' %d: not found", i+1);
	    RETURN_MAXENT_ERROR;
	}

	if (npoints_out[i] == 0)
	{
	    sprintf(err_msg, "'dim' %d: 'npts' not found", i+1);
	    RETURN_MAXENT_ERROR;
	}
    }

    CHECK_STATUS(sample_consistency(error_msg));

/*  allow real data now, 18 Jun 96
    for (i = 0; i < ndim; i++)
    {
	if (data_type[i] != COMPLEX_DATA)
	{
	    sprintf(err_msg, "'dim' %d: must have complex data", i+1);
	    RETURN_MAXENT_ERROR;
	}
    }
*/

    setup_lists();

    for (i = 0; i < ndim; i++)
    	CHECK_STATUS(init_fft_max(fft_code+i, 2*npoints_out[i], data_type[i], error_msg));

    if (init_opus(ndim, npoints_in, npoints_out,
		npoints_ignore, ignore_list, &ignore_func,
		npoints_sample, sample_list,
		fft_code, have_phase, phase_code, have_decay, decay,
		&opus, &tropus, error_msg) == ERROR)
	return  ERROR;

    tot_npts_in = tot_npts_out = 1;
    for (i = 0; i < ndim; i++)
    {
	tot_npts_in *= npoints_in[i];
	tot_npts_out *= npoints_out[i];
    }

    scale = tot_npts_out;
    for (i = 0; i < ndim; i++)
	scale *= 0.5;

    if (init_mem(tot_npts_in, tot_npts_out, max_niter, positive, noise,
		rate, def, scale, log_file, opus, tropus, error_msg) == ERROR)
	return  ERROR;

    for (i = 0; i < ndim; i++)
    {
	type[i] = REAL_DATA;
	npts_out[i] = npoints_out[i];
    }

    ncodes++;

    count = 0;

    return  OK;
}

static int parse_int[] = { PARSE_INT };
static int parse_float[] = { PARSE_FLOAT };
static int parse_intfree[] = { PARSE_INT | PARSE_FREE };
static int parse_string[] = { PARSE_STRING };
static int parse_float2[] = { PARSE_FLOAT, PARSE_FLOAT };
static int parse_float3[] = { PARSE_FLOAT, PARSE_FLOAT, PARSE_FLOAT };

static Parse_line maxent_table[] =
{
    { "iter",		1,	parse_int,		iter_parse },
    { "positive",	0,	NULL,			positive_parse },
    { "rate",		1,	parse_float,		rate_parse },
    { "def",		1,	parse_float,		def_parse },
    { "noise",		1,	parse_float,		noise_parse },
    { "log",		1,	parse_string,		log_parse },
    { "dim",		1,	parse_int,		dim_parse },
    { "npts",		1,	parse_int,		npts_parse },
    { "ignore",		1,	parse_intfree,		ignore_parse },
    { "sample",		1,	parse_intfree,		sample_parse },
    { "complex",	0,	NULL,			complex_parse },
    { "decay",		1,	parse_float,		decay_parse },
    { "phase",		2,	parse_float2,		phase_parse },
    { "phase2",		3,	parse_float3,		phase2_parse },
    { NULL,		0,	NULL,			NULL }
};

Status setup_maxents(int n, int *type, int *npts_in, int *npts_out,
						String file, String error_msg)
{
    CHECK_STATUS(start_maxent(n, type, npts_in, error_msg));
    CHECK_STATUS(parse_file(file, maxent_table, TRUE, error_msg));
    CHECK_STATUS(end_maxent(type, npts_out, error_msg));

    return  OK;
}

Status setup_maxents_com(int n, int *type, int *npts_in, int *npts_out,
						FILE *fp, String error_msg)
{
    CHECK_STATUS(start_maxent(n, type, npts_in, error_msg));
    CHECK_STATUS(parse_subfile(fp, maxent_table, TRUE, "end_maxent",error_msg));
    CHECK_STATUS(end_maxent(type, npts_out, error_msg));

    return  OK;
}

Status init_maxent(Generic_ptr *param, String error_msg)
{
    int type, npts_in, npts_out;
    char *file = (char *) param[0];

    if (setup_command(&type, &npts_in, ncodes, "maxent",
					do_maxent, error_msg) == ERROR)
	return  ERROR;

    if (setup_maxents(1, &type, &npts_in, &npts_out, file, error_msg) == ERROR)
	return  ERROR;

    CHECK_STATUS(end_command(type, npts_out, "maxent", error_msg));

    return  OK;
}

Status init_maxent_com(Generic_ptr *param, String error_msg)
{
    int type, npts_in, npts_out;
    FILE *fp = (FILE *) param[0];

    if (setup_command(&type, &npts_in, ncodes, "maxent_com",
					do_maxent, error_msg) == ERROR)
	return  ERROR;

    if (setup_maxents_com(1, &type, &npts_in, &npts_out, fp,
							error_msg) == ERROR)
	return  ERROR;

    CHECK_STATUS(end_command(type, npts_out, "maxent_com", error_msg));

    return  OK;
}
