#include "find.h"

#include "data.h"
#include "list.h"
#include "ref.h"
#include "sorts.h"

#define  SMALL_VALUE		1.0e-4

/* below defined so that fit starts with correct integer position  */
/*
#define  MAX_SHIFT		0.4
*/
#define  MAX_SHIFT		0.499

static int ndim;
static int *npoints;

static Ref_info *ref_info;

static Bool have_high;
static Bool have_low;
static float high;
static float low;
static Bool nonadjacent;
static Bool parabolic;
static int *buffer;
static int *first;
static int *last;
static Bool *periodic;
static int nexclude;
static Exclude *exclude;

static List find_list;
static Find_info **find_info;

static int nfind;
static int nhighs;
static int nlows;

static int nsetup;
static float base_value;

static int analysed_points;
static int nonadjacent_points;
static int nonadjacent_point_zero;

static int point[MAX_NDIM];
static int point2[MAX_NDIM];
static int cum_analysed_points[MAX_NDIM];
static int cum_nonadjacent_points[MAX_NDIM];

static Bool (*compare_func)(float w, float v);

static void init_arrays()
{
    int i;

    analysed_points = 1;
    for (i = 0; i < ndim; i++)
    {
	cum_analysed_points[i] = analysed_points;
	analysed_points *= last[i] - first[i];
    }

    nonadjacent_points = 1;
    for (i = 0; i < ndim; i++)
    {
	cum_nonadjacent_points[i] = nonadjacent_points;
	nonadjacent_points *= 3;
    }

    nonadjacent_point_zero = nonadjacent_points / 2;
}

static Status alloc_find_memory(String error_msg)
{
    sprintf(error_msg, "initializing find list");
    CHECK_STATUS(init_list(&find_list));

    return  OK;
}

static Bool valid_nth_point(int n, int *p)
{
    if (p[n] < 0)
    {
	if (!periodic[n])
	    return  FALSE;

	p[n] += npoints[n];
    }

    if (p[n] >= npoints[n])
    {
	if (!periodic[n])
	    return  FALSE;

	p[n] -= npoints[n];
    }

    return  TRUE;
}

static Bool valid_point(int *p)
{
    int i;

    for (i = 0; i < ndim; i++)
    {
	if (!valid_nth_point(i, p))
	    return  FALSE;
    }

    return  TRUE;
}

static Bool cmp_high(float w, float v)
{
    if (w > v)
	return  TRUE;
    else
	return  FALSE;
}

static Bool cmp_low(float w, float v)
{
    if (w < v)
	return  TRUE;
    else
	return  FALSE;
}

void print_point(float v, int *p)
{
    int i;

    printf("v = %2.1f, p = ", v);
    for (i = 0; i < ndim; i++)
	printf(" %3d", p[i]);

    printf("\n");
}

/*  Note that in two routines below do not need to reset
    array 'point' to original value if 'flag' is FALSE.  */

static Status check_nonadjacent(float v, Bool *flag, String error_msg)
{
    int i;
    float w;

    for (i = 0; i < ndim; i++)  /* go to corner */
	point[i]--;

    for (i = 0; i < nonadjacent_points; i++)
    {
	if (i == nonadjacent_point_zero)
	    continue;

	find_point(ndim, i, point2, cum_nonadjacent_points, point,
							npoints, FALSE);

	if (valid_point(point2))
	{
	    CHECK_STATUS(data_value(&w, point2, error_msg));
/*
	    print_point(w, point2);
*/
	    if ((*compare_func)(w, v))
	    {
		*flag = FALSE;
		return  OK;
	    }
	}
    }

    for (i = 0; i < ndim; i++)
	point[i]++;

    *flag = TRUE;

    return  OK;
}

static Status check_adjacent(float v, Bool *flag, String error_msg)
{
    int i, p;
    float w;

    for (i = 0; i < ndim; i++)
    {
	p = point[i];

	point[i] = p + 1;
	if (valid_nth_point(i, point))
	{
	    CHECK_STATUS(data_value(&w, point, error_msg));
/*
	    print_point(w, point);
*/
	    if ((*compare_func)(w, v))
	    {
		*flag = FALSE;
		return  OK;
	    }
	}

	point[i] = p - 1;
	if (valid_nth_point(i, point))
	{
	    CHECK_STATUS(data_value(&w, point, error_msg));
/*
	    print_point(w, point);
*/
	    if ((*compare_func)(w, v))
	    {
		*flag = FALSE;
		return  OK;
	    }
	}

	point[i] = p;
    }

    *flag = TRUE;

    return  OK;
}

