/*
 ***** BEGIN LICENSE BLOCK *****
 * Version: MPL 1.1/GPL 2.0/LGPL 2.1
 *
 * The contents of this file are subject to the Mozilla Public License Version
 * 1.1 (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 * http://www.mozilla.org/MPL/
 *
 * Software distributed under the License is distributed on an "AS IS" basis,
 * WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License
 * for the specific language governing rights and limitations under the
 * License.
 *
 * The Original Code is Global Alignment Kernel, (C) 2010, Marco Cuturi
 *
 * The Initial Developers of the Original Code is
 *
 * Marco Cuturi   mcuturi@i.kyoto-u.ac.jp
 *
 * Portions created by the Initial Developers are
 * Copyright (C) 2011 the Initial Developers. All Rights Reserved.
 *
 *
 * Alternatively, the contents of this file may be used under the terms of
 * either the GNU General Public License Version 2 or later (the "GPL"), or
 * the GNU Lesser General Public License Version 2.1 or later (the "LGPL"),
 * in which case the provisions of the GPL or the LGPL are applicable instead
 * of those above. If you wish to allow use of your version of this file only
 * under the terms of either the GPL or the LGPL, and not to allow others to
 * use your version of this file under the terms of the MPL, indicate your
 * decision by deleting the provisions above and replace them with the notice
 * and other provisions required by the GPL or the LGPL. If you do not delete
 * the provisions above, a recipient may use your version of this file under
 * the terms of any one of the MPL, the GPL or the LGPL.
 *
 ***** END LICENSE BLOCK *****
 *
 * REVISIONS:
 * This is v1.03 of Global Alignment Kernel, September 12th 2011.
 * Added log1p function for windows platforms. Uncomment lines 62-66 if you are compiling this file with matlab on windows.
 * 
 * Previous versions:
 * v1.02 Changed some C syntax that was not compiled properly on Windows platforms, June 8th
 * v1.01 of Global Alignment Kernel, May 12th 2011 (updated comments fields)
 * v1.0 of Global Alignment Kernel, March 25th 2011.

 */



#include <mex.h>
#include <math.h>
/* Useful constants */
#define LOG0 -10000          /* log(0) */
#define LOGP(x, y) (((x)>(y))?(x)+log1p(exp((y)-(x))):(y)+log1p(exp((x)-(y))))




/*---------------  Uncomment this part if you want to compile on Windows platforms <BEGIN> ---------------  
 *------ ( log1p is not defined on windows compilers, so we redefined it here below ) --------------------/
double log1p (const double x)
{
volatile double y;
y = 1 + x;
return log(y)-((y-1)-x)/y ;} /* cancels errors with IEEE arithmetic */
/*---------------  Uncomment this part if you want to compile on Windows platforms <END> ---------------  */






double logGAK(double *seq1 , double *seq2, int nX, int nY, int dimvect, double sigma, int triangular)
/* Implementation of the (Triangular) global alignment kernel.
 *
 * See details about the matlab wrapper mexFunction below for more information on the inputs that need to be called from Matlab
 *
 * /* seq1 is a first sequence represented as a matrix of real elements. Each line i corresponds to the vector of observations at time i.
 * seq2 is the second sequence formatted in the same way.
 * nX, nY and dimvect provide the number of lines of seq1 and seq2.
 * sigma stands for the bandwidth of the \phi_\sigma distance used kernel
 * lambda is an additional factor that can be used with the Geometrically divisible Gaussian Kernel
 * triangular is a parameter which parameterizes the triangular kernel
 * kerneltype selects either the Gaussian Kernel or its geometrically divisible equivalent */

