function [Model,Info] = ...
         vbmeg_ard_estimate25_3(B, Gall, Gact, COV, vb_parm, Modelinit)
% Algorithm for jointly estimating source activities and interactions. 
% (2015/03/09 M.Fukushima: computation faster than vbmeg_ard_estimate25_2)
% -- Syntax
% [Model,Info] = vbmeg_ard_estimate25_3(B, Gall, Gact, COV, vb_parm, Modelinit)
%
% 2014/09/05 M.Fukushima
% * The original code is vbmeg_ard_estimate4.m
% * Expricitly compute the currents J
% * Change the way of calculation of tr1_b, mean_zz, gain_z, and Hj
% * Do not support multiple directions dipoles
% * Do not support multiple sessions data
% * Do not support multiple time windows
% * Do not support multiple trials data
% * Do not support extra dipoles
% * Do not support fast convergence rule
% * Compute the exact free energy at each iteration

global flag_pinv;
if isempty(flag_pinv), flag_pinv = false; end

Ntrain = vb_parm.Ntrain;
Nskip  = vb_parm.Nskip;
a_min  = vb_parm.a_min;
a_max  = vb_parm.a_max;

fprintf('\n--- New Dynamic estimation program ---\n\n')
fprintf('--- Total update iteration      = %d\n',Ntrain)

% Number parameters
Ntrials     = vb_parm.Ntrials;    % Number of trials
Nsensors    = vb_parm.Nsensors;   % Number of sensors
Nsession    = vb_parm.Nsession;   % Number of sessions
Nwindow     = vb_parm.Nwindow;    % Number of time windows
Twindow     = vb_parm.Twindow;    % Time window index [Tstart, Tend]
Ntotaltrials = sum(Ntrials);      % Total number of trials in all sessions
Nerror  = sum(Ntrials.*Nsensors); % Total MEG channel used for error estimate

Nvact       = vb_parm.Nvact;      % Number of active current vertices
Njact       = vb_parm.Njact;      % Number of active dipoles parameters
Njall       = vb_parm.Njall;      % Number of whole brain vertices (dipole)
                                  %        for background activity

% Load structural connectivity matrix
Ind = vb_parm.Ind;
Delta = vb_parm.Delta;

% Set prior parameters
Ta0 = vb_parm.Ta0 ;            % confidence parameter of focal
sx0 = vb_parm.sx0;             % observation noise variance
a0  = max(vb_parm.a0, a_min ); % initial variance of focal

if exist('Modelinit','var') & ~isempty(Modelinit)
    a0s = max(Modelinit.a, a_min );
else
    a0s = repmat(a0, [1 Nwindow]);
end

if isempty(Gact) 
	error('No leadlield is given')
end

% Flag cotroling updates of variance parameters
update_sx = 1;
if isfield(vb_parm, 'update_sx')
   update_sx = vb_parm.update_sx;
end

%%% calculate BB' and GG'
for ns = 1 : Nsession
    for nw = 1 : Nwindow
        B0 = 0;
        Btr = B{ns}(:,Twindow(nw,1):Twindow(nw,2),:);
        for i = 1 : Ntrials(ns)
            Bt = B{ns}(:,Twindow(nw,1):Twindow(nw,2),i);
            B0 = B0 + Bt * Bt';
        end
        Ball{ns,nw} = Btr;
        BBall{ns,nw} = B0;
    end
end

% clear B Gall

Jall  = cell(Nsession, Nwindow);
J     = cell(Nsession,1);
J0    = cell(Nsession,1);
Vj     = cell(Nsession,1);
VjT    = cell(Nsession,1);
VjT_full = cell(Nsession,1);

% Temporal variable for variance (ratio)
aall  = zeros(Nvact,Nwindow);
sxall = zeros(Nwindow,1);

% Free energy and Error history
Info    = zeros(Nwindow*Ntrain, 8);
k_check = 0;

