#include "mex.h"
#include "mexutil.h"

/*
 *  z = matrix_product(x,y)
 *    = x * y
 *
 * Written by Masa-aki Sato 2009/07/01
 */

/* 
   z(j,i) = sum_k  x(j,k) * y(k,i)
   
   z(j,i) : M x N, 
   x(j,k) : M x D 
   y(k,i) : D x N
 */
void matrix_product(double *x, double *y, double *z, int m, int n, int d)
{
	double *xt, *yt, *zt, *xk, *yk;
	int i,j,k;
	
	yt = y;   /* y(k=0, i=0 ) */
	zt = z;   /* z(j=0, i=0 ) */
	
	for (i=0; i<n; i++) {
		xt = x;  /* x(j=0 ,k=0) */
		
		for (j=0; j<m; j++) {
			
			xk  = xt; /* x( j ,k=0) */
			yk  = yt; /* y(k=0, i ) */
			*zt = 0;  /* z( j , i ) */
			
			for (k=0; k<d; k++) {
			
				*zt += *xk * *yk;
				
				xk+=m;  /* x( j ,k++) */
				yk++;   /* y(k++, i ) */
 			}
 			
 			xt++;  /* x(j++ ,k=0) */
 			zt++;  /* z(j++ ,i) */
		}
		
		yt+=d;  /* y(k=0, i++ ) */
	}
}

/* 
   z(j) = sum_k  x(j,k) * y(k)
   
   z(j)   : M x 1, 
   x(j,k) : M x D
   y(k)   : D x 1
 */
void matrix_vector(double *x, double *y, double *z, int m, int d)
{
	double *xt, *yt, *zt, *xk;
	int j,k;
	
	zt = z;   /* z(j=0) */
	xt = x;  /* x(j=0 ,k=0) */
	
	for (j=0; j<m; j++) {
		xk = xt;  /* x( j ,k=0) */
		yt = y;   /* y(k=0) */
		*zt = 0;  /* z( j ) */
		
		for (k=0; k<d; k++) {
		
			*zt += *xk * *yt;
			
			xk+=m;  /* x( j ,k++) */
			yt++;   /* y(k++, i ) */
 		}
 		
 		xt++;  /* x(j++ ,k=0) */
 		zt++;  /* z(j++ ) */
	}
}
/* 
   z(j) = x(j) * y
   
   z(j) : M x 1, 
   x(j) : M x 1
   y    : 1 x 1
 */
void vector_scalar(double *x, double y, double *z, int m)
{
	double *xt, *zt;
	int j;
	
	zt = z;  /* z(j=0) */
	xt = x;  /* x(j=0) */
	
	for (j=0; j<m; j++) {
		
		*zt = *xt * y;
 		
 		xt++;  /* x(j++ ) */
 		zt++;  /* z(j++ ) */
	}
}

/* 
   z(j,i) =  x(j,i) + y(j,i)
   
   z(j,i) : M x N, 
   x(j,i) : M x D 
   y(j,i) : D x N
 */