static Status parabolic_fit(Find_info *f, String error_msg)
{
    int i, p;
    float u, w, d, v;
    double magn2;
    Bool singular;

    v = f->value;
    magn2 = v * v;

    singular = FALSE;
    for (i = 0; i < ndim; i++)
    {
	f->center[i] = p = point[i];

	point[i] = p + 1;

	if (!valid_nth_point(i, point))
	    continue;

	CHECK_STATUS(data_value(&w, point, error_msg));

	point[i] = p - 1;

	if (!valid_nth_point(i, point))
	    continue;

	CHECK_STATUS(data_value(&u, point, error_msg));

	d = 0.5 * ABS(2*v - u - w);

	if (d > SMALL_VALUE)  /* true unless u = v = w (unusual) */
	{
	    magn2 *= v / d;

	    d = 0.25 * ABS(w - u) / d;
	    d = MIN(d, MAX_SHIFT); /* probably not needed */

	    if (v > 0)
	    {
		if (w < u)
		    d = -d;
	    }
	    else
	    {
		if (w > u)
		    d = -d;
	    }

	    f->center[i] += d;
	}
	else
	{
	    singular = TRUE;
	}

	point[i] = p;
    }

    if (singular)
    {
	f->magnitude = 0;
    }
    else
    {
	magn2 = ABS(magn2);
	magn2 = sqrt(magn2);

/*  in wierd situations (e.g. low > 0) below might give crazy answer  */
	if (v > high)  /* have high */
	    f->magnitude = magn2;
	else /* have low */
	    f->magnitude = - magn2;
    }

    return  OK;
}

static Status record_find(float v, String error_msg)
{
    Find_info *f;

    sprintf(error_msg, "allocating find memory");
    MALLOC(f, Find_info, 1);
    MALLOC(f->position, int, ndim);
    MALLOC(f->center, float, ndim);

    sprintf(error_msg, "inserting in find list");
    CHECK_STATUS(insert_list(&find_list, (Generic_ptr) f));

    f->value = v;
    COPY_VECTOR(f->position, point, ndim);

    if (parabolic)
    {
	CHECK_STATUS(parabolic_fit(f, error_msg));
    }
    else
    {
	COPY_VECTOR(f->center, f->position, ndim);
    }

    return  OK;
}

static Status check_find(float v, int *n, String error_msg)
{
    Bool flag;

/*
    printf("new point\n");
    print_point(v, point);
*/

    if (nonadjacent)
    {
	CHECK_STATUS(check_nonadjacent(v, &flag, error_msg));
    }
    else
    {
	CHECK_STATUS(check_adjacent(v, &flag, error_msg));
    }

    if (flag)
    {
	CHECK_STATUS(record_find(v, error_msg));

	nfind++;
	(*n)++;
    }

    return  OK;
}

static Bool within_exclusion(int dim1, int dim2, float delta)
{
    float p1 = point[dim1] + 1;
    float p2 = point[dim2] + 1;

    convert_from_point(REF_PPM, npoints[dim1], ref_info+dim1, &p1);
    convert_from_point(REF_PPM, npoints[dim2], ref_info+dim2, &p2);

    if (ABS(p1 - p2) < delta)
	return  TRUE;
    else
	return  FALSE;
}

static Bool exclude_point()
{
    int i;

    for (i = 0; i < nexclude; i++)
    {
	if (within_exclusion(exclude[i].dim1, exclude[i].dim2,
							exclude[i].delta))
	    return  TRUE;
    }

    return  FALSE;
}

static Status find_extremes(String error_msg)
{
    int i, p;
    float v;

    p = analysed_points / 20;
    p = MAX(p, 1);

    for (i = 0; i < analysed_points; i++)
    {
	if (!(i % p))
	    printf("\t...finding extremes (%1.0f%% done)\n",
						(100.0*i)/analysed_points);

	find_point(ndim, i, point, cum_analysed_points, first, npoints, FALSE);

	if (exclude_point())
	    continue;

	CHECK_STATUS(data_value(&v, point, error_msg));

	if (have_high && (v > high))
	{
	    compare_func = cmp_high;

	    CHECK_STATUS(check_find(v, &nhighs, error_msg));
	}
	else if (have_low && (v < low))
	{
	    compare_func = cmp_low;

	    CHECK_STATUS(check_find(v, &nlows, error_msg));
	}
    }

    return  OK;
}

