/*********************************************************************
    DTWK.c (version 0.1)                                          
                                                                    
    Computes the (log) global alignment kernel

    Copyright 2009 Marco Cuturi                                

    The corresponding paper for this code is

    M.C, J.-P. Vert, O. Birkenes, T. Matsui, 

    "A kernel for time series based on global alignments",
    
    and can be found in a preprint form at http://arxiv.org/abs/cs.CV/0610033

    DTWK is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    Foobar is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Foobar; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

*********************************************************************/

#include <mex.h>
#include <math.h>
/* Useful constants */
#define LOG0 -10000          /* log(0) */
#define LOGP(x,y) (((x)>(y))?(x)+log1p(exp((y)-(x))):(y)+log1p(exp((x)-(y))))


double DTWkernelcompute(double *seq1 ,double *seq2, int nX, int nY, int dimvect, double sigma)
  /* Implementation of the global alignment kernel, a kernel inspired by the DTW score */
  /* seq1 is a first sequence represented as a matrix of real elements. Each line i corresponds to the vector of observations at time i. */
  /* seq2 is the second sequence formatted in the same way.
  /* NOTE: seq1 and seq2 may have a different number of lines but need to have the same number of columns. */
  /* sigma stands for the bandwidth of the Gaussian kernel */

{
    register int
    i,j,ii,                /* loop indexes */
    cur, old,           /* to indicate the array to use (0 or 1) */
    curpos, frompos1,frompos2,frompos3;    /* position in an array */
    double aux , aux2;
    int cl;                /* length of a column for the dynamic programming */
    
    double sum=0;
    double grammat[nX*nY];
    double S_=1/(sigma*sigma);
  /*********************************************************/
  /* Computation of the Gram matrix with a Gaussian kernel */
  /*********************************************************/
    for (i=0;i<nX;i++) {
        for (j=0;j<nY;j++) {
            sum=0;
            for (ii=0;ii<dimvect;ii++) {
                sum+=(seq1[i+ii*nX]-seq2[j+ii*nY])*(seq1[i+ii*nX]-seq2[j+ii*nY]);
            }
            grammat[i*nY+j]=exp(-sum*S_);
        }
    }
  /* Initialization of the arrays */
  /* Each array stores two successive columns of the (nX+1)x(nY+1) table used in dynamic programming */
    cl = nY+1;           /* each column stores the positions in the aaY sequence, plus a position at zero */
    double logM[2*cl];
    
  /************************************************/
  /* First iteration : initialization of column 0 */
  /************************************************/
  /* The log=proabilities of each state are initialized for the first column (x=0,y=0..nY) */
    
    for (j=0;j<cl;j++) {
        logM[j]=LOG0;
    }
    logM[0]=0;
  /* Update column order */
    cur = 1;      /* Indexes [0..cl-1] are used to process the next column */
    old = 0;      /* Indexes [cl..2*cl-1] were used for column 0 */
    
    
    double logT=LOG0; /* LOG OF THE FINAL KERNEL VALUE.
  /*  mexPrintf("---------------1\n");
  /************************************************/
  /* Next iterations : processing columns 1 .. nX */
  /************************************************/
    
  /* Main loop to vary the position in aaX : i=1..nX */
    for (i=1;i<=nX;i++) {
        /* Special update for positions (i=1..nX,j=0) */
        curpos = cur*cl;                  /* index of the state (i,0) */
        logM[curpos] = LOG0;
        /* Secondary loop to vary the position in aaY : j=1..nY */
        for (j=1;j<=nY;j++) {
            curpos = cur*cl + j;            /* index of the state (i,j) */
            frompos1 = old*cl + j;            /* index of the state (i-1,j) */
            frompos2 = cur*cl + j-1;          /* index of the state (i,j-1) */
            frompos3 = old*cl + j-1;          /* index of the state (i-1,j-1) */
            
            /* Doing the updates, in two steps*/
            aux= LOGP (logM[frompos1],logM[frompos2] );
            logM[curpos] = LOGP( aux , logM[frompos3] ) + grammat[(i-1)*nY+j-1];
        } 
    /* Update the culumn order */
        cur = 1-cur;
        old = 1-old;
    }
  /* Return the logarithm of the kernel */
return logM[curpos];
}





void mexFunction(int nlhs, mxArray *plhs[ ], int nrhs, const mxArray *prhs[ ])
{
    double *seq1, *seq2;
    int nX,nY,dimvect;
    double sigma;
 
    if (nrhs != 3) {
        mexErrMsgTxt("Three input arguments required. Two d-dimensional time series, that is two matrices n1 x d and n2 x d, and a real sigma.");
    } 
    
    if (nlhs > 1) {
        mexErrMsgTxt("Too many output arguments.");
    }
    
    seq1= (double *)mxGetPr(prhs[0]);
    seq2= (double *)mxGetPr(prhs[1]);
    nX = (mxGetDimensions(prhs[0]))[0];
    nY = (mxGetDimensions(prhs[1]))[0];
    dimvect =(mxGetDimensions(prhs[0]))[1];

    if (dimvect != (mxGetDimensions(prhs[1]))[1] ) {
        mexErrMsgTxt("The two input time series should describe time series that have the same dimension. They should have the same number of columns.");
    }
    sigma = mxGetScalar(prhs[2]);
    if (sigma<=0 ) {
        mexErrMsgTxt("Sigma must be positive");
    }
    double k=DTWkernelcompute(seq1,seq2,nX,nY,dimvect,sigma);
    plhs[0]= mxCreateDoubleScalar(k);
}