{
    int i, j, ii, cur, old, curpos, frompos1, frompos2, frompos3;    
    double aux , aux2;
    int cl = nY+1;                /* length of a column for the dynamic programming */
    
    
    double sum=0;
    double gram, Sig;    
    /* logM is the array that will stores two successive columns of the (nX+1) x (nY+1) table used to compute the final kernel value*/
    double * logM = malloc(2*cl * sizeof(double));        
    
    int trimax = (nX>nY) ? nX-1 : nY-1; /* Maximum of abs(i-j) when 1<=i<=nX and 1<=j<=nY */
    
    double *logTriangularCoefficients = malloc((trimax+1) * sizeof(double)); 
    if (triangular>0) {
        /* initialize */
        for (i=0;i<=trimax;i++){
            logTriangularCoefficients[i]=LOG0; /* Set all to zero */
        }
        
        for (i=0;i<((trimax<triangular) ? trimax+1 : triangular);i++) {
            logTriangularCoefficients[i]=log(1-i/triangular);
        }
    }
    else
        for (i=0;i<=trimax;i++){
        logTriangularCoefficients[i]=0; /* 1 for all if triangular==0, that is a log value of 0 */
        }
    Sig=-1/(2*sigma*sigma);
    
    
    
    /****************************************************/
    /* First iteration : initialization of columns to 0 */
    /****************************************************/
    /* The left most column is all zeros... */
    for (j=1;j<cl;j++) {
        logM[j]=LOG0;
    }
    /* ... except for the lower-left cell which is initialized with a value of 1, i.e. a log value of 0. */
    logM[0]=0;
    
    /* Cur and Old keep track of which column is the current one and which one is the already computed one.*/
    cur = 1;      /* Indexes [0..cl-1] are used to process the next column */
    old = 0;      /* Indexes [cl..2*cl-1] were used for column 0 */
    
    /************************************************/
    /* Next iterations : processing columns 1 .. nX */
    /************************************************/
    
    /* Main loop to vary the position for i=1..nX */
    for (i=1;i<=nX;i++) {
        /* Special update for positions (i=1..nX,j=0) */
        curpos = cur*cl;                  /* index of the state (i,0) */
        logM[curpos] = LOG0;
        /* Secondary loop to vary the position for j=1..nY */
        for (j=1;j<=nY;j++) {
            curpos = cur*cl + j;            /* index of the state (i,j) */
            if (logTriangularCoefficients[abs(i-j)]>LOG0) {
                frompos1 = old*cl + j;            /* index of the state (i-1,j) */
                frompos2 = cur*cl + j-1;          /* index of the state (i,j-1) */
                frompos3 = old*cl + j-1;          /* index of the state (i-1,j-1) */
                
                /* We first compute the kernel value */
                sum=0;
                for (ii=0;ii<dimvect;ii++) {
                    sum+=(seq1[i-1+ii*nX]-seq2[j-1+ii*nY])*(seq1[i-1+ii*nX]-seq2[j-1+ii*nY]);
                }
                gram= logTriangularCoefficients[abs(i-j)] + sum*Sig ;
                gram -=log(2-exp(gram));
                
                /* Doing the updates now, in two steps. */
                aux= LOGP(logM[frompos1], logM[frompos2] );
                logM[curpos] = LOGP( aux , logM[frompos3] ) + gram;
            }
            else {
                logM[curpos]=LOG0;
            }
        }
        /* Update the culumn order */
        cur = 1-cur;
        old = 1-old;
    }
    aux = logM[curpos];
    free(logM);
	free(logTriangularCoefficients);
    /* Return the logarithm of the Global Alignment Kernel */    
    return aux;
    
}





void mexFunction(int nlhs, mxArray *plhs[ ], int nrhs, const mxArray *prhs[ ]) {
    /* Inputs are, in this order
     * A N1 x d matrix (d-variate time series with N1 observations)
     * A N2 x d matrix (d-variate time series with N2 observations)
     * A sigma >0 parameter for the Gaussian kernel's width
     * A triangular integer which parameterizes the band.
     *  when triangular = 0, the triangular kernel is not used, the evaluation is on the sum of all paths, and the complexity if of the order of N1 x N2
     *  when triangular > 0, the triangular kernel is used and downweights some of the paths that lay far from the diagonal. */
    
    double *seq1, *seq2;
    int nX, nY, dimvect, triangular;
    double sigma;
    
    if (nrhs != 4) {
        mexErrMsgTxt("4 input arguments required. Two d-dimensional time series, that is two matrices n1 x d and n2 x d, a positive sigma kernel bandwidth, a triangular kernel band parameter (set to zero if one wants to sum over all paths) ");
    }
    
    if (nlhs > 1) {
        mexErrMsgTxt("Too many output arguments.");
    }
    
    seq1= (double *)mxGetPr(prhs[0]);
    seq2= (double *)mxGetPr(prhs[1]);
    nX = (mxGetDimensions(prhs[0]))[0];
    nY = (mxGetDimensions(prhs[1]))[0];
    dimvect =(mxGetDimensions(prhs[0]))[1];
    
    if (dimvect != (mxGetDimensions(prhs[1]))[1] ) {
        mexErrMsgTxt("The two input time series should describe time series that have the same dimension. They should have the same number of columns.");
    }
    sigma = mxGetScalar(prhs[2]);
    if (sigma<=0 ) {
        mexErrMsgTxt("Sigma must be positive");
    }    
    triangular = mxGetScalar(prhs[3]);
    if (triangular<0) {
        mexErrMsgTxt("Triangular parameter must be non-negative");
    }
   /* If triangular is smaller than the difference in length of the time series, the kernel is equal to zero, i.e. its log is set to -100000 */
    if ((triangular>0)&&(abs(nX-nY)>triangular)) {
        plhs[0]= mxCreateDoubleScalar(-100000);
    }
	/* otherwise we call the C code above. */
    else {
        plhs[0]= mxCreateDoubleScalar( logGAK(seq1, seq2, nX, nY, dimvect, sigma, triangular) );
    }
}