Home > vbmeg > functions > estimation > bayes > vbmeg_ard_estimate4.m

vbmeg_ard_estimate4

PURPOSE ^

Variational Bayes estimation for MEG.

SYNOPSIS ^

function [Model,Info] =vbmeg_ard_estimate4(B, Gall, Gact, COV, vb_parm, Modelinit)

DESCRIPTION ^

 Variational Bayes estimation for MEG. 
 Estimate current variance by using time sequence of all trials. 
 Prune algorithm.

 -- Syntax
 [Model,Info] = vbmeg_ard_estimate2(B, Gall, Gact, COV, vb_parm),
 [Model,Info] = vbmeg_ard_estimate2(B, Gall, Gact, COV, vb_parm, Modelinit)

 -- Input
 B{Nsession} : MEG magnetic field data matrix
                   size(B{i}) = [Nsensors(i), Tsample, Ntrials(i)]
 Gall{Nsession}: Lead field for whole brain region as background activity
                   size(Gall{i}) = [Nsensors(i) Ndipole]
                  If Gall = [], the sparse-mode (focal dipoles only) is
                  executed.
 Gact{Nsession} : Lead field (Current basis function)
                  for region with high estimation accuracy 
                   size(Gact{i})  = [Nsensors(i) Ndipole]
                  If Gact = [], the iso-mode (no focal dipoles) is
                  executed.
 COV{Nsession}  : Noise covariance matrix
                   size(COV{i})   = [Nsensors(i) Nsensors(i)]
 vb_parm        : Constants parameter
 
 -- Optional input
 Modelinit      : Initial Model parameters, if given. 
      if not given, the variables vb_parm.a0 and vb_parm.v0 are used.
      The size of each variable in 'Modelinit' must be the same as that of
      the output 'Model'.
 
 -- Output
 Model.a   : Current variance (1/alpha)  
 Model.v   : Current variance for background activity
 Model.sx  : Observation noise variance
 Info      : Containing convergence information such as 'FE','Ev','Err'
       
 -- Note
 Brain noise : Background brain activity is assumed 
               as independent Gaussian noise with same variance

 "Nvact" (# of vertex for active dipoles) and 
 "Nt"  (length of time window) are common for all sessions. 

 "Nsensors" (# of sensor channels in a session) and 
 "Ntrial"  (# of trials in a session) could depend on each session. 
 
 Note that ax0_j, ax0_z, and sx are defined as variance:
 the inverse of inverse-variance (i.e having the diminsion of variance).
 These notations are different from those in Sato-san's note. 

 2004-12-5 M. Sato
 2005-04-12 O.Yamashita
 * The whole program are revised 
   so that variables have consistent names with Sato-san's note.
 * The noise variance, sx, ax_z and ax_j are independently handled. 
 * A field 'fix_parm' is added to 'vb_parm' struct.
 * The input 'COV' must be given.
 2005-04-16 M. Sato
 * Modified free-energy calculation
 2005-05-07 M. Sato
 * Modified variance update (Faster convergence)
 2005-05-13 O.Yamashita
 * 'Modelinit' is added as the 6th input optional argument.
 * Iso-mode is implemented. (when Gact is [])
 * A field 'fix_parm' is changed to two flag paramters, 'update_v' and
 'update_sx'.
 * lower bound of 'gain_z'  
 2005-05-16 O.Yamashita
 * Sparse-mode is implementd. (when Gall is [])
 * lower bound of 'gain_j'
 2005-05-31 O.Yamashita
 * constant term of free energy
 2005-06-02 O.Yamashita
 * constat term of free energy revised
 2005/08/22 O.Yamashita    --- ver.30b 
 * vb_parm is introduced
 * Output argument is modified
 * Minor bug fix
 2005/09/29 O. Yamashita
 * Minor bug fix (when Norinet=2 & iso-mode)
 2006/04/16 M. Sato
 * vb_parm.Npre_train is introduced
   VB-update equation is used for iteration <= Npre_train
 2006/08/20 M. Sato
 * Soft normal constraint
 2008/06/20 M. Sato
 * Only support focal model with normal current direction  
 * Do not support soft normal constraint
 * prune operation is omitted since no pruning is done in Ta0>=10
 2012-02-08 taku-y
 [debug] Fixed so that flag 'update_sx' correctly works when put the
         flag 'false'. 

 ------------ Relative variance calculation ----------------
 In this implementation,
 Current variance 'ax_z' is defined as ratio w.r.t sx :
 ax_z * sx = < Z*Z>

 Copyright (C) 2011, ATR All Rights Reserved.
 License : New BSD License(see VBMEG_LICENSE.txt)

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [Model,Info] = ...
0002          vbmeg_ard_estimate4(B, Gall, Gact, COV, vb_parm, Modelinit)
0003 % Variational Bayes estimation for MEG.
0004 % Estimate current variance by using time sequence of all trials.
0005 % Prune algorithm.
0006 %
0007 % -- Syntax
0008 % [Model,Info] = vbmeg_ard_estimate2(B, Gall, Gact, COV, vb_parm),
0009 % [Model,Info] = vbmeg_ard_estimate2(B, Gall, Gact, COV, vb_parm, Modelinit)
0010 %
0011 % -- Input
0012 % B{Nsession} : MEG magnetic field data matrix
0013 %                   size(B{i}) = [Nsensors(i), Tsample, Ntrials(i)]
0014 % Gall{Nsession}: Lead field for whole brain region as background activity
0015 %                   size(Gall{i}) = [Nsensors(i) Ndipole]
0016 %                  If Gall = [], the sparse-mode (focal dipoles only) is
0017 %                  executed.
0018 % Gact{Nsession} : Lead field (Current basis function)
0019 %                  for region with high estimation accuracy
0020 %                   size(Gact{i})  = [Nsensors(i) Ndipole]
0021 %                  If Gact = [], the iso-mode (no focal dipoles) is
0022 %                  executed.
0023 % COV{Nsession}  : Noise covariance matrix
0024 %                   size(COV{i})   = [Nsensors(i) Nsensors(i)]
0025 % vb_parm        : Constants parameter
0026 %
0027 % -- Optional input
0028 % Modelinit      : Initial Model parameters, if given.
0029 %      if not given, the variables vb_parm.a0 and vb_parm.v0 are used.
0030 %      The size of each variable in 'Modelinit' must be the same as that of
0031 %      the output 'Model'.
0032 %
0033 % -- Output
0034 % Model.a   : Current variance (1/alpha)
0035 % Model.v   : Current variance for background activity
0036 % Model.sx  : Observation noise variance
0037 % Info      : Containing convergence information such as 'FE','Ev','Err'
0038 %
0039 % -- Note
0040 % Brain noise : Background brain activity is assumed
0041 %               as independent Gaussian noise with same variance
0042 %
0043 % "Nvact" (# of vertex for active dipoles) and
0044 % "Nt"  (length of time window) are common for all sessions.
0045 %
0046 % "Nsensors" (# of sensor channels in a session) and
0047 % "Ntrial"  (# of trials in a session) could depend on each session.
0048 %
0049 % Note that ax0_j, ax0_z, and sx are defined as variance:
0050 % the inverse of inverse-variance (i.e having the diminsion of variance).
0051 % These notations are different from those in Sato-san's note.
0052 %
0053 % 2004-12-5 M. Sato
0054 % 2005-04-12 O.Yamashita
0055 % * The whole program are revised
0056 %   so that variables have consistent names with Sato-san's note.
0057 % * The noise variance, sx, ax_z and ax_j are independently handled.
0058 % * A field 'fix_parm' is added to 'vb_parm' struct.
0059 % * The input 'COV' must be given.
0060 % 2005-04-16 M. Sato
0061 % * Modified free-energy calculation
0062 % 2005-05-07 M. Sato
0063 % * Modified variance update (Faster convergence)
0064 % 2005-05-13 O.Yamashita
0065 % * 'Modelinit' is added as the 6th input optional argument.
0066 % * Iso-mode is implemented. (when Gact is [])
0067 % * A field 'fix_parm' is changed to two flag paramters, 'update_v' and
0068 % 'update_sx'.
0069 % * lower bound of 'gain_z'
0070 % 2005-05-16 O.Yamashita
0071 % * Sparse-mode is implementd. (when Gall is [])
0072 % * lower bound of 'gain_j'
0073 % 2005-05-31 O.Yamashita
0074 % * constant term of free energy
0075 % 2005-06-02 O.Yamashita
0076 % * constat term of free energy revised
0077 % 2005/08/22 O.Yamashita    --- ver.30b
0078 % * vb_parm is introduced
0079 % * Output argument is modified
0080 % * Minor bug fix
0081 % 2005/09/29 O. Yamashita
0082 % * Minor bug fix (when Norinet=2 & iso-mode)
0083 % 2006/04/16 M. Sato
0084 % * vb_parm.Npre_train is introduced
0085 %   VB-update equation is used for iteration <= Npre_train
0086 % 2006/08/20 M. Sato
0087 % * Soft normal constraint
0088 % 2008/06/20 M. Sato
0089 % * Only support focal model with normal current direction
0090 % * Do not support soft normal constraint
0091 % * prune operation is omitted since no pruning is done in Ta0>=10
0092 % 2012-02-08 taku-y
0093 % [debug] Fixed so that flag 'update_sx' correctly works when put the
0094 %         flag 'false'.
0095 %
0096 % ------------ Relative variance calculation ----------------
0097 % In this implementation,
0098 % Current variance 'ax_z' is defined as ratio w.r.t sx :
0099 % ax_z * sx = < Z*Z>
0100 %
0101 % Copyright (C) 2011, ATR All Rights Reserved.
0102 % License : New BSD License(see VBMEG_LICENSE.txt)
0103 
0104 global flag_pinv;
0105 if isempty(flag_pinv), flag_pinv = false; end
0106 
0107 Ntrain = vb_parm.Ntrain;
0108 Nskip  = vb_parm.Nskip;
0109 a_min  = vb_parm.a_min;
0110 a_max  = vb_parm.a_max;
0111 
0112 if isfield(vb_parm, 'Npre_train')
0113     % Number of iteration using original VB-update rule
0114     % for stable estimation
0115     Npre_train = vb_parm.Npre_train;
0116 else
0117     Npre_train = Ntrain;
0118 end
0119 
0120 fprintf('\n--- New VBMEG estimation program ---\n\n')
0121 fprintf('--- Initial VB-update iteration = %d\n',Npre_train)
0122 fprintf('--- Total update iteration      = %d\n',Ntrain)
0123 
0124 % Number parameters
0125 Ntrials     = vb_parm.Ntrials;    % Number of trials
0126 Nsensors    = vb_parm.Nsensors;   % Number of sensors
0127 Nsession    = vb_parm.Nsession;   % Number of sessions
0128 Nwindow     = vb_parm.Nwindow;    % Number of time windows
0129 Twindow     = vb_parm.Twindow;    % Time window index [Tstart, Tend]
0130 Ntotaltrials = sum(Ntrials);      % Total number of trials in all sessions
0131 Nerror  = sum(Ntrials.*Nsensors); % Total MEG channel used for error estimate
0132 
0133 Nvact       = vb_parm.Nvact;      % Number of active current vertices
0134 Njact       = vb_parm.Njact;      % Number of active dipoles parameters
0135 Njall       = vb_parm.Njall;      % Number of whole brain vertices (dipole)
0136                                   %        for background activity
0137 Norient     = vb_parm.Norient;    % Number of current orientation component
0138 Norient_var = vb_parm.Norient_var;% Number of current orientation component
0139                                   %   for variance estimation
0140 Nvaract     = Nvact * Norient_var;% Number of variance parameters
0141                                   % (=length(ax_z))
0142 Ratio = Norient/Norient_var;        % = (# of orientation)
0143                                   %   /(# of orientation in variance)
0144 
0145 % Set prior parameters
0146 Ta0 = vb_parm.Ta0 ;            % confidence parameter of focal
0147 sx0 = vb_parm.sx0;             % observation noise variance
0148 a0  = max(vb_parm.a0, a_min ); % initial variance of focal
0149 
0150 if exist('Modelinit','var') & ~isempty(Modelinit)
0151     a0s = max(Modelinit.a, a_min );
0152 else
0153     a0s = repmat(a0, [1 Nwindow]);
0154 end
0155 
0156 if isempty(Gact) 
0157     error('No leadlield is given')
0158 end
0159 
0160 % Flag cotroling updates of variance parameters
0161 update_sx = 1;
0162 if isfield(vb_parm, 'update_sx')
0163    update_sx = vb_parm.update_sx;
0164 end
0165 
0166 %%% calculate BB' and GG'
0167 for ns = 1 : Nsession
0168     for nw = 1 : Nwindow
0169         B0 = 0;
0170         for i = 1 : Ntrials(ns)
0171             Bt = B{ns}(:,Twindow(nw,1):Twindow(nw,2),i);
0172             B0 = B0 + Bt * Bt';
0173         end
0174         BBall{ns,nw} = B0;
0175     end
0176 end
0177     
0178 clear B Gall
0179 
0180 % Temporal variable for variance (ratio)
0181 aall  = zeros(Nvaract,Nwindow); % Modified by TY 2005-06-15
0182 sxall = zeros(Nwindow,1);
0183 
0184 % Free energy and Error history
0185 Info    = zeros(Nwindow*Ntrain, 6); 
0186 k_check = 0;
0187 
0188 %%%%% Time window loop %%%%%%%%
0189 for nw=1:Nwindow
0190     
0191     Nt  = Twindow(nw,2) - Twindow(nw,1) + 1;
0192     Ntt = Nt*Ntotaltrials; 
0193 
0194     % confidence parameters
0195     gx0_z = Ta0; 
0196     gx_z  = 0.5 * Ntt * Ratio + gx0_z; 
0197     
0198     % initial variance parameters
0199     ax0_z = a0s(:,nw) ; 
0200 
0201     % variance hyper-parameter initialization
0202     if vb_parm.cont_pr == OFF | nw == 1
0203         % use the prior variance as the initial expected value
0204         ax_z = ax0_z;
0205         sx   = sx0;
0206     else
0207         % use the estimated variance in the previous time window as the
0208         % initial expected value
0209         ix_zero = find(ax_z < a_min);  % indices taking small variance
0210         ax_z(ix_zero) = ax0_z(ix_zero);              
0211     end
0212        
0213     %%%%%% Estimation Loop %%%%%%
0214     for k=1:Ntrain 
0215         %%% VB E-step -- current estimation
0216         
0217         % Initialize averaging variable
0218         tr1_b   = 0;                %  (B - G*J)^2
0219         mean_zz = zeros(Nvaract,1); % diag(  Z*Z )
0220         gain_z  = zeros(Nvaract,1); % Estimation Gain for Z
0221         Hj      = 0;                % log(det(SB))
0222         
0223         for ns=1:Nsession
0224             Ntry  = Ntrials(ns);
0225             
0226             % Lead field for each session
0227             G  = Gact{ns};
0228             Gt = G';
0229         
0230             % MEG covariance for each session
0231             BB    = BBall{ns,nw};    % (B * B')
0232             
0233             % Noise covariance for each session
0234             Cov   = COV{ns};   % Sigma_G^{-1}
0235             
0236             if Ratio == 1
0237                 A_z = ax_z;
0238             else          
0239                 A_z = repmat(ax_z,Ratio,1);
0240             end
0241 
0242             % Inverse filter for each session
0243             % GSG   = G*(A_z.*Gt) ; % Ga * alpha *Ga'
0244             GSG     = G * vb_repmultiply(Gt, A_z) ; 
0245             SB      = GSG + Cov;        % Sigma_B
0246             if flag_pinv, SB_inv  = pinv( SB );       % Sigam_B^{-1}
0247             else SB_inv  = inv(SB); end
0248             GSB     = Gt*SB_inv;        % Ga' * Sigma_B^{-1}
0249             SBBS    = SB_inv*BB*SB_inv; % Sigma_B^{-1}*BB*Sigma_B^{-1}
0250             
0251             % Reconstraction error = sum( (B-G*J)'*Cov*(B-G*J) )
0252             tr1_b = tr1_b + sum(sum(Cov.*SBBS));
0253             
0254             % Current magnitude = diag( sum( Z * Z ) )
0255             mean_zz = mean_zz + ax_z ...
0256                    .* sum(reshape(sum((GSB*BB).*GSB,2),[Nvaract Ratio]),2);
0257             % Estimation gain for Z = diag(ax_z*Gt*SB_inv*G)
0258             gain_z  = gain_z + Nt*Ntry*ax_z ...
0259                    .* sum(reshape(sum(Gt.*GSB, 2),[Nvaract Ratio]), 2);
0260             
0261             Hj  = Hj - 0.5*Nt*Ntry*vb_log_det(SB);
0262         end;
0263         
0264         % Noise variance estimation
0265         sx_total = (tr1_b + sum(mean_zz));
0266         if update_sx, 
0267           sx  = sx_total/(Nt*Nerror);
0268         end
0269       
0270         %%%%% VB M-step -- variance update
0271         
0272         % Save old value for Free energy calculation
0273         ax_z_old = ax_z;
0274         
0275         % --- Active current variance estimation
0276         %
0277         % VB update equation
0278         %                    gx_z * ax_z = ...
0279         % (0.5*Ntt*Ratio + gx0_z) * ax_z = ...
0280         %     mean_zz_all + 0.5*ax_z.*( Ntt * Ratio - gain_z );
0281         
0282         mean_zz_all = 0.5*ax_z.*mean_zz/sx + gx0_z.*ax0_z;
0283       
0284         % Update the current variance parameters
0285         if k <= Npre_train
0286             % --- VB update rule
0287             % Active current variance estimation
0288                ax_z_total = mean_zz_all + 0.5*ax_z.*(Ntt*Ratio - gain_z);
0289                ax_z   = ax_z_total ./ gx_z;  
0290         else
0291             % --- Faster convergence update rule
0292             % Active current variance estimation
0293             gain_z = max(gain_z, eps); 
0294             ax_z = mean_zz_all ./ (0.5*gain_z + gx0_z);  
0295         end
0296         
0297         %%%%%%%% Free Energy Calculation %%%%%%%%
0298         % Model complexity for current variance
0299         rz  = ax0_z./ax_z_old;
0300         Haz = sum( gx0_z.*(log(rz) - rz + 1));
0301         
0302         LP = - 0.5*Nt*Nerror*log(sx);
0303 
0304         % Data Likelihood & Free Energy
0305         Ea  = - 0.5*( sx_total/sx - Nt*Nerror );
0306         FE  = LP + Hj + Ea + Haz ;
0307         Ev  = LP + Hj;  % evidence (without constants)
0308         Err = tr1_b/trace(BB);% Normalized Error
0309                    
0310         % DEBUG Info
0311         k_check = k_check + 1;
0312         Info(k_check,:) = [...
0313             FE, Ev, Hj, LP, Err, cond(SB)];
0314         % end of free energy calculation
0315                    
0316         if mod(k, Nskip)==0 || k==Ntrain 
0317             fprintf('Tn=%3d, Iter=%4d, FE=%f, Error=%e\n', ...
0318                     nw, k, FE, Err);
0319         end; 
0320         
0321     end % end of learning loop
0322 
0323     aall(:,nw) = ax_z;
0324     sxall(nw)  = sx;
0325 end % end of time window loop
0326 
0327 Model.a   = aall ; % Current variance
0328 Model.sx  = sxall;    % Sensor noise variance
0329 Model.v   = zeros(Nwindow,1);
0330 
0331 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0332 %%% functions specific for this function
0333 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0334 function FEconst = free_energy_constant(gx_j, gx0_j, gx_z, gx0_z,...
0335     Nt, Ntotaltrials, Nerror, Njall)
0336 
0337 gx_sx = 0.5*Nt*Nerror;
0338 L0 = 0.5*Nt*(Njall*Ntotaltrials*qG(gx_j) + Ntotaltrials*sum(qG(gx_z)) + Nerror*qG(gx_sx) - Nerror);
0339 H0 = HG0(gx_sx,0) + HG0(gx_j,gx0_j)+sum(HG0(gx_z, gx0_z));
0340 FEconst = L0+H0;
0341 
0342 %%%%%
0343 function y = qG(gx)
0344 N=length(gx);
0345 nz = find(gx ~= 0); % non-zero components
0346 y = zeros(N,1);
0347 y_tmp = psi(gx(nz)) - log(gx(nz));
0348 y(nz) = y_tmp;
0349 
0350 %%%%%%
0351 function y =HG0(gx, gx0)
0352 N=length(gx0);
0353 nz = find(gx ~= 0); % non-zero components
0354 nz0 = find(gx0 ~= 0); % non-zero components
0355 % gammma
0356 y1 = zeros(N,1);
0357 y1_tmp = gammaln(gx(nz)) - gx(nz).*psi(gx(nz)) + gx(nz); 
0358 y1(nz) = y1_tmp;
0359 % gamma0
0360 y2 = zeros(N,1);
0361 y2_tmp = gammaln(gx0(nz0)) - gx0(nz0).*log(gx0(nz0)) + gx0(nz0); 
0362 y2(nz0) = y2_tmp;
0363 % gamma*gamma0
0364 y3 = zeros(N,1);
0365 y3_tmp = gx0(nz0).*(psi(gx(nz0))-log(gx(nz0))); 
0366 y3(nz0) = y3_tmp;
0367 
0368 y = y1 - y2 + y3;
0369 
0370 
0371 
0372

Generated on Mon 22-May-2023 06:53:56 by m2html © 2005