/* Copyright (c) 1991-2002 Doshita Lab. Speech Group, Kyoto University */
/* Copyright (c) 2000-2002 Speech and Acoustics Processing Lab., NAIST */
/*   All rights reserved   */

/* outprob.c --- top routine of calculate output probability  */

/* $Id: outprob.c,v 1.5 2002/09/11 22:01:50 ri Exp $ */

#include <sent/stddefs.h>
#include <sent/speech.h>
#include <sent/htk_hmm.h>
#include <sent/htk_param.h>
#include <sent/hmm.h>
#include <sent/gprune.h>
#include "globalvars.h"


/* cache */
static int statenum;		/* total number of HMM state */
static LOGPROB **outprob_cache = NULL; /* outprob cache [t][stateid] */
static int allocframenum;	/* allocated number of frame */
static int allocblock;		/* allocate block size per allocateion */
static LOGPROB *last_cache;	/* local cache */
#define LOG_UNDEF (LOG_ZERO - 1) /* not calced */

/* local variables init functions */
/* _init(): call once on startup */
/* _prepare(): call before every input (framenum may vary) */
boolean
outprob_cache_init()
{
  statenum = OP_hmminfo->totalstatenum;
  outprob_cache = NULL;
  allocframenum = 0;
  allocblock = OUTPROB_CACHE_PERIOD;
  OP_time = -1;
  return TRUE;
}
boolean
outprob_cache_prepare()
{
  int i, size;
  int s,t;

  /* clear already allocated area */
  for (t = 0; t < allocframenum; t++) {
    for (s = 0; s < statenum; s++) {
      outprob_cache[t][s] = LOG_UNDEF;
    }
  }
  
  return TRUE;
}

/* expand cache area if needed */
static void
outprob_cache_extend(int reqframe)
{
  int newnum;
  int size;
  int t, s;
  LOGPROB *tmpp;
  
  if (reqframe < allocframenum) return;

  /* allocate per certain period */
  newnum = reqframe + 1;
  if (newnum < allocframenum + allocblock) newnum = allocframenum + allocblock;
  size = (newnum - allocframenum) * statenum;
  
  /* allocate */
  if (outprob_cache == NULL) {
    outprob_cache = (LOGPROB **)mymalloc(sizeof(LOGPROB *) * newnum);
  } else {
    outprob_cache = (LOGPROB **)myrealloc(outprob_cache, sizeof(LOGPROB *) * newnum);
  }
  tmpp = (LOGPROB *)mymalloc(sizeof(LOGPROB) * size);
  for(t = allocframenum; t < newnum; t++) {
    outprob_cache[t] = &(tmpp[(t - allocframenum) * statenum]);
    for (s = 0; s < statenum; s++) {
      outprob_cache[t][s] = LOG_UNDEF;
    }
  }

  /*j_printf("outprob cache: %d->%d\n", allocframenum, newnum);*/
  allocframenum = newnum;
}


/* compute output probability of a state */
LOGPROB
outprob_state(
     int t,			/* time frame */
     HTK_HMM_State *stateinfo,	/* state info */
     HTK_Param *param)		/* parameter */
{
  LOGPROB outp;
  
  /* set global values for outprob functions to access them */
  OP_state = stateinfo;
  OP_state_id = stateinfo->id;
  OP_param = param;
  if (OP_time != t) {
    OP_last_time = OP_time;
    OP_time = t;
    OP_vec = param->parvec[t];
    OP_veclen = param->veclen;

    outprob_cache_extend(t);	/* extend cache if needed */
    last_cache = outprob_cache[t]; /* reduce 2-d array access */
  }
  
  /* consult cache */
  if ((outp = last_cache[OP_state_id]) == LOG_UNDEF) {
    outp = last_cache[OP_state_id] = calc_outprob_state();
  }
  return(outp);
}

/* return maximum outprob for cd state set */
static LOGPROB
outprob_cd_max(int t, CD_State_Set *lset, HTK_Param *param)
{
  LOGPROB maxprob, prob;
  int i;
  maxprob = LOG_ZERO;
  for(i=0;i<lset->num;i++) {
    prob = outprob_state(t, lset->s[i], param);
    if (maxprob < prob) maxprob = prob;
  }
  return(maxprob);
}

/* return average outprob for cd state set */
static LOGPROB
outprob_cd_avg(int t, CD_State_Set *lset, HTK_Param *param)
{
  LOGPROB sum, p;
  int i,j;
  sum = 0.0;
  j = 0;
  for(i=0;i<lset->num;i++) {
    p = outprob_state(t, lset->s[i], param);
    if (p > LOG_ZERO) {
      sum += p;
      j++;
    }
  }
  return(sum/(float)j);
}

LOGPROB
outprob_cd(int t, CD_State_Set *lset, HTK_Param *param)
{
  if (!OP_hmminfo->prefer_cdset_avg) {
    return(outprob_cd_max(t, lset, param));
  } else {
    return(outprob_cd_avg(t, lset, param));
  }
}
  


/* generic outprob function for HMM_STATE */
LOGPROB
outprob(int t, HMM_STATE *hmmstate, HTK_Param *param)
{
  if (hmmstate->is_pseudo_state) {
    return(outprob_cd(t, hmmstate->out.cdset, param));
  } else {
    return(outprob_state(t, hmmstate->out.state, param));
  }
}
