#include "fit.h"

#include "line_fit.h"
#include "nonlinear_model.h"

#define  MAX_MODEL_ITER  20
#define  MAX_CONDITION  4
#define  CHISQ_STOP_CRITERION  (1.0e-1)
 
#define  MIN_DELTA_X  (1.0e-2)

#define  CHECK_NONLINEAR_MODEL(stage) \
	 {   nonlinear_model(x, y, w, n, params_fit, covar, \
				alpha, beta, da, ap, dy_da, \
				piv, row, col, nparams, &chisq, \
				&lambda, f, stage, &singular); \
	     if (singular)  {   RETURN_ERROR_MSG("singular data");   }   }

static int nalloc_log_linear = 0;
static int *used_indices = NULL;
static float *yy = NULL;

static int nalloc_nonlinear = 0;
static float *w;
static float **covar;
static float **alpha;
static float *beta;
static float *da;
static float *ap;
static float *dy_da;
static int *piv;
static int *row;
static int *col;

static void free_log_linear_memory()
{
    FREE(used_indices, int);
    FREE(yy, float);
}

static Status alloc_log_linear_memory(int n, String error_msg)
{
    sprintf(error_msg, "allocating log linear memory");
 
    if (n < nalloc_log_linear)
	return  OK;
 
    free_log_linear_memory();

    MALLOC(used_indices, int, n);
    MALLOC(yy, float, n);
 
    nalloc_log_linear = n;

    return  OK;
}

static void free_nonlinear_memory()
{
    static int nparams = 3;

    FREE(w, float);

    FREE(beta, float);
    FREE(da, float);
    FREE(ap, float);
    FREE(dy_da, float);
    FREE2(covar, float, nparams);
    FREE2(alpha, float, nparams);

    FREE(piv, int);
    FREE(row, int);
    FREE(col, int);
}

static Status alloc_nonlinear_memory(int n, String error_msg)
{
    static int nparams = 3;
    int i;

    sprintf(error_msg, "allocating nonlinear memory");
 
    if (n < nalloc_nonlinear)
	return  OK;
 
    free_nonlinear_memory();

    MALLOC(w, float, n);
    for (i = 0; i < n; i++)
	w[i] = 1; /* weight */

    MALLOC(beta, float, nparams);
    MALLOC(da, float, nparams);
    MALLOC(ap, float, nparams);
    MALLOC(dy_da, float, nparams);
    MALLOC2(covar, float, nparams, nparams);
    MALLOC2(alpha, float, nparams, nparams);

    MALLOC(piv, int, nparams);
    MALLOC(row, int, nparams);
    MALLOC(col, int, nparams);
 
    nalloc_nonlinear = n;

    return  OK;
}

static Status fit_linear(int n, float *x, float *y,
	 float *params_fit, float *y_fit, float *chisq, String error_msg)
{
    float *sigma = NULL, a, b, std_a, std_b, corr_ab, goodness;

/*
int i;
printf("n = %d\n", n);
for (i = 0; i < n; i++) printf("  %d: x = %4.3e y = %4.3e\n", i, x[i], y[i]);
*/
    line_fit(n, x, y, sigma, y_fit, &a, &b, &std_a, &std_b, &corr_ab, &goodness);
/*
for (i = 0; i < n; i++) printf("  %d: yfit = %4.3e\n", i, y_fit[i]);
*/

    params_fit[0] = b; /* NOTE a and b backwards! */
    params_fit[1] = a;
    *chisq = goodness / n;

    return  OK;
}

static Status fit_log_linear(int n, float *x, float *y,
	 float *params_fit, float *y_fit, float *chisq, String error_msg)
{
    int i, j, imax, nused;
    float ymax, s, *sigma = NULL, a, b, std_a, std_b, corr_ab, goodness, d, d2;
    
    CHECK_STATUS(alloc_log_linear_memory(n, error_msg));

    imax = 0;
    ymax = ABS(y[0]);

    for (i = 1; i < n; i++)
    {
	if (ABS(y[i]) > ymax)
	{
	    imax = i;
	    ymax = ABS(y[i]);
	}
    }

    if (ymax == 0)
	RETURN_ERROR_MSG("y's all zero");

    if (y[imax] > 0)
	s = 1;
    else
	s = -1;

    nused = 0;
    for (i = j = 0; i < n; i++)
    {
	yy[j] = s * y[i];

	if (yy[j] > 0)
	{
	    yy[j] = (float) log((double) yy[j]);
	    used_indices[j] = i;
	    j++;
	    nused++;
	}
    }

    if (nused < 2)
	RETURN_ERROR_MSG("not enough valid data points");

    line_fit(nused, x, yy, sigma, y_fit, &a, &b, &std_a, &std_b,
						&corr_ab, &goodness);

    params_fit[0] = s * (float) exp((double) a);
    params_fit[1] = - b;

    for (i = 0; i < nused; i++)
    {
	y_fit[i] = (float) exp((double) y_fit[i]);
	j = used_indices[i];
	d = y[j] - y_fit[i];
	d2 = d * d;
    }

    *chisq = d2 / nused;

    if (nused < n) /* put back in missing terms, zeroed (arbitrary) */
    {
	for (i = nused-1; i >= 0; i--)
	{
	    y_fit[used_indices[i]] = y_fit[i];

	    if (i > 0)
	    {
		for (j = used_indices[i-1]+1; j < used_indices[i]; j++)
		    y_fit[j] = 0;
	    }
	}
    }

    return  OK;
}

