function [Model,Info] = ...
         vbmeg_multi_fmri(B, Gall, Gact, COV, vb_parm)
% Bayesian Wiener estimation for MEG. 
% Estimate current variance by using time sequence of all trials. 
% Mixture of Multiple fMRI activity patterns are used for prior
% Current variance is represented by mixture of fMRI patterns
% Weight coefficient for each fMRI pattern is estimated
%
% -- Syntax
% function [Model,Info] = vbmeg_multi_fmri(B, Gall, Gact, COV, vb_parm),
%
% -- Input
% B{Nsession} : MEG magnetic field data matrix
%                   size(B{i}) = [Nsensors(i), Tsample, Ntrials(i)]
% Gall{Nsession}: Lead field for whole brain region as background activity
%                   size(Gall{i}) = [Nsensors(i) Ndipole]
%                  If Gall = [], the sparse-mode (focal dipoles only) is
%                  executed.
% Gact{Nsession} : Lead field (Current basis function)
%                  for region with high estimation accuracy 
%                   size(Gact{i})  = [Nsensors(i) Ndipole]
%                  If Gact = [], the iso-mode (no focal dipoles) is
%                  executed.
% COV{Nsession}  : Noise covariance matrix
%                   size(COV{i})   = [Nsensors(i) Nsensors(i)]
% vb_parm        : Constants parameter
%
% (Field of vb_parm)
% .Ntrain
% .Nskip
% .a_min
% .a_max
% .Npre_train (optional)
% .Ntrials
% .Nsession
% .Nwindow
% .Twindow
% .Ta0
% .Tv0
% .sx0
% .a0
% .v0
% .wiener
% .act
% .cont_pr
% 
% -- Output
% Model.a   : Current variance (1/alpha)  
% Model.v   : Current variance for background activity
% Model.sx  : Observation noise variance
%
% Model.act   : Multiple fMRI activity patterns for focal area
% Model.act_v : Their estimated weights
%
% Info      : Containing convergence information such as 'FE','Ev','Err'
%       
% -- Note
% Brain noise : Background brain activity is assumed 
%               as independent Gaussian noise with same variance
%
% "Nvact" (# of vertex for active dipoles) and 
% "Nt"  (length of time window) are common for all sessions. 
%
% "Nsensors" (# of sensor channels in a session) and 
% "Ntrial"  (# of trials in a session) could depend on each session. 
% 
% Note that ax0_j, ax0_z, and sx are defined as variance:
% the inverse of inverse-variance (i.e having the diminsion of variance).
% These notations are different from those in Sato-san's note. 
%
% --- History
% 2004-12-5 M. Sato
% 2005-04-12 O.Yamashita
% * The whole program are revised 
%   so that variables have consistent names with Sato-san's note.
% * The noise variance, sx, ax_z and ax_j are independently handled. 
% * A field 'fix_parm' is added to 'vb_parm' struct.
% * The input 'COV' must be given.
% 2005-04-16 M. Sato
% * Modified free-energy calculation
% 2005-05-07 M. Sato
% * Modified variance update (Faster convergence)
% 2005-05-13 O.Yamashita
% * 'Modelinit' is added as the 6th input optional argument.
% * Iso-mode is implemented. (when Gact is [])
% * A field 'fix_parm' is changed to two flag paramters, 'update_v' and
% 'update_sx'.
% * lower bound of 'gain_z'  
% 2005-05-16 O.Yamashita
% * Sparse-mode is implementd. (when Gall is [])
% * lower bound of 'gain_j'
% 2005-05-31 O.Yamashita
% * constant term of free energy
% 2005-06-02 O.Yamashita
% * constat term of free energy revised
% 2005/08/22 O.Yamashita    --- ver.30b 
% * vb_parm is introduced
% * Output argument is modified
% * Minor bug fix
% 2005/09/29 O. Yamashita
% * Minor bug fix (when Norinet=2 & iso-mode)
%
% 2005/12/5 M.Sato
% * New version
% * Prior is given for each time window
% 2006/8/11 M.Sato
%  Bayesian Wiener estimation
%  Mixture of Multiple fMRI activity patterns are used for prior
% 2008-11-28 Taku Yoshioka
%   Use vb_disp() for displaying message
%
% Copyright (C) 2011, ATR All Rights Reserved.
% License : New BSD License(see VBMEG_LICENSE.txt)

const = vb_define_verbose;
VERBOSE_LEVEL_NOTICE = const.VERBOSE_LEVEL_NOTICE;
VERBOSE_LEVEL_INFO = const.VERBOSE_LEVEL_INFO;

Ntrain = vb_parm.Ntrain;
Nskip  = vb_parm.Nskip;
a_min  = vb_parm.a_min;
a_max  = vb_parm.a_max;

if isfield(vb_parm, 'Npre_train')
	% Number of iteration using original VB-update rule
	% for stable estimation
	Npre_train = vb_parm.Npre_train;
else
	Npre_train = 0;
end

vb_disp(['--- Initial VB-update iteration = ' num2str(Npre_train) ], ...
        VERBOSE_LEVEL_NOTICE);
vb_disp(['--- Total update iteration      = ' num2str(Ntrain) ], ...
        VERBOSE_LEVEL_NOTICE);

% Number parameters
Ntrials     = vb_parm.Ntrials;    % Number of trials
Nsensors    = vb_parm.Nsensors;   % Number of sensors
Nsession    = vb_parm.Nsession;   % Number of sessions
Nwindow     = vb_parm.Nwindow;    % Number of time windows
Twindow     = vb_parm.Twindow;    % Time window index (Tstart, Tend)

Ntotaltrials = sum(Ntrials);      % Total number of trials in all sessions
Nerror  = sum(Ntrials.*Nsensors); % Total MEG channel used for error estimate

% Set prior parameters
Ta0 = vb_parm.Ta0 ;           	% confidence parameter of focal
Tv0 = vb_parm.Tv0 ;           	% confidence parameter of global
sx0 = vb_parm.sx0;              % observation noise variance
a0  = max(vb_parm.a0, a_min );  % initial variance of focal
v0  = max(vb_parm.v0, a_min ); 	% initial weight of global & multiple fMRI

% Sensor noise update flag
if isfield(vb_parm, 'wiener')
	update_sx = vb_parm.wiener.update_sx;
else
	update_sx = 1;
end

% Set fMRI activity pattern
if isfield(vb_parm, 'act')
	act = vb_parm.act;
	Nact_ptn = size(act,2);
else
	act = [];
	Nact_ptn = 0;
end

% Nvarall = 0 (No Global variance estimation) 
%         = 1 (   Global variance estimation)

if isempty(Gall)
    % No global area case
    Njall	 = []; 
    Tv0      = [];
    v0       = [];
    Nvarall  = 0;
	update_v = 0;
else
    % Global area variance estimation case
	Njall    = size(Gall{1},2);    % Number of global area dipoles
    Nvarall  = 1;
	update_v = 1;
end

% Nvarea = # of variance parameters
Narea = (Nvarall + Nact_ptn);

if isempty(Gact) 
    % No focal area case
    Njact    = 1; % avoid empt
    act_id   = []; 
    act      = 0;
	update_z = 0;
elseif Nact_ptn > 0
    % Focal area variance estimation case
	Njact   = size(Gact{1},2);    % Number of active dipoles
    Nvaract = Nact_ptn; 
    act_id  = (Nvarall+1):Narea; % variance index for focal areas
    Njall	= [Njall , Njact * ones(1,Nvaract)]; 
    Tv0     = [Tv0   , Ta0 * ones(1,Nvaract)];
    v0      = [ v0   , a0 * ones(1,Nvaract)];
    
	update_z = 1;
	if isfield(vb_parm, 'wiener')
	   update_v = vb_parm.wiener.update_v;
	else
	   update_v = 0;
	end
else
	error('No active patterns are given')
end

mesg = {'is fixed', 'is estimated'};
vb_disp(['Sensor noise variance ' mesg{update_sx + 1}], ...
        VERBOSE_LEVEL_NOTICE);
vb_disp(['Background current variance ' mesg{update_v + 1}], ...
        VERBOSE_LEVEL_NOTICE);

% Calculate  (G * Az(j) * G') for multiple patterns
GGall = cell(Nsession,1);

for ns = 1 : Nsession
    Nch   = Nsensors(ns); % number of channel available in each session
    
    if isempty(Gall)
        GGall{ns} = [];
    else
        GG = Gall{ns}*Gall{ns}';
        GGall{ns} = GG(:);
    end
    
    if ~isempty(Gact) 
        if Njact ~= size(act,1),
            error('Activity pattern does not match with focal area\n')
        end
        
        G  = Gact{ns};
        Gt = G';
        for j = 1:Nact_ptn,
            Az = repmat(act(:,j) , [1  Nch]);  
            GG = G*(Az.*Gt);
            
            GGall{ns} = [GGall{ns} GG(:)];
        end
    end
end

%%% calculate BB'
for ns = 1 : Nsession
    for nw = 1 : Nwindow
        B0 = 0;
        for i = 1 : Ntrials(ns)
            Bt = B{ns}(:,Twindow(nw,1):Twindow(nw,2),i);
            B0 = B0 + Bt * Bt';
        end
        BBall{ns,nw} = B0;
    end
end
    
clear B Gall

[Nv1, Nv2 ]= size(v0);

if Nv2 ~= Narea, error('Dimension mismatch for V0'); end;

% Temporal variable for variance (ratio)
aall  = zeros(Njact, Nwindow); 
vall  = zeros(Narea,Nwindow);
sxall = zeros(1,Nwindow);

% Free energy and Error history
Info    = zeros(Nwindow*Ntrain, 8); 
k_check = 0;

%%%%% Time window loop %%%%%%%%
for nw=1:Nwindow
    
    Nt  = Twindow(nw,2) - Twindow(nw,1) + 1;
    Ntt = Nt*Ntotaltrials; 

    % confidence parameters
    gx0_j = Tv0; 
    
    gx_j = 0.5 * Ntt * Njall + gx0_j;
    
    % expected variance parameters
    ax0_j = v0;
    sx    = sx0;
    
    % variance hyper-parameter initialization
    if vb_parm.cont_pr == OFF | nw == 1
        % use the prior variance as the initial expected value
        ax_j = ax0_j; 
	    sx   = sx0;
    else
        % use the estimated variance in the previous time window as the
        % initial expected value
    end
    
    %%%%%% Estimation Loop %%%%%% 
    for k=1:Ntrain 
        
        %%% VB E-step -- current estimation
        
        % Initialize averaging variable
        tr1_b   = 0;                % Reconstruction error
        tr2_b   = 0;                % Error variance
        mean_jj = zeros(1,Narea);  %  Z*Z
        gain_j  = zeros(1,Narea);  % Estimation Gain for Z
        Hj      = 0;                % log(det(SB))
        
        for ns=1:Nsession
            
            Nch   = Nsensors(ns);
            Nch2  = Nch^2;
            Ntry  = Ntrials(ns);
            
            % Lead field for each session
            GG    = GGall{ns};  % ( G * Az(j) * G' ) 
            
            % MEG covariance for each session
            BB    = BBall{ns,nw};    % (B B')
            
            % Noise covariance for each session
            Cov   = COV{ns};   % Sigma_G^{-1}   
            
            % Inverse filter for each session
            GSG     = reshape(GG*ax_j(:), [Nch Nch]); % Ga * Sigma_alpha *Ga'
            SB      = GSG + Cov;             % Sigma_B
            SB_inv  = inv( SB );             % Sigam_B^{-1}
            SBBS    = SB_inv*BB*SB_inv;      % Sigma_B^{-1}*BB*Sigma_B^{-1} 
            
            % Reconstraction error = sum( (B-G*J)'*Cov*(B-G*J) )
            tr1_b = tr1_b + sum(sum(Cov.*SBBS));
            % Error variance 
            tr2_b = tr2_b + sum(sum(BB.*SB_inv));
            
            % Current magnitude = diag( sum( Z * Z ) )
            mean_jj = mean_jj + ( reshape(SBBS,[1 Nch2]) * GG ).*ax_j;
            % Estimation gain for Z = diag(ax_z*Gt*SB_inv*G)
            gain_j  = gain_j + Nt*Ntry*( reshape(SB_inv,[1 Nch2]) * GG ).*ax_j;
            
            if ( mod(k, Nskip)==0 || k==Ntrain )
                % Model complexity for Current parameter
                Hj  = Hj - 0.5*Nt*Ntry*vb_log_det(SB);
            end
        end;
        
        % Noise variance estimation
%        sx_total = (tr1_b + mean_jj + sum(mean_zz));
        sx_total = tr2_b;
        
        % update the observation noise variance
        if update_sx
              sx   = sx_total/(Nt*Nerror);   
        end
      
        %%%%% VB M-step -- variance update 
        
        % Save old value for Free energy calculation
        ax_j_old = ax_j;
        
        % variance estimation
        mean_jj_all = 0.5*ax_j.*mean_jj/sx + gx0_j.*ax0_j;
      
        % Update the current variance parameters
		% VB update equation
        %                    gx_j * ax_j = ...
        % (0.5*Ntt*Njall + gx0_j) * ax_j = ...
        %       mean_jj_all + 0.5*ax_j*( Ntt * Njall - gain_j ) ;

        % Update the current variance parameters
        if k <= Npre_train
			% --- VB update rule
	        ax_j_total = mean_jj_all + 0.5*ax_j.*(Ntt*Njall - gain_j) ;
			
        	% Background activity variance estimation
	        if update_v
		        ax_j(1)  = ax_j_total(1) ./ gx_j(1);
	        end
        	% Active current variance estimation
	        if update_z
        		ax_j(act_id) = ax_j_total(act_id) ./ gx_j(act_id);  
	        end
        else
	        % --- Faster convergence update rule
            gain_j = max(gain_j, eps); 

        	% Background activity variance estimation
	        if update_v
	            ax_j(1) = mean_jj_all(1) ./ (0.5*gain_j(1) + gx0_j(1));  
	        end
        	% Active current variance estimation
	        if update_z
	            ax_j(act_id) = mean_jj_all(act_id) ...
	                        ./ (0.5*gain_j(act_id) + gx0_j(act_id));  
	        end
	    end
            
        %%%%%%%% Free Energy Calculation %%%%%%%%
       
        % Model complexity for current variance
        ix = find( gx0_j > eps & ax0_j > eps );
        
        if ~isempty(ix)
        	Haj = sum(gx0_j(ix).*( log(ax0_j(ix)./ax_j_old(ix)) ...
                                     - ax0_j(ix)./ax_j_old(ix) + 1));
        else
            Haj = 0;
        end;

        LP = - 0.5*Nt*Nerror*log(sx);

        % Data Likelihood & Free Energy 
        Ea  = - 0.5*( sx_total/sx - Nt*Nerror );
        FE  = LP + Hj + Ea + Haj;
        Ev  = LP + Hj;  % evidence (without constants)
        Err = tr1_b/sum(sum(BB));% Normalized Error
                   
        % DEBUG Info
        k_check = k_check + 1;
        Info(k_check,:) = [...
            FE, Ev, Hj, LP, Err,...
            sx, max(ax_j), sum(gain_j)];
        % end of free energy calculation
                   
        if mod(k, Nskip)==0 || k==Ntrain 
            vb_disp(['Tn=' num2str(nw,'%3d') ', Iter=' num2str(k,'%4d') ...
                     ', FE=' num2str(FE,'%6.2f') ', Error=' ...
                     num2str(Err,'%6.2e')],VERBOSE_LEVEL_INFO);
        end; 
        
    end % end of learning loop 

    sxall(nw)  = sx;
    vall(:,nw) = ax_j(:);
    
    if ~isempty(act_id)
    	aall(:,nw) = act*ax_j(act_id)';
    end
end % end of time window loop


Model.sx  = sxall;    % Sensor noise variance
Model.a   = aall ; % Current variance 

% Current variance for background activity
if Njall == 0,
    Model.v   = zeros(1,Nwindow) ;
else
    Model.v   = vall(1,:) ; 
end

% Model.act   : Multiple fMRI activity patterns for focal area
% Model.act_v : Their estimated weights
if ~isempty(act_id)
	Model.act   = act;
	Model.act_v = vall(act_id,:);
end

%fprintf('Wiener estimation for (multi) fmri patterns\n')

return

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% functions specific for this function
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function FEconst = free_energy_constant(gx_j, gx0_j, gx_z, gx0_z,...
    Nt, Ntotaltrials, Nerror, Njall)

gx_sx = 0.5*Nt*Nerror;
L0 = 0.5*Nt*(Njall*Ntotaltrials*qG(gx_j) + Ntotaltrials*sum(qG(gx_z)) + Nerror*qG(gx_sx) - Nerror);
H0 = HG0(gx_sx,0) + HG0(gx_j,gx0_j)+sum(HG0(gx_z, gx0_z));
FEconst = L0+H0;

%%%%%
function y = qG(gx)
N=length(gx);
nz = find(gx ~= 0); % non-zero components
y = zeros(N,1);
y_tmp = psi(gx(nz)) - log(gx(nz));
y(nz) = y_tmp;

%%%%%%
function y =HG0(gx, gx0)
N=length(gx0);
nz = find(gx ~= 0); % non-zero components
nz0 = find(gx0 ~= 0); % non-zero components
% gammma 
y1 = zeros(N,1);
y1_tmp = gammaln(gx(nz)) - gx(nz).*psi(gx(nz)) + gx(nz); 
y1(nz) = y1_tmp;
% gamma0
y2 = zeros(N,1);
y2_tmp = gammaln(gx0(nz0)) - gx0(nz0).*log(gx0(nz0)) + gx0(nz0); 
y2(nz0) = y2_tmp;
% gamma*gamma0
y3 = zeros(N,1);
y3_tmp = gx0(nz0).*(psi(gx(nz0))-log(gx(nz0))); 
y3(nz0) = y3_tmp;

y = y1 - y2 + y3;