static int cmp_values(Generic_ptr p1, Generic_ptr p2)
{
    float v1, v2;
    Find_info *f1 = (Find_info *) p1;
    Find_info *f2 = (Find_info *) p2;

    v1 = f1->value - base_value;  v1 = ABS(v1);
    v2 = f2->value - base_value;  v2 = ABS(v2);

    if (v1 < v2)
	return  -1;
    else
	return  1;
}

static void setup_array(Generic_ptr p)
{
    Find_info *f = (Find_info *) p;

    find_info[nsetup++] = f;
}

static Status sort_extremes(String error_msg)
{
    sprintf(error_msg, "allocating sorting memory");

    MALLOC(find_info, Find_info *, nfind);

    nsetup = 0;
    destroy_list(&find_list, setup_array);

    if (have_high && have_low)
	base_value = (high + low) / 2;
    else
	base_value = 0;

    heap_sort((Generic_ptr *) find_info, nfind, FALSE, cmp_values);

    return  OK;
}

static Bool cannot_ignore_buffer()
{
    int i;

    for (i = 0; i < ndim; i++)
    {
	if (buffer[i] > 0)
	    return  TRUE;
    }

    return  FALSE;
}

static Bool within_buffer(int dim, int p1, int p2)
{
    if (p2 < p1)
	SWAP(p1, p2, int);

    if ((p2 - p1) <= buffer[dim])
	return  TRUE;

    if (periodic[dim] && ((p1 + npoints[dim] - p2) <= buffer[dim]))
	return  TRUE;

    return  FALSE;
}

static Bool within_buffers(Find_info *f1, Find_info *f2)
{
    int i;

    for (i = 0; i < ndim; i++)
    {
	if (!within_buffer(i, f1->position[i], f2->position[i]))
	    return  FALSE;
    }

    return  TRUE;
}

static void check_buffer()
{
    int i, j, k;

    nhighs = nlows = 0;

    k = 0;
    for (i = 0; i < nfind; i++)
    {
	for (j = 0; j < k; j++)  /* not very efficient */
	{
	    if (within_buffers(find_info[i], find_info[j]))
	    {
		FREE(find_info[i], Find_info);
		break;
	    }
	}

	if (j == k)
	{
	    if (find_info[i]->value > high)
		nhighs++;
	    else
		nlows++;

	    find_info[k++] = find_info[i];
	}
    }

    nfind = k;
}

static Status calculate_peaks(String error_msg)
{
    CHECK_STATUS(alloc_find_memory(error_msg));

    nfind = nhighs = nlows = 0;

    CHECK_STATUS(find_extremes(error_msg));

    printf("number of extremes = %d (highs = %d, lows = %d)\n",
						nfind, nhighs, nlows);

    if (nfind > 0)
    {
	CHECK_STATUS(sort_extremes(error_msg));

	if (cannot_ignore_buffer())
	{
	    check_buffer();

    printf("after buffer, number of extremes = %d (highs = %d, lows = %d)\n",
						nfind, nhighs, nlows);
	}
    }

    return  OK;
}

Status find_peaks(Size_info *size_info, Ref_info *ref, Find_param *find_param,
		int *p_nfind, Find_info ***p_find_info, String error_msg)
{
    ndim = size_info->ndim;
    npoints = size_info->npoints;

    ref_info = ref;

    have_high = find_param->have_high;
    have_low = find_param->have_low;
    high = find_param->high;
    low = find_param->low;
    nonadjacent = find_param->nonadjacent;
    parabolic = find_param->parabolic;
    first = find_param->first;
    last = find_param->last;
    buffer = find_param->buffer;
    periodic = find_param->periodic;
    nexclude = find_param->nexclude;
    exclude = find_param->exclude;

    init_arrays();

    CHECK_STATUS(calculate_peaks(error_msg));

    *p_nfind = nfind;
    *p_find_info = find_info;

    return  OK;
}
