function [Model,Info] = ...
         vbmeg_ard_estimate4(B, Gall, Gact, COV, vb_parm, Modelinit)
% Variational Bayes estimation for MEG. 
% Estimate current variance by using time sequence of all trials. 
% Prune algorithm.
%
% -- Syntax
% [Model,Info] = vbmeg_ard_estimate2(B, Gall, Gact, COV, vb_parm),
% [Model,Info] = vbmeg_ard_estimate2(B, Gall, Gact, COV, vb_parm, Modelinit)
%
% -- Input
% B{Nsession} : MEG magnetic field data matrix
%                   size(B{i}) = [Nsensors(i), Tsample, Ntrials(i)]
% Gall{Nsession}: Lead field for whole brain region as background activity
%                   size(Gall{i}) = [Nsensors(i) Ndipole]
%                  If Gall = [], the sparse-mode (focal dipoles only) is
%                  executed.
% Gact{Nsession} : Lead field (Current basis function)
%                  for region with high estimation accuracy 
%                   size(Gact{i})  = [Nsensors(i) Ndipole]
%                  If Gact = [], the iso-mode (no focal dipoles) is
%                  executed.
% COV{Nsession}  : Noise covariance matrix
%                   size(COV{i})   = [Nsensors(i) Nsensors(i)]
% vb_parm        : Constants parameter
% 
% -- Optional input
% Modelinit      : Initial Model parameters, if given. 
%      if not given, the variables vb_parm.a0 and vb_parm.v0 are used.
%      The size of each variable in 'Modelinit' must be the same as that of
%      the output 'Model'.
% 
% -- Output
% Model.a   : Current variance (1/alpha)  
% Model.v   : Current variance for background activity
% Model.sx  : Observation noise variance
% Info      : Containing convergence information such as 'FE','Ev','Err'
%       
% -- Note
% Brain noise : Background brain activity is assumed 
%               as independent Gaussian noise with same variance
%
% "Nvact" (# of vertex for active dipoles) and 
% "Nt"  (length of time window) are common for all sessions. 
%
% "Nsensors" (# of sensor channels in a session) and 
% "Ntrial"  (# of trials in a session) could depend on each session. 
% 
% Note that ax0_j, ax0_z, and sx are defined as variance:
% the inverse of inverse-variance (i.e having the diminsion of variance).
% These notations are different from those in Sato-san's note. 
%
% 2004-12-5 M. Sato
% 2005-04-12 O.Yamashita
% * The whole program are revised 
%   so that variables have consistent names with Sato-san's note.
% * The noise variance, sx, ax_z and ax_j are independently handled. 
% * A field 'fix_parm' is added to 'vb_parm' struct.
% * The input 'COV' must be given.
% 2005-04-16 M. Sato
% * Modified free-energy calculation
% 2005-05-07 M. Sato
% * Modified variance update (Faster convergence)
% 2005-05-13 O.Yamashita
% * 'Modelinit' is added as the 6th input optional argument.
% * Iso-mode is implemented. (when Gact is [])
% * A field 'fix_parm' is changed to two flag paramters, 'update_v' and
% 'update_sx'.
% * lower bound of 'gain_z'  
% 2005-05-16 O.Yamashita
% * Sparse-mode is implementd. (when Gall is [])
% * lower bound of 'gain_j'
% 2005-05-31 O.Yamashita
% * constant term of free energy
% 2005-06-02 O.Yamashita
% * constat term of free energy revised
% 2005/08/22 O.Yamashita    --- ver.30b 
% * vb_parm is introduced
% * Output argument is modified
% * Minor bug fix
% 2005/09/29 O. Yamashita
% * Minor bug fix (when Norinet=2 & iso-mode)
% 2006/04/16 M. Sato
% * vb_parm.Npre_train is introduced
%   VB-update equation is used for iteration <= Npre_train
% 2006/08/20 M. Sato
% * Soft normal constraint
% 2008/06/20 M. Sato
% * Only support focal model with normal current direction  
% * Do not support soft normal constraint
% * prune operation is omitted since no pruning is done in Ta0>=10
% 2012-02-08 taku-y
% [debug] Fixed so that flag 'update_sx' correctly works when put the
%         flag 'false'. 
%
% ------------ Relative variance calculation ----------------
% In this implementation,
% Current variance 'ax_z' is defined as ratio w.r.t sx :
% ax_z * sx = < Z*Z>
%
% Copyright (C) 2011, ATR All Rights Reserved.
% License : New BSD License(see VBMEG_LICENSE.txt)

global flag_pinv;
if isempty(flag_pinv), flag_pinv = false; end

Ntrain = vb_parm.Ntrain;
Nskip  = vb_parm.Nskip;
a_min  = vb_parm.a_min;
a_max  = vb_parm.a_max;

if isfield(vb_parm, 'Npre_train')
	% Number of iteration using original VB-update rule
	% for stable estimation
	Npre_train = vb_parm.Npre_train;
else
	Npre_train = Ntrain;
end

fprintf('\n--- New VBMEG estimation program ---\n\n')
fprintf('--- Initial VB-update iteration = %d\n',Npre_train)
fprintf('--- Total update iteration      = %d\n',Ntrain)

% Number parameters
Ntrials     = vb_parm.Ntrials;    % Number of trials
Nsensors    = vb_parm.Nsensors;   % Number of sensors
Nsession    = vb_parm.Nsession;   % Number of sessions
Nwindow     = vb_parm.Nwindow;    % Number of time windows
Twindow     = vb_parm.Twindow;    % Time window index [Tstart, Tend]
Ntotaltrials = sum(Ntrials);      % Total number of trials in all sessions
Nerror  = sum(Ntrials.*Nsensors); % Total MEG channel used for error estimate

Nvact       = vb_parm.Nvact;      % Number of active current vertices
Njact       = vb_parm.Njact;      % Number of active dipoles parameters
Njall       = vb_parm.Njall;      % Number of whole brain vertices (dipole)
                                  %        for background activity
Norient     = vb_parm.Norient;    % Number of current orientation component
Norient_var = vb_parm.Norient_var;% Number of current orientation component
                                  %   for variance estimation
Nvaract     = Nvact * Norient_var;% Number of variance parameters
                                  % (=length(ax_z))
Ratio = Norient/Norient_var;  	  % = (# of orientation)
                                  %   /(# of orientation in variance)

% Set prior parameters
Ta0 = vb_parm.Ta0 ;            % confidence parameter of focal
sx0 = vb_parm.sx0;             % observation noise variance
a0  = max(vb_parm.a0, a_min ); % initial variance of focal

if exist('Modelinit','var') & ~isempty(Modelinit)
    a0s = max(Modelinit.a, a_min );
else
    a0s = repmat(a0, [1 Nwindow]);
end

if isempty(Gact) 
	error('No leadlield is given')
end

% Flag cotroling updates of variance parameters
update_sx = 1;
if isfield(vb_parm, 'update_sx')
   update_sx = vb_parm.update_sx;
end

%%% calculate BB' and GG'
for ns = 1 : Nsession
    for nw = 1 : Nwindow
        B0 = 0;
        for i = 1 : Ntrials(ns)
            Bt = B{ns}(:,Twindow(nw,1):Twindow(nw,2),i);
            B0 = B0 + Bt * Bt';
        end
        BBall{ns,nw} = B0;
    end
end
    
clear B Gall

% Temporal variable for variance (ratio)
aall  = zeros(Nvaract,Nwindow); % Modified by TY 2005-06-15
sxall = zeros(Nwindow,1);

% Free energy and Error history
Info    = zeros(Nwindow*Ntrain, 6); 
k_check = 0;

%%%%% Time window loop %%%%%%%%
for nw=1:Nwindow
    
    Nt  = Twindow(nw,2) - Twindow(nw,1) + 1;
    Ntt = Nt*Ntotaltrials; 

    % confidence parameters
    gx0_z = Ta0; 
    gx_z  = 0.5 * Ntt * Ratio + gx0_z; 
    
    % initial variance parameters
    ax0_z = a0s(:,nw) ; 

    % variance hyper-parameter initialization
    if vb_parm.cont_pr == OFF | nw == 1
        % use the prior variance as the initial expected value
        ax_z = ax0_z;
        sx   = sx0;
    else
        % use the estimated variance in the previous time window as the
        % initial expected value
        ix_zero = find(ax_z < a_min);  % indices taking small variance 
        ax_z(ix_zero) = ax0_z(ix_zero);              
    end
       
    %%%%%% Estimation Loop %%%%%% 
    for k=1:Ntrain 
        %%% VB E-step -- current estimation
        
        % Initialize averaging variable
        tr1_b   = 0;                %  (B - G*J)^2
        mean_zz = zeros(Nvaract,1); % diag(  Z*Z )
        gain_z  = zeros(Nvaract,1); % Estimation Gain for Z
        Hj      = 0;                % log(det(SB))
        
        for ns=1:Nsession
            Ntry  = Ntrials(ns);
            
            % Lead field for each session
            G  = Gact{ns};
            Gt = G';
        
            % MEG covariance for each session
            BB    = BBall{ns,nw};    % (B * B')
            
            % Noise covariance for each session
            Cov   = COV{ns};   % Sigma_G^{-1}   
            
            if Ratio == 1
                A_z = ax_z;
            else          
                A_z = repmat(ax_z,Ratio,1);
            end

            % Inverse filter for each session
            % GSG   = G*(A_z.*Gt) ; % Ga * alpha *Ga'
            GSG     = G * vb_repmultiply(Gt, A_z) ; 
            SB      = GSG + Cov;        % Sigma_B
            if flag_pinv, SB_inv  = pinv( SB );       % Sigam_B^{-1}
            else SB_inv  = inv(SB); end
            GSB     = Gt*SB_inv;        % Ga' * Sigma_B^{-1}
            SBBS    = SB_inv*BB*SB_inv; % Sigma_B^{-1}*BB*Sigma_B^{-1} 
            
            % Reconstraction error = sum( (B-G*J)'*Cov*(B-G*J) )
            tr1_b = tr1_b + sum(sum(Cov.*SBBS));
            
            % Current magnitude = diag( sum( Z * Z ) )
            mean_zz = mean_zz + ax_z ...
                   .* sum(reshape(sum((GSB*BB).*GSB,2),[Nvaract Ratio]),2);
            % Estimation gain for Z = diag(ax_z*Gt*SB_inv*G)
            gain_z  = gain_z + Nt*Ntry*ax_z ...
                   .* sum(reshape(sum(Gt.*GSB, 2),[Nvaract Ratio]), 2);
            
            Hj  = Hj - 0.5*Nt*Ntry*vb_log_det(SB);
        end;
        
        % Noise variance estimation
        sx_total = (tr1_b + sum(mean_zz));
        if update_sx, 
          sx  = sx_total/(Nt*Nerror);
        end
      
        %%%%% VB M-step -- variance update 
        
        % Save old value for Free energy calculation
        ax_z_old = ax_z;
        
        % --- Active current variance estimation
        %
		% VB update equation
        %                    gx_z * ax_z = ...
        % (0.5*Ntt*Ratio + gx0_z) * ax_z = ...
		% 	mean_zz_all + 0.5*ax_z.*( Ntt * Ratio - gain_z );
        
        mean_zz_all = 0.5*ax_z.*mean_zz/sx + gx0_z.*ax0_z;
      
        % Update the current variance parameters
        if k <= Npre_train
			% --- VB update rule
        	% Active current variance estimation
       		ax_z_total = mean_zz_all + 0.5*ax_z.*(Ntt*Ratio - gain_z);
       		ax_z   = ax_z_total ./ gx_z;  
        else
	        % --- Faster convergence update rule
        	% Active current variance estimation
            gain_z = max(gain_z, eps); 
            ax_z = mean_zz_all ./ (0.5*gain_z + gx0_z);  
        end
        
        %%%%%%%% Free Energy Calculation %%%%%%%%
        % Model complexity for current variance
        rz  = ax0_z./ax_z_old;
        Haz = sum( gx0_z.*(log(rz) - rz + 1));
        
        LP = - 0.5*Nt*Nerror*log(sx);

        % Data Likelihood & Free Energy 
        Ea  = - 0.5*( sx_total/sx - Nt*Nerror );
        FE  = LP + Hj + Ea + Haz ;
        Ev  = LP + Hj;  % evidence (without constants)
        Err = tr1_b/trace(BB);% Normalized Error
                   
        % DEBUG Info
        k_check = k_check + 1;
        Info(k_check,:) = [...
            FE, Ev, Hj, LP, Err, cond(SB)];
        % end of free energy calculation
                   
        if mod(k, Nskip)==0 || k==Ntrain 
            fprintf('Tn=%3d, Iter=%4d, FE=%f, Error=%e\n', ...
            		nw, k, FE, Err);
        end; 
        
    end % end of learning loop 

    aall(:,nw) = ax_z;
    sxall(nw)  = sx;
end % end of time window loop

Model.a   = aall ; % Current variance 
Model.sx  = sxall;    % Sensor noise variance
Model.v   = zeros(Nwindow,1);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% functions specific for this function
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function FEconst = free_energy_constant(gx_j, gx0_j, gx_z, gx0_z,...
    Nt, Ntotaltrials, Nerror, Njall)

gx_sx = 0.5*Nt*Nerror;
L0 = 0.5*Nt*(Njall*Ntotaltrials*qG(gx_j) + Ntotaltrials*sum(qG(gx_z)) + Nerror*qG(gx_sx) - Nerror);
H0 = HG0(gx_sx,0) + HG0(gx_j,gx0_j)+sum(HG0(gx_z, gx0_z));
FEconst = L0+H0;

%%%%%
function y = qG(gx)
N=length(gx);
nz = find(gx ~= 0); % non-zero components
y = zeros(N,1);
y_tmp = psi(gx(nz)) - log(gx(nz));
y(nz) = y_tmp;

%%%%%%
function y =HG0(gx, gx0)
N=length(gx0);
nz = find(gx ~= 0); % non-zero components
nz0 = find(gx0 ~= 0); % non-zero components
% gammma 
y1 = zeros(N,1);
y1_tmp = gammaln(gx(nz)) - gx(nz).*psi(gx(nz)) + gx(nz); 
y1(nz) = y1_tmp;
% gamma0
y2 = zeros(N,1);
y2_tmp = gammaln(gx0(nz0)) - gx0(nz0).*log(gx0(nz0)) + gx0(nz0); 
y2(nz0) = y2_tmp;
% gamma*gamma0
y3 = zeros(N,1);
y3_tmp = gx0(nz0).*(psi(gx(nz0))-log(gx(nz0))); 
y3(nz0) = y3_tmp;

y = y1 - y2 + y3;