void matrix_add(double *x, double *y, double *z, int n)
{
	double *xt, *yt, *zt;
	int i;
	
	xt = x;   /* x(j=0, i=0 ) */
	yt = y;   /* y(j=0, i=0 ) */
	zt = z;   /* z(j=0, i=0 ) */
	
	for (i=0; i<n; i++) {
		*zt = *xt + *yt;
		
		xt++;  /* x(j++ ,i) */
		yt++;  /* y(j++ ,i) */
		zt++;  /* z(j++ ,i) */
	}
}
void copy_data(double *x, double *y, int t)
{
	double *yt, *xt;
	int j;
	
	yt = y;
	xt = x;
	
  	for (j=0; j<t; j++) {
      	*yt = *xt;
      	xt++;
      	yt++;
    }
}
/***
% [Y, Z] = online_filter_loop(A, B, C, D, X, Z)
%	A: M x M, 
%	B: 1 x M, 
%	C: M x 1, 
%	D: 1 x 1
% X : Input signal   (N x T)
% Z : Internal state (N x M)
% Y : Output signal  (N x T)
%
%	Y = Z * C + X * D; % Output update
%	Z = Z * A + X * B; % Internal state variable update
***/
/* The gateway routine */
void mexFunction( int nlhs, mxArray *plhs[],
                  int nrhs, const mxArray *prhs[])
{
	double *x, *y, *z, *a,*b,*c, *d;
	double *xt, *yt, *zt, *z1, *z2, *z3, *z4;
	int    N,T,M,t, m,n;
	mxArray *z1_ptr,*z2_ptr,*z3_ptr,*z4_ptr;
	
	/*  Check for proper number of arguments. */
	if(nrhs!=6) 
	  mexErrMsgTxt("Two inputs required.");
	if(nlhs!=2) 
	  mexErrMsgTxt("One output required.");
	
	/*  Create a pointer to the input*/
	a = mxGetPr(prhs[0]);
	b = mxGetPr(prhs[1]);
	c = mxGetPr(prhs[2]);
	d = mxGetPr(prhs[3]);
	x = mxGetPr(prhs[4]);
	z = mxGetPr(prhs[5]);
	/*
	A: M x M, 
	B: 1 x M, 
	C: M x 1, 
	D: 1 x 1
	X : Input signal   (N x T)
	Z : Internal state (N x M)
	Y : Output signal  (N x T)
	*/
	/*  Get the dimensions of the matrix input */
	m = mxGetM(prhs[0]);
	n = mxGetN(prhs[0]);

	if(n!=m) 
	  mexErrMsgTxt("dimension mismatch.");
	M = m;
	
	N = mxGetM(prhs[4]);
	T   = mxGetN(prhs[4]);
	
	m = mxGetM(prhs[5]);
	n = mxGetN(prhs[5]);
	
	if(m!=N) 
	  mexErrMsgTxt("dimension mismatch.");
	if(n!=M) 
	  mexErrMsgTxt("dimension mismatch.");
	  
	/*  Set the output pointer to the output matrix. 
		plhs[0] = mxCreateDoubleMatrix(my,ny, mxREAL);
		Create uninitialized matrix for speed up
	*/
	plhs[0] = mxCreateDoubleMatrixE(N,T,mxREAL);
	y = mxGetPr(plhs[0]);
	
	plhs[1] = mxCreateDoubleMatrixE(N,M,mxREAL);
	zt = mxGetPr(plhs[1]);

	z1_ptr = mxCreateDoubleMatrixE(N,1,mxREAL);
	z1 = mxGetPr(z1_ptr);
	z2_ptr = mxCreateDoubleMatrixE(N,1,mxREAL);
	z2 = mxGetPr(z2_ptr);
	z3_ptr = mxCreateDoubleMatrixE(N,M,mxREAL);
	z3 = mxGetPr(z3_ptr);
	z4_ptr = mxCreateDoubleMatrixE(N,M,mxREAL);
	z4 = mxGetPr(z4_ptr);
	
	xt = x;
	yt = y;
	copy_data(z, zt, N*M);
	/*
	Xt : Input signal   (N x 1)
	Yt : Output signal  (N x 1)
	Z  : Internal state (N x M)
	A: M x M, 	B: 1 x M, 	C: M x 1, 	D: 1 x 1
	Yt = Z * C + Xt * D; 
	Z  = Z * A + Xt * B; 
	*/
	for (t=0;t<T;t++){
		matrix_vector(zt, c, z1, N, M);
		vector_scalar(xt,*d, z2, N);
		
		matrix_add(z1, z2, yt, N);
		
		matrix_product(zt, a, z3, N, M, M);
		matrix_product(xt, b, z4, N, M, 1);
		
		matrix_add(z3, z4, zt, N*M);
		
		xt+=N;
		yt+=N;
	}
}