tic;
%%%%% Time window loop %%%%%%%%
for nw=1:Nwindow
    
    Nt  = Twindow(nw,2) - Twindow(nw,1) + 1;
    Ntt = Nt*Ntotaltrials; 

    % confidence parameters
    gx0_z = Ta0; 
    gx_z  = 0.5 * Ntt + gx0_z; 
    
    % initial variance parameters
    ax0_z = a0s(:,nw); 

    % variance hyper-parameter initialization
    if vb_parm.cont_pr == OFF | nw == 1
        % use the prior variance as the initial expected value
        ax_z = vb_parm.a;
        sx   = sx0;
    else
        % use the estimated variance in the previous time window as the
        % initial expected value
        ix_zero = find(ax_z < a_min);  % indices taking small variance 
        ax_z(ix_zero) = ax0_z(ix_zero);              
    end
    
    % For computing AtQA and AdtQA_V in forward_pass
    [Delta_set, Ndelta, indx, deltax, sum_ind, nl_indx, NL_indx] = ...
      init_index(Delta);
    AtQA0      = init_AtQA(Nvact, Ndelta, NL_indx);
    AdtQA_V0   = init_AdtQA_V(Nvact, Ndelta, NL_indx);
    sp_indx    = get_AtQA_index(Nvact, Ndelta, NL_indx, AtQA0);
    sp_dl_indx = get_AdtQA_V_index(Nvact, Ndelta, NL_indx, AdtQA_V0);
    
    % Initialization for estimating Q(A)
    mar = cell(Nvact,1);
    MAR = 0*speye(Nvact);
    Vmar = cell(Nvact,1);
    for nv = 1:Nvact
      mar{nv}  = 0 * ones(sum_ind(nv),1);
      Vmar{nv} = 0 * eye(sum_ind(nv));
    end
    
    % Initialization for estimating Q(\eta_{1:N})
    ieta0 = cell(Nvact,1);
    gx0_e = cell(Nvact,1);
    gx_e  = cell(Nvact,1);
    for nv = 1:Nvact
      ieta0{nv} = vb_parm.ieta0 * ones(sum_ind(nv),1);
      gx0_e{nv} = vb_parm.g0 * ones(sum_ind(nv),1);
      gx_e{nv} = gx0_e{nv} + 0.5;
    end
    ieta = ieta0;
    
    % Labeling time stamps for sumVj1 calculation
    vfreq = fliplr(-diff([(Nt-Delta_set)' 0]));
    cum_vfreq = [0 cumsum(vfreq)];
    tsp = repmat(1:Nt,size(Delta_set)) - repmat(Delta_set,[1 Nt]);
    vlb = zeros(size(tsp));
    vct = [];
    for nl = 1:Ndelta
      vlb = vlb + nl*(cum_vfreq(nl)<tsp).*(tsp<=cum_vfreq(nl+1));
    end
    for nl = 1:Ndelta
      vct = [vct sum(vlb==nl,2)];
    end
    
    %%%%%% Estimation Loop %%%%%% 
    for k=1:(Ntrain+1)
        %%% VB E-step -- current estimation

        % Initialize averaging variable
        tr1_b = 0;
        Hj = [0 0];
        
        for ns=1:Nsession
            Ntry  = Ntrials(ns);
            Nsensor = Nsensors(ns);
            
            % Lead field for each session
            G  = Gact{ns};
            Gt = G';
        
            % MEG covariance for each session
            B     = Ball{ns,nw};
            BB    = BBall{ns,nw};
            
            % Noise covariance for each session
            Cov   = COV{ns};
            
            % Initial current
            if nw == 1
              J0{ns} = zeros(Nvact, Delta_set(end), Ntry);
            else
              J0{ns} = Jall{ns,nw-1}(:,Nt-Delta_set(end)+1:Nt,:); 
            end
            
            if k == 1, prevJ = zeros(Nvact, Nt, Ntry);
            else prevJ = J{ns}; end
            
            % Estimate current
            [J{ns}, Vj{ns}, VjT{ns}, GJ, Qj, VjT_full{ns}] = current_estimation ...
              (B, G, Gt, ax_z, Cov, J0{ns}, prevJ, mar, Vmar, ...
              Nvact, Nsensor, Nt, Ntry, flag_pinv, indx, deltax, Delta_set, sum_ind, ...
              nl_indx, NL_indx, AtQA0, AdtQA_V0, sp_indx, sp_dl_indx);
            
            % Reconstraction error
            BGJ = reshape(B-GJ, Nsensor, Nt*Ntry);
            tr1_b = tr1_b + sum(repdiag( BGJ, inv(Cov)*BGJ));
                   
            Hj = Hj + 0.5*Qj;
        end;
        
        %%%%% VB M-step -- parameter update
        %% Sufficient statistics for Q(\beta) calculation
        % Initialize
        sumJ11 = zeros(Nvact,1);
        sumVj1 = zeros(Nvact,1);
        sumJ12 = zeros(Nvact,1);
        MARJ1  = zeros(Nvact,Nt*Ntrials(1));
        
        J22 = reshape(J{1}.^2, Nvact, Nt*Ntrials(1));
        sumJ22 = sum(J22,2);
        sumVj2 = Ntrials(1)*(Nt-sum(vfreq))*VjT{1};
        for nl = 1:Ndelta
          sumVj2 = sumVj2 + Ntrials(1)*vfreq(nl)*diag(Vj{1}{nl});
        end
        
        % Target currents
        J2 = reshape(J{1}, Nvact, Nt*Ntrials(1));
        
        Jtmp = [];
        t_all = [];
        for ntr = 1:Ntrials(1)
          % Concatenate J
          Jtmp = [Jtmp J0{1}(:,:,ntr) J{1}(:,:,ntr)];
          t_all = [t_all (1:Nt)+Delta_set(end) + (Nt+Delta_set(end))*(ntr-1)];
        end
        
        J1   = cell(Nvact,1);
        J1J1 = cell(Nvact,1);
        Vj1  = cell(Nvact,1);
        if k==1; mar_Vmar = cell(Nvact,1); end
        for nv = 1:Nvact
          % Seed Currents
          J1{nv} = select_entries(Jtmp, repmat(indx{nv},[1 Nt*Ntrials(1)]), ...
            repmat(t_all,size(deltax{nv}))-repmat(deltax{nv},[1 Nt*Ntrials(1)]));
          J1J1{nv} = J1{nv}*J1{nv}';
          
          Nind_nv = length(deltax{nv});
          Vj1{nv} = zeros(Nind_nv); % Must not be sparse (must be full)
          for nld = 1:Ndelta
            indlb = (deltax{nv}==Delta_set(nld));
            sum_indlb = sum(indlb);
            if sum_indlb~=0
              Vjtmp = zeros(sum_indlb);
              nl = 1;
              while nl <= Ndelta-nld+1
                Vjtmp = Vjtmp + ...
                  vct(nld,nl)*Vj{1}{nl}(indx{nv}(indlb),indx{nv}(indlb));
                nl = nl + 1;
              end
              Vj1{nv}(indlb,indlb) = Vj1{nv}(indlb,indlb) + Ntrials(1)*Vjtmp;
            end
          end
          
          if k==1; mar_Vmar{nv} = mar{nv}*mar{nv}' + Vmar{nv}; end
          MARJ1(nv,:) = mar{nv}'*J1{nv};
          sumJ11(nv) = sumJ11(nv) + sum(repdiag(mar_Vmar{nv},J1J1{nv}));
          sumVj1(nv) = sumVj1(nv) + Ntrials(1)*sum(repdiag(mar_Vmar{nv},Vj1{nv}));
        end
        sumJ12 = sumJ12 + repdiag(J2',(MARJ1)');
        
        mean_zz = sumJ22 - 2*sumJ12 + sumJ11;
        mean_Vj = sumVj2 + sumVj1;
        
        % Noise variance estimation
        sx_total = tr1_b + sum(mean_zz./ax_z);
        if update_sx, 
          sx  = sx_total/(Nt*Nerror);
        end
        
        Hj(1) = Hj(1) - 0.5*tr1_b/sx;
        Hj(2) = Hj(2) - 0.5*sum(mean_zz./ax_z)/sx - 0.5*sum(mean_Vj./ax_z);
        
        if k==(Ntrain+1)
          %% Skip the rest of updates
          % Save old value for Free energy calculation
          mar_old = mar;
          Vmar_old = Vmar;
          ieta_old = ieta;
          ax_z_old = ax_z;
        else
          %% Q(A) calculation
          % Save old value for Free energy calculation
          mar_old = mar;
          Vmar_old = Vmar;

          for nv = 1:Nvact
            ax_z_invVmar = (J1J1{nv}/sx + Vj1{nv}) + ax_z(nv)*diag(1./ieta{nv});
            J1J2_beta = (J1{nv}*J2(nv,:)'/sx);

            inv_ax_z_invVmar = inv(ax_z_invVmar);
            Vmar{nv} = ax_z(nv)*inv_ax_z_invVmar;
            mar{nv} = inv_ax_z_invVmar*J1J2_beta;
            MAR(nv,indx{nv}) = mar{nv}';
          end

          if k == 1; MAR0 = MAR; end

          %% Q(\eta_{1:N}) calculation
          % Save old value for Free energy calculation
          ieta_old = ieta;
          
          for nv = 1:Nvact
            ieta{nv} = (gx0_e{nv}.*ieta0{nv} + ...
              0.5*(mar{nv}.^2 + diag(Vmar{nv})))./gx_e{nv};
          end

          %% Sufficient statistics for Q(q) calculation
          % Initialize
          sumJ11 = zeros(Nvact,1);
          sumVj1 = zeros(Nvact,1);
          sumJ12 = zeros(Nvact,1);
          MARJ1  = zeros(Nvact,Nt*Ntrials(1));
          
          mar_Vmar = cell(Nvact,1);
          for nv = 1:Nvact
            mar_Vmar{nv} = mar{nv}*mar{nv}' + Vmar{nv};
            MARJ1(nv,:) = mar{nv}'*J1{nv};
            sumJ11(nv) = sumJ11(nv) + sum(repdiag(mar_Vmar{nv},J1J1{nv}));
            sumVj1(nv) = sumVj1(nv) + Ntrials(1)*sum(repdiag(mar_Vmar{nv},Vj1{nv}));
          end
          sumJ12 = sumJ12 + repdiag(J2',(MARJ1)');

          mean_zz = sumJ22 - 2*sumJ12 + sumJ11;
          mean_Vj = sumVj2 + sumVj1;

          % Save old value for Free energy calculation
          ax_z_old = ax_z;

          % Current noise variance estimation
          ax_z = (gx0_z.*ax0_z + 0.5*(mean_zz/sx + mean_Vj))./gx_z;
        end
        
        %%%%%%%% Free Energy Calculation %%%%%%%%
        rz  = ax0_z./ax_z_old;
        Haz = sum(gx0_z.*(log(rz) - rz + 1));
        
        Hmar = 0;
        Hieta = 0;
        mar_vmar_ieta_old = zeros(Nvact,1);
        for nv = 1:Nvact
          if isempty(indx{nv}) == 0
            Hmar = Hmar - sum(log(ieta_old{nv})) + vb_log_det(Vmar_old{nv});
            mar_vmar_ieta_old(nv) = ...
              sum((mar_old{nv}.^2+diag(Vmar_old{nv}))./ieta_old{nv});
          end
          re = ieta0{nv}./ieta_old{nv};
          Hieta = Hieta + sum(gx0_e{nv}.*(log(re) - re + 1));
        end
        Hmar = 0.5*(Hmar-sum(mar_vmar_ieta_old));
        
        Hj(1) = Hj(1) - 0.5*Nt*Nerror*log(sx);

        [L1c L2c Hbc HAc Hec Hqc difbeta] = free_energy_constant...
          (gx_z, gx0_z, gx_e, gx0_e, Nt, Ntotaltrials, Nerror, Nvact, Ind);
        
        % Log-likelihood & Penalty terms 
        Hj(1) = Hj(1) + L1c;
        Hj(2) = Hj(2) + L2c + difbeta;
        Hmar = Hmar + HAc;
        Hieta = Hieta + Hec;
        Haz = Haz + Hqc;
        
        % Free energy
        FE  = sum(Hj) + Hbc + Hmar + Hieta + Haz;
        
        % Normalized Error
        Err = tr1_b/sum(sum(BB));
        
        if k > 1
          % DEBUG Info
          k_check = k_check + 1;
          Info(k_check,:) = [FE, Hj(1), Hj(2), Hbc, Hmar, Hieta, Haz, Err];

          if mod(k-1, Nskip) == 0
            fprintf('Tn=%3d, Iter=%4d, FE=%f, Error=%e\n', nw, k-1, FE, Err);
          end
        end
    end % end of learning loop

    for ns = 1:Nsession
      Jall{ns,nw} = J{ns};
    end
    aall(:,nw) = ax_z;
    sxall(nw)  = sx;
end % end of time window loop
toc;

% Save estimates
Model.a   = aall ; % Current noize variance 
Model.sx  = sxall; % Observation noise variance
Model.v   = zeros(Nwindow,1);
Model.Z   = Jall; % current sources
Model.Vz  = Vj; % source covariances
clear Vj
for ns = 1:Nsession
  Model.Vz{ns}{end+1}  = VjT_full{ns}; % source covariances near the end
end
% # of time samples for each source covariance
Model.vz_Nsample = [vfreq (Nt-sum(vfreq))];
Model.mar  = mar;  % MAR coefficients
Model.MAR  = MAR;  % MAR matrix
Model.MAR0 = MAR0; % MAR matrix at the 1st iteration
for nv = 1:Nvact
  Model.Vmar{nv} = Vmar{nv}; % Covariances of MAR coefficients
end
Model.ieta = ieta; % Variance parameter of MAR coefficients

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% functions specific for this function
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [J, Vj, VjT, GJ, Qj, VjT_full] = current_estimation ...
  (B, G, Gt, ax_z, Cov, J0, prevJ, mar, Vmar, ...
    Nvact, Nsensor, Nt, Ntry, flag_pinv, indx, deltax, Delta_set, sum_ind, ...
    nl_indx, NL_indx, AtQA0, AdtQA_V0, sp_indx, sp_dl_indx)
            
% Forward passing algorithm
[J, Vj, VjT, Qj, VjT_full] = forward_pass ...
  (B, G, Gt, ax_z, Cov, J0, prevJ, mar, Vmar, ...
    Nvact, Nt, Ntry, flag_pinv, indx, deltax, Delta_set, sum_ind, ...
    nl_indx, NL_indx, AtQA0, AdtQA_V0, sp_indx, sp_dl_indx);

% Reconstructed magnetic fields
GJ = G * reshape(J, [Nvact, Nt * Ntry]); 
GJ = reshape(GJ, [Nsensor, Nt, Ntry]);

%%%%%
function [Jfilt, Vj, VjT, Qj, VjT_full] = forward_pass ...
  (B, G, Gt, ax_z, Cov, J0, prevJ, mar, Vmar, ...
    Nvact, Nt, Ntry, flag_pinv, indx, deltax, Delta_set, sum_ind, ...
    nl_indx, NL_indx, AtQA0, AdtQA_V0, sp_indx, sp_dl_indx)
 
% Initialize
Jfilt = zeros(Nvact, Nt, Ntry);
Ndelta = length(Delta_set);
AtQ = cell(Ndelta,1);
A = cell(Ndelta,1);
Vj = cell(Ndelta,1);
for nl = 1:Ndelta
  AtQ{nl} = spalloc(Nvact, Nvact, Nvact);
  A{nl} = spalloc(Nvact, Nvact, Nvact);
  Vj{nl} = zeros(Nvact);
end

% Compute products of model parameters
for nl = 1:Ndelta
  for nv = 1:Nvact
    if ~isempty(nl_indx{nv,nl})
      A{nl}(nv,NL_indx{nv,nl}) = mar{nv}(nl_indx{nv,nl});
    end
  end
  AtQ{nl} = (sparse(diag(1./ax_z))*A{nl})';
end
AtQA = set_AtQA(Nvact, Ndelta, ax_z, mar, Vmar, ...
		nl_indx, NL_indx, sp_indx, AtQA0);
AdtQA_V = set_AdtQA_V(Nvact, Ndelta, ax_z, mar, Vmar, ...
		nl_indx, NL_indx, sp_dl_indx, AdtQA_V0);

% Estimate current covariances
P  = cell(Ndelta+1,1);
Ka = cell(Ndelta+1,1);
Nbound = Inf;
for t = 1:Nt-1
  if Nbound ~= sum((t + Delta_set) <= Nt) % If Vj has to be updated
    Nbound = sum((t + Delta_set) <= Nt);

    sumAtQA = spalloc(Nvact, Nvact, Nvact);
    for nl = 1:Nbound
      sumAtQA = sumAtQA + AtQA{nl};
    end
    QplusAtQA = diag(1./ax_z) + sumAtQA;
    P{Ndelta-Nbound+1} = inv(QplusAtQA);
    Ka{Ndelta-Nbound+1} = ...
      calc_kalman_gain2(G, Gt, P{Ndelta-Nbound+1}, Cov, flag_pinv);

    if Nbound ~= 0
      % Current covariances
      Vj{Ndelta-Nbound+1} = ...
        P{Ndelta-Nbound+1} - Ka{Ndelta-Nbound+1}*(G*P{Ndelta-Nbound+1});
    end
  end
end
PT = ax_z;
KaT = calc_kalman_gain(G, Gt, PT, Cov, flag_pinv);
% Current covariance for Nt
VjT_full = diag(PT) - KaT*(G*diag(PT));
VjT = diag(VjT_full);

% Store (indx, deltax, mar) to (Indx, Deltax, MARt)
Indx   = zeros(Nvact, max(sum_ind));
Deltax = zeros(Nvact, max(sum_ind));
MARt   = zeros(Nvact, max(sum_ind));
for nv = 1:Nvact
  Indx(nv,1:size(indx{nv},1))     = indx{nv}';
  Deltax(nv,1:size(deltax{nv},1)) = deltax{nv}';
  MARt(nv,1:size(mar{nv},1))     = mar{nv}';
end

% Estimate currents
for ntr = 1:Ntry
  Nbound = Inf;
  for t = 1:Nt-1
    t_end = t + Delta_set(end);
    
    if Nbound ~= sum((t + Delta_set) <= Nt) % If Vj has to be updated
      Nbound = sum((t + Delta_set) <= Nt);
    end
    
    % Concatenate J
    Jtmp = [J0(:,:,ntr) Jfilt(:,1:t,ntr) prevJ(:,t+1:end,ntr)];
    % Forward model prediction
    AJfilt = ...
      atimesj(MARt,Indx,Deltax,sum_ind,Jtmp(:,t_end:-1:(t_end-Delta_set(end))));
    
    % Sum the products of model parameters
    saqj = zeros(Nvact,1);
    for nld = 1:Nbound
      saj = zeros(Nvact,1);
      adqsa_v_j = zeros(Nvact,1);
      for nl = 1:Ndelta
        if nl ~= nld
          t_tmp = t_end + Delta_set(nld) - Delta_set(nl);
          saj = saj + A{nl}*Jtmp(:,t_tmp);
          adqsa_v_j = adqsa_v_j + AdtQA_V{nld,nl}*Jtmp(:,t_tmp);
        end
      end
      saqj = saqj + AtQ{nld}* ...
        (Jtmp(:,t_end + Delta_set(nld)) - saj) - adqsa_v_j;
    end
    
    % Estimate currents
    Jpred = P{Ndelta-Nbound+1}*(saqj + spdiags(1./ax_z,0,Nvact,Nvact)*AJfilt);
    Jfilt(:,t,ntr) = Jpred + Ka{Ndelta-Nbound+1}*(B(:,t,ntr) - G*Jpred);
  end
  
  % t = Nt
  Nt_end = Nt + Delta_set(end);
  
  % Concatenate J
  Jtmp = [J0(:,:,ntr) Jfilt(:,:,ntr)];
  % Forward model prediction
  AJfilt = ...
    atimesj(MARt,Indx,Deltax,sum_ind,Jtmp(:,Nt_end:-1:(Nt_end-Delta_set(end))));
  
  % Estimate currents
  Jpred = AJfilt;
  Jfilt(:,Nt,ntr) = Jpred + KaT*(B(:,Nt,ntr) - G*Jpred);
end

% Computing sufficient statics for free energy calculation
vfreq = fliplr(-diff([(Nt-Delta_set)' 0]));
invCov = inv(Cov);

sumlogVj = 0;
sumTrGVjGtiC = 0;
for nl = 1:Ndelta
  sumlogVj = sumlogVj + vfreq(nl)/2*vb_log_det(Vj{nl});
  sumTrGVjGtiC = sumTrGVjGtiC + vfreq(nl)/2*sum(repdiag(G*Vj{nl}*Gt,invCov));
end

Qj(1) = 2*Ntry*(-Nt/2*vb_log_det(Cov)...
  - sumTrGVjGtiC...
  - (Nt-sum(vfreq))/2*sum(repdiag(G*VjT_full*Gt,invCov)));

Qj(2) = 2*Ntry*(...
  - Nt/2*sum(log(ax_z))...
  + sumlogVj...
  + (Nt-sum(vfreq))/2*vb_log_det(VjT_full));

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% functions for calculating constant terms in free energy
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [L1c L2c Hbc HAc Hec Hqc difbeta] = free_energy_constant...
  (gx_z, gx0_z, gx_e, gx0_e, Nt, Ntotaltrials, Nerror, Nvact, Ind)

gx_sx = 0.5*Nt*Nerror;
L1c = 0.5*Nt*(Nerror*qG(gx_sx));
L2c = 0.5*Nt*(Ntotaltrials*sum(qG(gx_z)));
Hbc = HG0(gx_sx,0);
Hqc = sum(HG0(gx_z, gx0_z));
HAc = 0;
Hec = 0;
for nv = 1:Nvact
  HAc = HAc + 0.5*sum(qG(gx_e{nv}));
  Hec = Hec + sum(HG0(gx_e{nv}, gx0_e{nv}));
end
L1c = L1c + 0.5*(-Nt*Nerror*log(2*pi));
L2c = L2c + 0.5*(Nvact*Nt*Ntotaltrials);
HAc = HAc + 0.5*(sum(Ind(:)));
difbeta =  0.5*Nt*Nvact*qG(gx_sx);

%%%%%
function y = qG(gx)
N=length(gx);
nz = find(gx ~= 0); % non-zero components
y = zeros(N,1);
y_tmp = psi(gx(nz)) - log(gx(nz));
y(nz) = y_tmp;

%%%%%%
function y =HG0(gx, gx0)
N=length(gx0);
nz = find(gx ~= 0); % non-zero components
nz0 = find(gx0 ~= 0); % non-zero components
% gammma 
y1 = zeros(N,1);
y1_tmp = gammaln(gx(nz)) - gx(nz).*psi(gx(nz)) + gx(nz); 
y1(nz) = y1_tmp;
% gamma0
y2 = zeros(N,1);
y2_tmp = gammaln(gx0(nz0)) - gx0(nz0).*log(gx0(nz0)) + gx0(nz0); 
y2(nz0) = y2_tmp;
% gamma*gamma0
y3 = zeros(N,1);
y3_tmp = gx0(nz0).*(psi(gx(nz0))-log(gx(nz0))); 
y3(nz0) = y3_tmp;

y = y1 - y2 + y3;