static void nonlinear_func(float x, int nparams, float *params, float *y, float *dy_da)
{
    float a = params[0], b = params[1], c = 0;
    float s = (float) exp((float) (-b*x));
    float t = a * s;

    if (nparams == 3)
    {
	c = params[2];
	dy_da[2] = 1;
    }

    *y = t + c;
    dy_da[0] = s;
    dy_da[1] = - x * t;
}

static void nonlinear2_func(float x, float *params, float *y, float *dy_da)
{
    nonlinear_func(x, 2, params, y, dy_da);
}

static void nonlinear3_func(float x, float *params, float *y, float *dy_da)
{
    nonlinear_func(x, 3, params, y, dy_da);
}

static void init_params_fit(int method, int n,
			float *x, float *y, float *params_fit)
{
    int i, n1, n2;
    float y1, y2, abs_y;

    n1 = n2 = -1;
    y1 = y2 = -1;

    for (i = 0; i < n; i++)
    {
	if (n1 >= 0)
	{
	    if (ABS(x[i] - x[n1]) < MIN_DELTA_X)
		continue;
	}

	if (n2 >= 0)
	{
	    if (ABS(x[i] - x[n2]) < MIN_DELTA_X)
		continue;
	}

	abs_y = ABS(y[i]);

	if (abs_y > y1)
	{
	    n2 = n1;
	    y2 = y1;
	    n1 = i;
	    y1 = abs_y;
	}
	else if (abs_y > y2)
	{
	    n2 = i;
	    y2 = abs_y;
	}
    }

    if (n2 < 0)
    {
	params_fit[0] = y1;
	params_fit[1] = 0;
    }
    else
    {
	/* in theory y[n1] and y[n2] could have opposite signs */
	/* but very unlikely so ignore this */
	params_fit[1] = (float) log((double) (y1/y2)) / (x[n2] - x[n1]);
	params_fit[0] = y[n1] / ((float) exp((double) (-params_fit[1]*x[n1])));
    }

    if (method == NONLINEAR3_FIT)
	params_fit[2] = 0;
}

static Status fit_nonlinear(int method, int n, float *x, float *y,
	 float *params_fit, float *y_fit, float *p_chisq, String error_msg)
{
    int i, cond, iter, nparams = get_method_nparams(method);
    float lambda, chisq, old_chisq, chisq_stop_criterion, chisq_scale;
    Bool singular;
    Nonlinear_model_func f;

/*
printf("method = %s, n = %d\n", fit_methods[method], n);
*/

    CHECK_STATUS(alloc_nonlinear_memory(n, error_msg));

/*
    params_fit[0] = y[0];
    params_fit[1] = 0;
*/
    init_params_fit(method, n, x, y, params_fit);

    if (method == NONLINEAR2_FIT)
    {
	f = nonlinear2_func;
    }
    else
    {
	f = nonlinear3_func;
/*
	params_fit[2] = 0;
*/
    }

    chisq_scale = 100.0 / n;
    chisq_stop_criterion = chisq_scale * CHISQ_STOP_CRITERION;
 
    CHECK_NONLINEAR_MODEL(INITIAL_STAGE);
 
    for (iter = cond = 0; (iter < MAX_MODEL_ITER) &&
				(cond < MAX_CONDITION); iter++)
    {
/*
	printf("iteration %d\n", iter);
*/
 
	old_chisq = chisq;
 
	CHECK_NONLINEAR_MODEL(GENERAL_STAGE);
 
	if (chisq > old_chisq)
	    cond = 0;
	else if ((old_chisq - chisq) < chisq_stop_criterion)
	    cond++;

/*
printf("iter = %d, cond = %d, lambda = %4.3e, chisq = %4.3e, params0 = %4.3e, params1 = %4.3e",
                iter, cond, lambda, chisq, params_fit[0], params_fit[1]);
if (method == NONLINEAR3_FIT)
printf(", params2 = %4.3e", params_fit[2]);
printf("\n");
*/
    }

/*
printf("chisq = %4.3e, params0 = %4.3e, params1 = %4.3e\n",
                chisq, params_fit[0], params_fit[1]);
for (i = 0; i < n; i++)
        printf("i = %d, x = %4.3e, y = %4.3e, yfit = %4.3e\n", i, x[i], y[i],
                (float) (params_fit[0]*exp((double) (-params_fit[1]*x[i]))));
*/

    if (iter == MAX_MODEL_ITER)
	RETURN_ERROR_MSG("fit did not converge");

    for (i = 0; i < n; i++)
	y_fit[i] = calculate_fit(method, x[i], params_fit);

    *p_chisq = chisq;

    return  OK;
}

Status fit_data(int method, int ndata, float *x, float *y,
	 float *params_fit, float *y_fit, float *chisq, String error_msg)
{
    if (ndata < get_method_nparams(method))
	return  OK;

    if (method == LINEAR_FIT)
	return fit_linear(ndata, x, y, params_fit, y_fit, chisq, error_msg);
    else if (method == LOG_LINEAR_FIT)
	return fit_log_linear(ndata, x, y, params_fit, y_fit, chisq, error_msg);
    else
	return fit_nonlinear(method, ndata, x, y, params_fit, y_fit,
							chisq, error_msg);

    return  OK;
}

int get_method_nparams(int method)
{
    if ((method < 0) || (method >= NFIT_METHODS))
	return 0;

    if (method == NO_FIT)
	return 0;
    else if (method == NONLINEAR3_FIT)
	return 3;
    else
	return 2;
}

float calculate_fit(int method, float x, float *params_fit)
{
    float a, b, c;

    a = params_fit[0];
    b = params_fit[1];

    if (method == NONLINEAR3_FIT)
	c =  params_fit[2];
    else
	c = 0;

    if (method == LINEAR_FIT)
	return a*x + b;
    else
	return ((float) a * exp((double) (-b*x))) + c;
}
