#include "mex.h"
#include "mexutil.h"

/*
 *  z = matrix_product(x,y)
 *    = x * y
 *
 * Written by Masa-aki Sato 2009/07/01
 */

/* 
   z(j,i) = sum_k  x(j,k) * y(k,i)
   
   z(j,i) : M x N, 
   x(j,k) : M x D 
   y(k,i) : D x N
 */
void matrix_product(double *x, double *y, double *z, int m, int n, int d)
{
	double *xt, *yt, *zt, *xk, *yk;
	int i,j,k;
	
	yt = y;   /* y(k=0, i=0 ) */
	zt = z;   /* z(j=0, i=0 ) */
	
	for (i=0; i<n; i++) {
		xt = x;  /* x(j=0 ,k=0) */
		
		for (j=0; j<m; j++) {
			
			xk  = xt; /* x( j ,k=0) */
			yk  = yt; /* y(k=0, i ) */
			*zt = 0;  /* z( j , i ) */
			
			for (k=0; k<d; k++) {
			
				*zt += *xk * *yk;
				
				xk+=m;  /* x( j ,k++) */
				yk++;   /* y(k++, i ) */
 			}
 			
 			xt++;  /* x(j++ ,k=0) */
 			zt++;  /* z(j++ ,i) */
		}
		
		yt+=d;  /* y(k=0, i++ ) */
	}
}

/* The gateway routine */
void mexFunction( int nlhs, mxArray *plhs[],
                  int nrhs, const mxArray *prhs[])
{
	double *x, *y,*z;
	int    mx,my,nx,ny;
	
	/*  Check for proper number of arguments. */
	if(nrhs!=2) 
	  mexErrMsgTxt("Two inputs required.");
	if(nlhs!=1) 
	  mexErrMsgTxt("One output required.");
	
	/*  Create a pointer to the input x. */
	x = mxGetPr(prhs[0]);
	
	/*  Get the dimensions of the matrix input x. */
	mx = mxGetM(prhs[0]);
	nx = mxGetN(prhs[0]);
	
	/*  Create a pointer to the input y. */
	y = mxGetPr(prhs[1]);
	
	/*  Get the dimensions of the matrix input y. */
	my = mxGetM(prhs[1]);
	ny = mxGetN(prhs[1]);

	if(nx!=my) 
	  mexErrMsgTxt("dimension mismatch.");
	  
	/*  Set the output pointer to the output matrix. 
		plhs[0] = mxCreateDoubleMatrix(my,ny, mxREAL);
		Create uninitialized matrix for speed up
	*/
	plhs[0] = mxCreateDoubleMatrixE(mx,ny,mxREAL);
	
	/*  Create a C pointer to a copy of the output matrix. */
	z = mxGetPr(plhs[0]);
	
	matrix_product(x, y, z, mx, ny, nx);

}
