#include "mex.h"
#include "mexutil.h"

/*
 *  z = repdiag(x,y)
 *
 * Get diagonal elements of a matrix X*Y 
 *
 * Written by M.Fukushima 2009/12/01
 * Original code is 'repmultiply.c'
 */


/* 
   z(n,1) = diag(x(m,n)'*y(m,n))
   x : M x N, y : M x N
   (mx==my && nx==ny)
 */
void xtimesyrow(double *x, double *y, double *z, int m, int n)
{
  int i,j,count=0;
  double tmp=0;
  
  for (i=0; i<n; i++) {
    for (j=0; j<m; j++) {
      tmp += *(x+count) * *(y+count);
      count++;
    }
   *(z+i) = tmp;
   tmp = 0;
  }
}

/* The gateway routine */
void mexFunction( int nlhs, mxArray *plhs[],
                  int nrhs, const mxArray *prhs[])
{
	double *x,*y,*z;
	int    mx,my,nx,ny;
	
	/*  Check for proper number of arguments. */
	if(nrhs!=2) 
	  mexErrMsgTxt("Two inputs required.");
	if(nlhs!=1) 
	  mexErrMsgTxt("One output required.");
	
	/*  Create a pointer to the input y. */
	y = mxGetPr(prhs[0]);
	
	/*  Get the dimensions of the matrix input y. */
	my = mxGetM(prhs[0]);
	ny = mxGetN(prhs[0]);

	/*  Create a pointer to the input x. */
	x = mxGetPr(prhs[1]);
	
	/*  Get the dimensions of the matrix input x. */
	mx = mxGetM(prhs[1]);
	nx = mxGetN(prhs[1]);
	
	/*  Set the output pointer to the output matrix. 
		plhs[0] = mxCreateDoubleMatrix(my,ny, mxREAL);
		Create uninitialized matrix for speed up
	*/
	plhs[0] = mxCreateDoubleMatrixE(ny,1,mxREAL);
	
	/*  Create a C pointer to a copy of the output matrix. */
	z = mxGetPr(plhs[0]);
	
	if (mx==my && nx==ny){
		xtimesyrow(x,y,z,my,ny);
		}
	else {
	    mexErrMsgTxt("Do not match dim");
	}

}
