Home > vbmeg > functions > estimation > bayes > dynamics > vbmeg_ard_estimate25_3.m

vbmeg_ard_estimate25_3

PURPOSE ^

Algorithm for jointly estimating source activities and interactions.

SYNOPSIS ^

function [Model,Info] =vbmeg_ard_estimate25_3(B, Gall, Gact, COV, vb_parm, Modelinit)

DESCRIPTION ^

 Algorithm for jointly estimating source activities and interactions. 
 (2015/03/09 M.Fukushima: computation faster than vbmeg_ard_estimate25_2)
 -- Syntax
 [Model,Info] = vbmeg_ard_estimate25_3(B, Gall, Gact, COV, vb_parm, Modelinit)

 2014/09/05 M.Fukushima
 * The original code is vbmeg_ard_estimate4.m
 * Expricitly compute the currents J
 * Change the way of calculation of tr1_b, mean_zz, gain_z, and Hj
 * Do not support multiple directions dipoles
 * Do not support multiple sessions data
 * Do not support multiple time windows
 * Do not support multiple trials data
 * Do not support extra dipoles
 * Do not support fast convergence rule
 * Compute the exact free energy at each iteration

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [Model,Info] = ...
0002          vbmeg_ard_estimate25_3(B, Gall, Gact, COV, vb_parm, Modelinit)
0003 % Algorithm for jointly estimating source activities and interactions.
0004 % (2015/03/09 M.Fukushima: computation faster than vbmeg_ard_estimate25_2)
0005 % -- Syntax
0006 % [Model,Info] = vbmeg_ard_estimate25_3(B, Gall, Gact, COV, vb_parm, Modelinit)
0007 %
0008 % 2014/09/05 M.Fukushima
0009 % * The original code is vbmeg_ard_estimate4.m
0010 % * Expricitly compute the currents J
0011 % * Change the way of calculation of tr1_b, mean_zz, gain_z, and Hj
0012 % * Do not support multiple directions dipoles
0013 % * Do not support multiple sessions data
0014 % * Do not support multiple time windows
0015 % * Do not support multiple trials data
0016 % * Do not support extra dipoles
0017 % * Do not support fast convergence rule
0018 % * Compute the exact free energy at each iteration
0019 
0020 global flag_pinv;
0021 if isempty(flag_pinv), flag_pinv = false; end
0022 
0023 Ntrain = vb_parm.Ntrain;
0024 Nskip  = vb_parm.Nskip;
0025 a_min  = vb_parm.a_min;
0026 a_max  = vb_parm.a_max;
0027 
0028 fprintf('\n--- New Dynamic estimation program ---\n\n')
0029 fprintf('--- Total update iteration      = %d\n',Ntrain)
0030 
0031 % Number parameters
0032 Ntrials     = vb_parm.Ntrials;    % Number of trials
0033 Nsensors    = vb_parm.Nsensors;   % Number of sensors
0034 Nsession    = vb_parm.Nsession;   % Number of sessions
0035 Nwindow     = vb_parm.Nwindow;    % Number of time windows
0036 Twindow     = vb_parm.Twindow;    % Time window index [Tstart, Tend]
0037 Ntotaltrials = sum(Ntrials);      % Total number of trials in all sessions
0038 Nerror  = sum(Ntrials.*Nsensors); % Total MEG channel used for error estimate
0039 
0040 Nvact       = vb_parm.Nvact;      % Number of active current vertices
0041 Njact       = vb_parm.Njact;      % Number of active dipoles parameters
0042 Njall       = vb_parm.Njall;      % Number of whole brain vertices (dipole)
0043                                   %        for background activity
0044 
0045 % Load structural connectivity matrix
0046 Ind = vb_parm.Ind;
0047 Delta = vb_parm.Delta;
0048 
0049 % Set prior parameters
0050 Ta0 = vb_parm.Ta0 ;            % confidence parameter of focal
0051 sx0 = vb_parm.sx0;             % observation noise variance
0052 a0  = max(vb_parm.a0, a_min ); % initial variance of focal
0053 
0054 if exist('Modelinit','var') & ~isempty(Modelinit)
0055     a0s = max(Modelinit.a, a_min );
0056 else
0057     a0s = repmat(a0, [1 Nwindow]);
0058 end
0059 
0060 if isempty(Gact) 
0061     error('No leadlield is given')
0062 end
0063 
0064 % Flag cotroling updates of variance parameters
0065 update_sx = 1;
0066 if isfield(vb_parm, 'update_sx')
0067    update_sx = vb_parm.update_sx;
0068 end
0069 
0070 %%% calculate BB' and GG'
0071 for ns = 1 : Nsession
0072     for nw = 1 : Nwindow
0073         B0 = 0;
0074         Btr = B{ns}(:,Twindow(nw,1):Twindow(nw,2),:);
0075         for i = 1 : Ntrials(ns)
0076             Bt = B{ns}(:,Twindow(nw,1):Twindow(nw,2),i);
0077             B0 = B0 + Bt * Bt';
0078         end
0079         Ball{ns,nw} = Btr;
0080         BBall{ns,nw} = B0;
0081     end
0082 end
0083 
0084 % clear B Gall
0085 
0086 Jall  = cell(Nsession, Nwindow);
0087 J     = cell(Nsession,1);
0088 J0    = cell(Nsession,1);
0089 Vj     = cell(Nsession,1);
0090 VjT    = cell(Nsession,1);
0091 VjT_full = cell(Nsession,1);
0092 
0093 % Temporal variable for variance (ratio)
0094 aall  = zeros(Nvact,Nwindow);
0095 sxall = zeros(Nwindow,1);
0096 
0097 % Free energy and Error history
0098 Info    = zeros(Nwindow*Ntrain, 8);
0099 k_check = 0;
0100 
0101 tic;
0102 %%%%% Time window loop %%%%%%%%
0103 for nw=1:Nwindow
0104     
0105     Nt  = Twindow(nw,2) - Twindow(nw,1) + 1;
0106     Ntt = Nt*Ntotaltrials; 
0107 
0108     % confidence parameters
0109     gx0_z = Ta0; 
0110     gx_z  = 0.5 * Ntt + gx0_z; 
0111     
0112     % initial variance parameters
0113     ax0_z = a0s(:,nw); 
0114 
0115     % variance hyper-parameter initialization
0116     if vb_parm.cont_pr == OFF | nw == 1
0117         % use the prior variance as the initial expected value
0118         ax_z = vb_parm.a;
0119         sx   = sx0;
0120     else
0121         % use the estimated variance in the previous time window as the
0122         % initial expected value
0123         ix_zero = find(ax_z < a_min);  % indices taking small variance
0124         ax_z(ix_zero) = ax0_z(ix_zero);              
0125     end
0126     
0127     % For computing AtQA and AdtQA_V in forward_pass
0128     [Delta_set, Ndelta, indx, deltax, sum_ind, nl_indx, NL_indx] = ...
0129       init_index(Delta);
0130     AtQA0      = init_AtQA(Nvact, Ndelta, NL_indx);
0131     AdtQA_V0   = init_AdtQA_V(Nvact, Ndelta, NL_indx);
0132     sp_indx    = get_AtQA_index(Nvact, Ndelta, NL_indx, AtQA0);
0133     sp_dl_indx = get_AdtQA_V_index(Nvact, Ndelta, NL_indx, AdtQA_V0);
0134     
0135     % Initialization for estimating Q(A)
0136     mar = cell(Nvact,1);
0137     MAR = 0*speye(Nvact);
0138     Vmar = cell(Nvact,1);
0139     for nv = 1:Nvact
0140       mar{nv}  = 0 * ones(sum_ind(nv),1);
0141       Vmar{nv} = 0 * eye(sum_ind(nv));
0142     end
0143     
0144     % Initialization for estimating Q(\eta_{1:N})
0145     ieta0 = cell(Nvact,1);
0146     gx0_e = cell(Nvact,1);
0147     gx_e  = cell(Nvact,1);
0148     for nv = 1:Nvact
0149       ieta0{nv} = vb_parm.ieta0 * ones(sum_ind(nv),1);
0150       gx0_e{nv} = vb_parm.g0 * ones(sum_ind(nv),1);
0151       gx_e{nv} = gx0_e{nv} + 0.5;
0152     end
0153     ieta = ieta0;
0154     
0155     % Labeling time stamps for sumVj1 calculation
0156     vfreq = fliplr(-diff([(Nt-Delta_set)' 0]));
0157     cum_vfreq = [0 cumsum(vfreq)];
0158     tsp = repmat(1:Nt,size(Delta_set)) - repmat(Delta_set,[1 Nt]);
0159     vlb = zeros(size(tsp));
0160     vct = [];
0161     for nl = 1:Ndelta
0162       vlb = vlb + nl*(cum_vfreq(nl)<tsp).*(tsp<=cum_vfreq(nl+1));
0163     end
0164     for nl = 1:Ndelta
0165       vct = [vct sum(vlb==nl,2)];
0166     end
0167     
0168     %%%%%% Estimation Loop %%%%%%
0169     for k=1:(Ntrain+1)
0170         %%% VB E-step -- current estimation
0171 
0172         % Initialize averaging variable
0173         tr1_b = 0;
0174         Hj = [0 0];
0175         
0176         for ns=1:Nsession
0177             Ntry  = Ntrials(ns);
0178             Nsensor = Nsensors(ns);
0179             
0180             % Lead field for each session
0181             G  = Gact{ns};
0182             Gt = G';
0183         
0184             % MEG covariance for each session
0185             B     = Ball{ns,nw};
0186             BB    = BBall{ns,nw};
0187             
0188             % Noise covariance for each session
0189             Cov   = COV{ns};
0190             
0191             % Initial current
0192             if nw == 1
0193               J0{ns} = zeros(Nvact, Delta_set(end), Ntry);
0194             else
0195               J0{ns} = Jall{ns,nw-1}(:,Nt-Delta_set(end)+1:Nt,:); 
0196             end
0197             
0198             if k == 1, prevJ = zeros(Nvact, Nt, Ntry);
0199             else prevJ = J{ns}; end
0200             
0201             % Estimate current
0202             [J{ns}, Vj{ns}, VjT{ns}, GJ, Qj, VjT_full{ns}] = current_estimation ...
0203               (B, G, Gt, ax_z, Cov, J0{ns}, prevJ, mar, Vmar, ...
0204               Nvact, Nsensor, Nt, Ntry, flag_pinv, indx, deltax, Delta_set, sum_ind, ...
0205               nl_indx, NL_indx, AtQA0, AdtQA_V0, sp_indx, sp_dl_indx);
0206             
0207             % Reconstraction error
0208             BGJ = reshape(B-GJ, Nsensor, Nt*Ntry);
0209             tr1_b = tr1_b + sum(repdiag( BGJ, inv(Cov)*BGJ));
0210                    
0211             Hj = Hj + 0.5*Qj;
0212         end;
0213         
0214         %%%%% VB M-step -- parameter update
0215         %% Sufficient statistics for Q(\beta) calculation
0216         % Initialize
0217         sumJ11 = zeros(Nvact,1);
0218         sumVj1 = zeros(Nvact,1);
0219         sumJ12 = zeros(Nvact,1);
0220         MARJ1  = zeros(Nvact,Nt*Ntrials(1));
0221         
0222         J22 = reshape(J{1}.^2, Nvact, Nt*Ntrials(1));
0223         sumJ22 = sum(J22,2);
0224         sumVj2 = Ntrials(1)*(Nt-sum(vfreq))*VjT{1};
0225         for nl = 1:Ndelta
0226           sumVj2 = sumVj2 + Ntrials(1)*vfreq(nl)*diag(Vj{1}{nl});
0227         end
0228         
0229         % Target currents
0230         J2 = reshape(J{1}, Nvact, Nt*Ntrials(1));
0231         
0232         Jtmp = [];
0233         t_all = [];
0234         for ntr = 1:Ntrials(1)
0235           % Concatenate J
0236           Jtmp = [Jtmp J0{1}(:,:,ntr) J{1}(:,:,ntr)];
0237           t_all = [t_all (1:Nt)+Delta_set(end) + (Nt+Delta_set(end))*(ntr-1)];
0238         end
0239         
0240         J1   = cell(Nvact,1);
0241         J1J1 = cell(Nvact,1);
0242         Vj1  = cell(Nvact,1);
0243         if k==1; mar_Vmar = cell(Nvact,1); end
0244         for nv = 1:Nvact
0245           % Seed Currents
0246           J1{nv} = select_entries(Jtmp, repmat(indx{nv},[1 Nt*Ntrials(1)]), ...
0247             repmat(t_all,size(deltax{nv}))-repmat(deltax{nv},[1 Nt*Ntrials(1)]));
0248           J1J1{nv} = J1{nv}*J1{nv}';
0249           
0250           Nind_nv = length(deltax{nv});
0251           Vj1{nv} = zeros(Nind_nv); % Must not be sparse (must be full)
0252           for nld = 1:Ndelta
0253             indlb = (deltax{nv}==Delta_set(nld));
0254             sum_indlb = sum(indlb);
0255             if sum_indlb~=0
0256               Vjtmp = zeros(sum_indlb);
0257               nl = 1;
0258               while nl <= Ndelta-nld+1
0259                 Vjtmp = Vjtmp + ...
0260                   vct(nld,nl)*Vj{1}{nl}(indx{nv}(indlb),indx{nv}(indlb));
0261                 nl = nl + 1;
0262               end
0263               Vj1{nv}(indlb,indlb) = Vj1{nv}(indlb,indlb) + Ntrials(1)*Vjtmp;
0264             end
0265           end
0266           
0267           if k==1; mar_Vmar{nv} = mar{nv}*mar{nv}' + Vmar{nv}; end
0268           MARJ1(nv,:) = mar{nv}'*J1{nv};
0269           sumJ11(nv) = sumJ11(nv) + sum(repdiag(mar_Vmar{nv},J1J1{nv}));
0270           sumVj1(nv) = sumVj1(nv) + Ntrials(1)*sum(repdiag(mar_Vmar{nv},Vj1{nv}));
0271         end
0272         sumJ12 = sumJ12 + repdiag(J2',(MARJ1)');
0273         
0274         mean_zz = sumJ22 - 2*sumJ12 + sumJ11;
0275         mean_Vj = sumVj2 + sumVj1;
0276         
0277         % Noise variance estimation
0278         sx_total = tr1_b + sum(mean_zz./ax_z);
0279         if update_sx, 
0280           sx  = sx_total/(Nt*Nerror);
0281         end
0282         
0283         Hj(1) = Hj(1) - 0.5*tr1_b/sx;
0284         Hj(2) = Hj(2) - 0.5*sum(mean_zz./ax_z)/sx - 0.5*sum(mean_Vj./ax_z);
0285         
0286         if k==(Ntrain+1)
0287           %% Skip the rest of updates
0288           % Save old value for Free energy calculation
0289           mar_old = mar;
0290           Vmar_old = Vmar;
0291           ieta_old = ieta;
0292           ax_z_old = ax_z;
0293         else
0294           %% Q(A) calculation
0295           % Save old value for Free energy calculation
0296           mar_old = mar;
0297           Vmar_old = Vmar;
0298 
0299           for nv = 1:Nvact
0300             ax_z_invVmar = (J1J1{nv}/sx + Vj1{nv}) + ax_z(nv)*diag(1./ieta{nv});
0301             J1J2_beta = (J1{nv}*J2(nv,:)'/sx);
0302 
0303             inv_ax_z_invVmar = inv(ax_z_invVmar);
0304             Vmar{nv} = ax_z(nv)*inv_ax_z_invVmar;
0305             mar{nv} = inv_ax_z_invVmar*J1J2_beta;
0306             MAR(nv,indx{nv}) = mar{nv}';
0307           end
0308 
0309           if k == 1; MAR0 = MAR; end
0310 
0311           %% Q(\eta_{1:N}) calculation
0312           % Save old value for Free energy calculation
0313           ieta_old = ieta;
0314           
0315           for nv = 1:Nvact
0316             ieta{nv} = (gx0_e{nv}.*ieta0{nv} + ...
0317               0.5*(mar{nv}.^2 + diag(Vmar{nv})))./gx_e{nv};
0318           end
0319 
0320           %% Sufficient statistics for Q(q) calculation
0321           % Initialize
0322           sumJ11 = zeros(Nvact,1);
0323           sumVj1 = zeros(Nvact,1);
0324           sumJ12 = zeros(Nvact,1);
0325           MARJ1  = zeros(Nvact,Nt*Ntrials(1));
0326           
0327           mar_Vmar = cell(Nvact,1);
0328           for nv = 1:Nvact
0329             mar_Vmar{nv} = mar{nv}*mar{nv}' + Vmar{nv};
0330             MARJ1(nv,:) = mar{nv}'*J1{nv};
0331             sumJ11(nv) = sumJ11(nv) + sum(repdiag(mar_Vmar{nv},J1J1{nv}));
0332             sumVj1(nv) = sumVj1(nv) + Ntrials(1)*sum(repdiag(mar_Vmar{nv},Vj1{nv}));
0333           end
0334           sumJ12 = sumJ12 + repdiag(J2',(MARJ1)');
0335 
0336           mean_zz = sumJ22 - 2*sumJ12 + sumJ11;
0337           mean_Vj = sumVj2 + sumVj1;
0338 
0339           % Save old value for Free energy calculation
0340           ax_z_old = ax_z;
0341 
0342           % Current noise variance estimation
0343           ax_z = (gx0_z.*ax0_z + 0.5*(mean_zz/sx + mean_Vj))./gx_z;
0344         end
0345         
0346         %%%%%%%% Free Energy Calculation %%%%%%%%
0347         rz  = ax0_z./ax_z_old;
0348         Haz = sum(gx0_z.*(log(rz) - rz + 1));
0349         
0350         Hmar = 0;
0351         Hieta = 0;
0352         mar_vmar_ieta_old = zeros(Nvact,1);
0353         for nv = 1:Nvact
0354           if isempty(indx{nv}) == 0
0355             Hmar = Hmar - sum(log(ieta_old{nv})) + vb_log_det(Vmar_old{nv});
0356             mar_vmar_ieta_old(nv) = ...
0357               sum((mar_old{nv}.^2+diag(Vmar_old{nv}))./ieta_old{nv});
0358           end
0359           re = ieta0{nv}./ieta_old{nv};
0360           Hieta = Hieta + sum(gx0_e{nv}.*(log(re) - re + 1));
0361         end
0362         Hmar = 0.5*(Hmar-sum(mar_vmar_ieta_old));
0363         
0364         Hj(1) = Hj(1) - 0.5*Nt*Nerror*log(sx);
0365 
0366         [L1c L2c Hbc HAc Hec Hqc difbeta] = free_energy_constant...
0367           (gx_z, gx0_z, gx_e, gx0_e, Nt, Ntotaltrials, Nerror, Nvact, Ind);
0368         
0369         % Log-likelihood & Penalty terms
0370         Hj(1) = Hj(1) + L1c;
0371         Hj(2) = Hj(2) + L2c + difbeta;
0372         Hmar = Hmar + HAc;
0373         Hieta = Hieta + Hec;
0374         Haz = Haz + Hqc;
0375         
0376         % Free energy
0377         FE  = sum(Hj) + Hbc + Hmar + Hieta + Haz;
0378         
0379         % Normalized Error
0380         Err = tr1_b/sum(sum(BB));
0381         
0382         if k > 1
0383           % DEBUG Info
0384           k_check = k_check + 1;
0385           Info(k_check,:) = [FE, Hj(1), Hj(2), Hbc, Hmar, Hieta, Haz, Err];
0386 
0387           if mod(k-1, Nskip) == 0
0388             fprintf('Tn=%3d, Iter=%4d, FE=%f, Error=%e\n', nw, k-1, FE, Err);
0389           end
0390         end
0391     end % end of learning loop
0392 
0393     for ns = 1:Nsession
0394       Jall{ns,nw} = J{ns};
0395     end
0396     aall(:,nw) = ax_z;
0397     sxall(nw)  = sx;
0398 end % end of time window loop
0399 toc;
0400 
0401 % Save estimates
0402 Model.a   = aall ; % Current noize variance
0403 Model.sx  = sxall; % Observation noise variance
0404 Model.v   = zeros(Nwindow,1);
0405 Model.Z   = Jall; % current sources
0406 Model.Vz  = Vj; % source covariances
0407 clear Vj
0408 for ns = 1:Nsession
0409   Model.Vz{ns}{end+1}  = VjT_full{ns}; % source covariances near the end
0410 end
0411 % # of time samples for each source covariance
0412 Model.vz_Nsample = [vfreq (Nt-sum(vfreq))];
0413 Model.mar  = mar;  % MAR coefficients
0414 Model.MAR  = MAR;  % MAR matrix
0415 Model.MAR0 = MAR0; % MAR matrix at the 1st iteration
0416 for nv = 1:Nvact
0417   Model.Vmar{nv} = Vmar{nv}; % Covariances of MAR coefficients
0418 end
0419 Model.ieta = ieta; % Variance parameter of MAR coefficients
0420 
0421 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0422 %%% functions specific for this function
0423 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0424 function [J, Vj, VjT, GJ, Qj, VjT_full] = current_estimation ...
0425   (B, G, Gt, ax_z, Cov, J0, prevJ, mar, Vmar, ...
0426     Nvact, Nsensor, Nt, Ntry, flag_pinv, indx, deltax, Delta_set, sum_ind, ...
0427     nl_indx, NL_indx, AtQA0, AdtQA_V0, sp_indx, sp_dl_indx)
0428             
0429 % Forward passing algorithm
0430 [J, Vj, VjT, Qj, VjT_full] = forward_pass ...
0431   (B, G, Gt, ax_z, Cov, J0, prevJ, mar, Vmar, ...
0432     Nvact, Nt, Ntry, flag_pinv, indx, deltax, Delta_set, sum_ind, ...
0433     nl_indx, NL_indx, AtQA0, AdtQA_V0, sp_indx, sp_dl_indx);
0434 
0435 % Reconstructed magnetic fields
0436 GJ = G * reshape(J, [Nvact, Nt * Ntry]); 
0437 GJ = reshape(GJ, [Nsensor, Nt, Ntry]);
0438 
0439 %%%%%
0440 function [Jfilt, Vj, VjT, Qj, VjT_full] = forward_pass ...
0441   (B, G, Gt, ax_z, Cov, J0, prevJ, mar, Vmar, ...
0442     Nvact, Nt, Ntry, flag_pinv, indx, deltax, Delta_set, sum_ind, ...
0443     nl_indx, NL_indx, AtQA0, AdtQA_V0, sp_indx, sp_dl_indx)
0444  
0445 % Initialize
0446 Jfilt = zeros(Nvact, Nt, Ntry);
0447 Ndelta = length(Delta_set);
0448 AtQ = cell(Ndelta,1);
0449 A = cell(Ndelta,1);
0450 Vj = cell(Ndelta,1);
0451 for nl = 1:Ndelta
0452   AtQ{nl} = spalloc(Nvact, Nvact, Nvact);
0453   A{nl} = spalloc(Nvact, Nvact, Nvact);
0454   Vj{nl} = zeros(Nvact);
0455 end
0456 
0457 % Compute products of model parameters
0458 for nl = 1:Ndelta
0459   for nv = 1:Nvact
0460     if ~isempty(nl_indx{nv,nl})
0461       A{nl}(nv,NL_indx{nv,nl}) = mar{nv}(nl_indx{nv,nl});
0462     end
0463   end
0464   AtQ{nl} = (sparse(diag(1./ax_z))*A{nl})';
0465 end
0466 AtQA = set_AtQA(Nvact, Ndelta, ax_z, mar, Vmar, ...
0467         nl_indx, NL_indx, sp_indx, AtQA0);
0468 AdtQA_V = set_AdtQA_V(Nvact, Ndelta, ax_z, mar, Vmar, ...
0469         nl_indx, NL_indx, sp_dl_indx, AdtQA_V0);
0470 
0471 % Estimate current covariances
0472 P  = cell(Ndelta+1,1);
0473 Ka = cell(Ndelta+1,1);
0474 Nbound = Inf;
0475 for t = 1:Nt-1
0476   if Nbound ~= sum((t + Delta_set) <= Nt) % If Vj has to be updated
0477     Nbound = sum((t + Delta_set) <= Nt);
0478 
0479     sumAtQA = spalloc(Nvact, Nvact, Nvact);
0480     for nl = 1:Nbound
0481       sumAtQA = sumAtQA + AtQA{nl};
0482     end
0483     QplusAtQA = diag(1./ax_z) + sumAtQA;
0484     P{Ndelta-Nbound+1} = inv(QplusAtQA);
0485     Ka{Ndelta-Nbound+1} = ...
0486       calc_kalman_gain2(G, Gt, P{Ndelta-Nbound+1}, Cov, flag_pinv);
0487 
0488     if Nbound ~= 0
0489       % Current covariances
0490       Vj{Ndelta-Nbound+1} = ...
0491         P{Ndelta-Nbound+1} - Ka{Ndelta-Nbound+1}*(G*P{Ndelta-Nbound+1});
0492     end
0493   end
0494 end
0495 PT = ax_z;
0496 KaT = calc_kalman_gain(G, Gt, PT, Cov, flag_pinv);
0497 % Current covariance for Nt
0498 VjT_full = diag(PT) - KaT*(G*diag(PT));
0499 VjT = diag(VjT_full);
0500 
0501 % Store (indx, deltax, mar) to (Indx, Deltax, MARt)
0502 Indx   = zeros(Nvact, max(sum_ind));
0503 Deltax = zeros(Nvact, max(sum_ind));
0504 MARt   = zeros(Nvact, max(sum_ind));
0505 for nv = 1:Nvact
0506   Indx(nv,1:size(indx{nv},1))     = indx{nv}';
0507   Deltax(nv,1:size(deltax{nv},1)) = deltax{nv}';
0508   MARt(nv,1:size(mar{nv},1))     = mar{nv}';
0509 end
0510 
0511 % Estimate currents
0512 for ntr = 1:Ntry
0513   Nbound = Inf;
0514   for t = 1:Nt-1
0515     t_end = t + Delta_set(end);
0516     
0517     if Nbound ~= sum((t + Delta_set) <= Nt) % If Vj has to be updated
0518       Nbound = sum((t + Delta_set) <= Nt);
0519     end
0520     
0521     % Concatenate J
0522     Jtmp = [J0(:,:,ntr) Jfilt(:,1:t,ntr) prevJ(:,t+1:end,ntr)];
0523     % Forward model prediction
0524     AJfilt = ...
0525       atimesj(MARt,Indx,Deltax,sum_ind,Jtmp(:,t_end:-1:(t_end-Delta_set(end))));
0526     
0527     % Sum the products of model parameters
0528     saqj = zeros(Nvact,1);
0529     for nld = 1:Nbound
0530       saj = zeros(Nvact,1);
0531       adqsa_v_j = zeros(Nvact,1);
0532       for nl = 1:Ndelta
0533         if nl ~= nld
0534           t_tmp = t_end + Delta_set(nld) - Delta_set(nl);
0535           saj = saj + A{nl}*Jtmp(:,t_tmp);
0536           adqsa_v_j = adqsa_v_j + AdtQA_V{nld,nl}*Jtmp(:,t_tmp);
0537         end
0538       end
0539       saqj = saqj + AtQ{nld}* ...
0540         (Jtmp(:,t_end + Delta_set(nld)) - saj) - adqsa_v_j;
0541     end
0542     
0543     % Estimate currents
0544     Jpred = P{Ndelta-Nbound+1}*(saqj + spdiags(1./ax_z,0,Nvact,Nvact)*AJfilt);
0545     Jfilt(:,t,ntr) = Jpred + Ka{Ndelta-Nbound+1}*(B(:,t,ntr) - G*Jpred);
0546   end
0547   
0548   % t = Nt
0549   Nt_end = Nt + Delta_set(end);
0550   
0551   % Concatenate J
0552   Jtmp = [J0(:,:,ntr) Jfilt(:,:,ntr)];
0553   % Forward model prediction
0554   AJfilt = ...
0555     atimesj(MARt,Indx,Deltax,sum_ind,Jtmp(:,Nt_end:-1:(Nt_end-Delta_set(end))));
0556   
0557   % Estimate currents
0558   Jpred = AJfilt;
0559   Jfilt(:,Nt,ntr) = Jpred + KaT*(B(:,Nt,ntr) - G*Jpred);
0560 end
0561 
0562 % Computing sufficient statics for free energy calculation
0563 vfreq = fliplr(-diff([(Nt-Delta_set)' 0]));
0564 invCov = inv(Cov);
0565 
0566 sumlogVj = 0;
0567 sumTrGVjGtiC = 0;
0568 for nl = 1:Ndelta
0569   sumlogVj = sumlogVj + vfreq(nl)/2*vb_log_det(Vj{nl});
0570   sumTrGVjGtiC = sumTrGVjGtiC + vfreq(nl)/2*sum(repdiag(G*Vj{nl}*Gt,invCov));
0571 end
0572 
0573 Qj(1) = 2*Ntry*(-Nt/2*vb_log_det(Cov)...
0574   - sumTrGVjGtiC...
0575   - (Nt-sum(vfreq))/2*sum(repdiag(G*VjT_full*Gt,invCov)));
0576 
0577 Qj(2) = 2*Ntry*(...
0578   - Nt/2*sum(log(ax_z))...
0579   + sumlogVj...
0580   + (Nt-sum(vfreq))/2*vb_log_det(VjT_full));
0581 
0582 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0583 %%% functions for calculating constant terms in free energy
0584 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0585 function [L1c L2c Hbc HAc Hec Hqc difbeta] = free_energy_constant...
0586   (gx_z, gx0_z, gx_e, gx0_e, Nt, Ntotaltrials, Nerror, Nvact, Ind)
0587 
0588 gx_sx = 0.5*Nt*Nerror;
0589 L1c = 0.5*Nt*(Nerror*qG(gx_sx));
0590 L2c = 0.5*Nt*(Ntotaltrials*sum(qG(gx_z)));
0591 Hbc = HG0(gx_sx,0);
0592 Hqc = sum(HG0(gx_z, gx0_z));
0593 HAc = 0;
0594 Hec = 0;
0595 for nv = 1:Nvact
0596   HAc = HAc + 0.5*sum(qG(gx_e{nv}));
0597   Hec = Hec + sum(HG0(gx_e{nv}, gx0_e{nv}));
0598 end
0599 L1c = L1c + 0.5*(-Nt*Nerror*log(2*pi));
0600 L2c = L2c + 0.5*(Nvact*Nt*Ntotaltrials);
0601 HAc = HAc + 0.5*(sum(Ind(:)));
0602 difbeta =  0.5*Nt*Nvact*qG(gx_sx);
0603 
0604 %%%%%
0605 function y = qG(gx)
0606 N=length(gx);
0607 nz = find(gx ~= 0); % non-zero components
0608 y = zeros(N,1);
0609 y_tmp = psi(gx(nz)) - log(gx(nz));
0610 y(nz) = y_tmp;
0611 
0612 %%%%%%
0613 function y =HG0(gx, gx0)
0614 N=length(gx0);
0615 nz = find(gx ~= 0); % non-zero components
0616 nz0 = find(gx0 ~= 0); % non-zero components
0617 % gammma
0618 y1 = zeros(N,1);
0619 y1_tmp = gammaln(gx(nz)) - gx(nz).*psi(gx(nz)) + gx(nz); 
0620 y1(nz) = y1_tmp;
0621 % gamma0
0622 y2 = zeros(N,1);
0623 y2_tmp = gammaln(gx0(nz0)) - gx0(nz0).*log(gx0(nz0)) + gx0(nz0); 
0624 y2(nz0) = y2_tmp;
0625 % gamma*gamma0
0626 y3 = zeros(N,1);
0627 y3_tmp = gx0(nz0).*(psi(gx(nz0))-log(gx(nz0))); 
0628 y3(nz0) = y3_tmp;
0629 
0630 y = y1 - y2 + y3;

Generated on Mon 22-May-2023 06:53:56 by m2html © 2005