0001 function [Model,Info] = ...
0002 vbmeg_ard_estimate25_3(B, Gall, Gact, COV, vb_parm, Modelinit)
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020 global flag_pinv;
0021 if isempty(flag_pinv), flag_pinv = false; end
0022
0023 Ntrain = vb_parm.Ntrain;
0024 Nskip = vb_parm.Nskip;
0025 a_min = vb_parm.a_min;
0026 a_max = vb_parm.a_max;
0027
0028 fprintf('\n--- New Dynamic estimation program ---\n\n')
0029 fprintf('--- Total update iteration = %d\n',Ntrain)
0030
0031
0032 Ntrials = vb_parm.Ntrials;
0033 Nsensors = vb_parm.Nsensors;
0034 Nsession = vb_parm.Nsession;
0035 Nwindow = vb_parm.Nwindow;
0036 Twindow = vb_parm.Twindow;
0037 Ntotaltrials = sum(Ntrials);
0038 Nerror = sum(Ntrials.*Nsensors);
0039
0040 Nvact = vb_parm.Nvact;
0041 Njact = vb_parm.Njact;
0042 Njall = vb_parm.Njall;
0043
0044
0045
0046 Ind = vb_parm.Ind;
0047 Delta = vb_parm.Delta;
0048
0049
0050 Ta0 = vb_parm.Ta0 ;
0051 sx0 = vb_parm.sx0;
0052 a0 = max(vb_parm.a0, a_min );
0053
0054 if exist('Modelinit','var') & ~isempty(Modelinit)
0055 a0s = max(Modelinit.a, a_min );
0056 else
0057 a0s = repmat(a0, [1 Nwindow]);
0058 end
0059
0060 if isempty(Gact)
0061 error('No leadlield is given')
0062 end
0063
0064
0065 update_sx = 1;
0066 if isfield(vb_parm, 'update_sx')
0067 update_sx = vb_parm.update_sx;
0068 end
0069
0070
0071 for ns = 1 : Nsession
0072 for nw = 1 : Nwindow
0073 B0 = 0;
0074 Btr = B{ns}(:,Twindow(nw,1):Twindow(nw,2),:);
0075 for i = 1 : Ntrials(ns)
0076 Bt = B{ns}(:,Twindow(nw,1):Twindow(nw,2),i);
0077 B0 = B0 + Bt * Bt';
0078 end
0079 Ball{ns,nw} = Btr;
0080 BBall{ns,nw} = B0;
0081 end
0082 end
0083
0084
0085
0086 Jall = cell(Nsession, Nwindow);
0087 J = cell(Nsession,1);
0088 J0 = cell(Nsession,1);
0089 Vj = cell(Nsession,1);
0090 VjT = cell(Nsession,1);
0091 VjT_full = cell(Nsession,1);
0092
0093
0094 aall = zeros(Nvact,Nwindow);
0095 sxall = zeros(Nwindow,1);
0096
0097
0098 Info = zeros(Nwindow*Ntrain, 8);
0099 k_check = 0;
0100
0101 tic;
0102
0103 for nw=1:Nwindow
0104
0105 Nt = Twindow(nw,2) - Twindow(nw,1) + 1;
0106 Ntt = Nt*Ntotaltrials;
0107
0108
0109 gx0_z = Ta0;
0110 gx_z = 0.5 * Ntt + gx0_z;
0111
0112
0113 ax0_z = a0s(:,nw);
0114
0115
0116 if vb_parm.cont_pr == OFF | nw == 1
0117
0118 ax_z = vb_parm.a;
0119 sx = sx0;
0120 else
0121
0122
0123 ix_zero = find(ax_z < a_min);
0124 ax_z(ix_zero) = ax0_z(ix_zero);
0125 end
0126
0127
0128 [Delta_set, Ndelta, indx, deltax, sum_ind, nl_indx, NL_indx] = ...
0129 init_index(Delta);
0130 AtQA0 = init_AtQA(Nvact, Ndelta, NL_indx);
0131 AdtQA_V0 = init_AdtQA_V(Nvact, Ndelta, NL_indx);
0132 sp_indx = get_AtQA_index(Nvact, Ndelta, NL_indx, AtQA0);
0133 sp_dl_indx = get_AdtQA_V_index(Nvact, Ndelta, NL_indx, AdtQA_V0);
0134
0135
0136 mar = cell(Nvact,1);
0137 MAR = 0*speye(Nvact);
0138 Vmar = cell(Nvact,1);
0139 for nv = 1:Nvact
0140 mar{nv} = 0 * ones(sum_ind(nv),1);
0141 Vmar{nv} = 0 * eye(sum_ind(nv));
0142 end
0143
0144
0145 ieta0 = cell(Nvact,1);
0146 gx0_e = cell(Nvact,1);
0147 gx_e = cell(Nvact,1);
0148 for nv = 1:Nvact
0149 ieta0{nv} = vb_parm.ieta0 * ones(sum_ind(nv),1);
0150 gx0_e{nv} = vb_parm.g0 * ones(sum_ind(nv),1);
0151 gx_e{nv} = gx0_e{nv} + 0.5;
0152 end
0153 ieta = ieta0;
0154
0155
0156 vfreq = fliplr(-diff([(Nt-Delta_set)' 0]));
0157 cum_vfreq = [0 cumsum(vfreq)];
0158 tsp = repmat(1:Nt,size(Delta_set)) - repmat(Delta_set,[1 Nt]);
0159 vlb = zeros(size(tsp));
0160 vct = [];
0161 for nl = 1:Ndelta
0162 vlb = vlb + nl*(cum_vfreq(nl)<tsp).*(tsp<=cum_vfreq(nl+1));
0163 end
0164 for nl = 1:Ndelta
0165 vct = [vct sum(vlb==nl,2)];
0166 end
0167
0168
0169 for k=1:(Ntrain+1)
0170
0171
0172
0173 tr1_b = 0;
0174 Hj = [0 0];
0175
0176 for ns=1:Nsession
0177 Ntry = Ntrials(ns);
0178 Nsensor = Nsensors(ns);
0179
0180
0181 G = Gact{ns};
0182 Gt = G';
0183
0184
0185 B = Ball{ns,nw};
0186 BB = BBall{ns,nw};
0187
0188
0189 Cov = COV{ns};
0190
0191
0192 if nw == 1
0193 J0{ns} = zeros(Nvact, Delta_set(end), Ntry);
0194 else
0195 J0{ns} = Jall{ns,nw-1}(:,Nt-Delta_set(end)+1:Nt,:);
0196 end
0197
0198 if k == 1, prevJ = zeros(Nvact, Nt, Ntry);
0199 else prevJ = J{ns}; end
0200
0201
0202 [J{ns}, Vj{ns}, VjT{ns}, GJ, Qj, VjT_full{ns}] = current_estimation ...
0203 (B, G, Gt, ax_z, Cov, J0{ns}, prevJ, mar, Vmar, ...
0204 Nvact, Nsensor, Nt, Ntry, flag_pinv, indx, deltax, Delta_set, sum_ind, ...
0205 nl_indx, NL_indx, AtQA0, AdtQA_V0, sp_indx, sp_dl_indx);
0206
0207
0208 BGJ = reshape(B-GJ, Nsensor, Nt*Ntry);
0209 tr1_b = tr1_b + sum(repdiag( BGJ, inv(Cov)*BGJ));
0210
0211 Hj = Hj + 0.5*Qj;
0212 end;
0213
0214
0215
0216
0217 sumJ11 = zeros(Nvact,1);
0218 sumVj1 = zeros(Nvact,1);
0219 sumJ12 = zeros(Nvact,1);
0220 MARJ1 = zeros(Nvact,Nt*Ntrials(1));
0221
0222 J22 = reshape(J{1}.^2, Nvact, Nt*Ntrials(1));
0223 sumJ22 = sum(J22,2);
0224 sumVj2 = Ntrials(1)*(Nt-sum(vfreq))*VjT{1};
0225 for nl = 1:Ndelta
0226 sumVj2 = sumVj2 + Ntrials(1)*vfreq(nl)*diag(Vj{1}{nl});
0227 end
0228
0229
0230 J2 = reshape(J{1}, Nvact, Nt*Ntrials(1));
0231
0232 Jtmp = [];
0233 t_all = [];
0234 for ntr = 1:Ntrials(1)
0235
0236 Jtmp = [Jtmp J0{1}(:,:,ntr) J{1}(:,:,ntr)];
0237 t_all = [t_all (1:Nt)+Delta_set(end) + (Nt+Delta_set(end))*(ntr-1)];
0238 end
0239
0240 J1 = cell(Nvact,1);
0241 J1J1 = cell(Nvact,1);
0242 Vj1 = cell(Nvact,1);
0243 if k==1; mar_Vmar = cell(Nvact,1); end
0244 for nv = 1:Nvact
0245
0246 J1{nv} = select_entries(Jtmp, repmat(indx{nv},[1 Nt*Ntrials(1)]), ...
0247 repmat(t_all,size(deltax{nv}))-repmat(deltax{nv},[1 Nt*Ntrials(1)]));
0248 J1J1{nv} = J1{nv}*J1{nv}';
0249
0250 Nind_nv = length(deltax{nv});
0251 Vj1{nv} = zeros(Nind_nv);
0252 for nld = 1:Ndelta
0253 indlb = (deltax{nv}==Delta_set(nld));
0254 sum_indlb = sum(indlb);
0255 if sum_indlb~=0
0256 Vjtmp = zeros(sum_indlb);
0257 nl = 1;
0258 while nl <= Ndelta-nld+1
0259 Vjtmp = Vjtmp + ...
0260 vct(nld,nl)*Vj{1}{nl}(indx{nv}(indlb),indx{nv}(indlb));
0261 nl = nl + 1;
0262 end
0263 Vj1{nv}(indlb,indlb) = Vj1{nv}(indlb,indlb) + Ntrials(1)*Vjtmp;
0264 end
0265 end
0266
0267 if k==1; mar_Vmar{nv} = mar{nv}*mar{nv}' + Vmar{nv}; end
0268 MARJ1(nv,:) = mar{nv}'*J1{nv};
0269 sumJ11(nv) = sumJ11(nv) + sum(repdiag(mar_Vmar{nv},J1J1{nv}));
0270 sumVj1(nv) = sumVj1(nv) + Ntrials(1)*sum(repdiag(mar_Vmar{nv},Vj1{nv}));
0271 end
0272 sumJ12 = sumJ12 + repdiag(J2',(MARJ1)');
0273
0274 mean_zz = sumJ22 - 2*sumJ12 + sumJ11;
0275 mean_Vj = sumVj2 + sumVj1;
0276
0277
0278 sx_total = tr1_b + sum(mean_zz./ax_z);
0279 if update_sx,
0280 sx = sx_total/(Nt*Nerror);
0281 end
0282
0283 Hj(1) = Hj(1) - 0.5*tr1_b/sx;
0284 Hj(2) = Hj(2) - 0.5*sum(mean_zz./ax_z)/sx - 0.5*sum(mean_Vj./ax_z);
0285
0286 if k==(Ntrain+1)
0287
0288
0289 mar_old = mar;
0290 Vmar_old = Vmar;
0291 ieta_old = ieta;
0292 ax_z_old = ax_z;
0293 else
0294
0295
0296 mar_old = mar;
0297 Vmar_old = Vmar;
0298
0299 for nv = 1:Nvact
0300 ax_z_invVmar = (J1J1{nv}/sx + Vj1{nv}) + ax_z(nv)*diag(1./ieta{nv});
0301 J1J2_beta = (J1{nv}*J2(nv,:)'/sx);
0302
0303 inv_ax_z_invVmar = inv(ax_z_invVmar);
0304 Vmar{nv} = ax_z(nv)*inv_ax_z_invVmar;
0305 mar{nv} = inv_ax_z_invVmar*J1J2_beta;
0306 MAR(nv,indx{nv}) = mar{nv}';
0307 end
0308
0309 if k == 1; MAR0 = MAR; end
0310
0311
0312
0313 ieta_old = ieta;
0314
0315 for nv = 1:Nvact
0316 ieta{nv} = (gx0_e{nv}.*ieta0{nv} + ...
0317 0.5*(mar{nv}.^2 + diag(Vmar{nv})))./gx_e{nv};
0318 end
0319
0320
0321
0322 sumJ11 = zeros(Nvact,1);
0323 sumVj1 = zeros(Nvact,1);
0324 sumJ12 = zeros(Nvact,1);
0325 MARJ1 = zeros(Nvact,Nt*Ntrials(1));
0326
0327 mar_Vmar = cell(Nvact,1);
0328 for nv = 1:Nvact
0329 mar_Vmar{nv} = mar{nv}*mar{nv}' + Vmar{nv};
0330 MARJ1(nv,:) = mar{nv}'*J1{nv};
0331 sumJ11(nv) = sumJ11(nv) + sum(repdiag(mar_Vmar{nv},J1J1{nv}));
0332 sumVj1(nv) = sumVj1(nv) + Ntrials(1)*sum(repdiag(mar_Vmar{nv},Vj1{nv}));
0333 end
0334 sumJ12 = sumJ12 + repdiag(J2',(MARJ1)');
0335
0336 mean_zz = sumJ22 - 2*sumJ12 + sumJ11;
0337 mean_Vj = sumVj2 + sumVj1;
0338
0339
0340 ax_z_old = ax_z;
0341
0342
0343 ax_z = (gx0_z.*ax0_z + 0.5*(mean_zz/sx + mean_Vj))./gx_z;
0344 end
0345
0346
0347 rz = ax0_z./ax_z_old;
0348 Haz = sum(gx0_z.*(log(rz) - rz + 1));
0349
0350 Hmar = 0;
0351 Hieta = 0;
0352 mar_vmar_ieta_old = zeros(Nvact,1);
0353 for nv = 1:Nvact
0354 if isempty(indx{nv}) == 0
0355 Hmar = Hmar - sum(log(ieta_old{nv})) + vb_log_det(Vmar_old{nv});
0356 mar_vmar_ieta_old(nv) = ...
0357 sum((mar_old{nv}.^2+diag(Vmar_old{nv}))./ieta_old{nv});
0358 end
0359 re = ieta0{nv}./ieta_old{nv};
0360 Hieta = Hieta + sum(gx0_e{nv}.*(log(re) - re + 1));
0361 end
0362 Hmar = 0.5*(Hmar-sum(mar_vmar_ieta_old));
0363
0364 Hj(1) = Hj(1) - 0.5*Nt*Nerror*log(sx);
0365
0366 [L1c L2c Hbc HAc Hec Hqc difbeta] = free_energy_constant...
0367 (gx_z, gx0_z, gx_e, gx0_e, Nt, Ntotaltrials, Nerror, Nvact, Ind);
0368
0369
0370 Hj(1) = Hj(1) + L1c;
0371 Hj(2) = Hj(2) + L2c + difbeta;
0372 Hmar = Hmar + HAc;
0373 Hieta = Hieta + Hec;
0374 Haz = Haz + Hqc;
0375
0376
0377 FE = sum(Hj) + Hbc + Hmar + Hieta + Haz;
0378
0379
0380 Err = tr1_b/sum(sum(BB));
0381
0382 if k > 1
0383
0384 k_check = k_check + 1;
0385 Info(k_check,:) = [FE, Hj(1), Hj(2), Hbc, Hmar, Hieta, Haz, Err];
0386
0387 if mod(k-1, Nskip) == 0
0388 fprintf('Tn=%3d, Iter=%4d, FE=%f, Error=%e\n', nw, k-1, FE, Err);
0389 end
0390 end
0391 end
0392
0393 for ns = 1:Nsession
0394 Jall{ns,nw} = J{ns};
0395 end
0396 aall(:,nw) = ax_z;
0397 sxall(nw) = sx;
0398 end
0399 toc;
0400
0401
0402 Model.a = aall ;
0403 Model.sx = sxall;
0404 Model.v = zeros(Nwindow,1);
0405 Model.Z = Jall;
0406 Model.Vz = Vj;
0407 clear Vj
0408 for ns = 1:Nsession
0409 Model.Vz{ns}{end+1} = VjT_full{ns};
0410 end
0411
0412 Model.vz_Nsample = [vfreq (Nt-sum(vfreq))];
0413 Model.mar = mar;
0414 Model.MAR = MAR;
0415 Model.MAR0 = MAR0;
0416 for nv = 1:Nvact
0417 Model.Vmar{nv} = Vmar{nv};
0418 end
0419 Model.ieta = ieta;
0420
0421
0422
0423
0424 function [J, Vj, VjT, GJ, Qj, VjT_full] = current_estimation ...
0425 (B, G, Gt, ax_z, Cov, J0, prevJ, mar, Vmar, ...
0426 Nvact, Nsensor, Nt, Ntry, flag_pinv, indx, deltax, Delta_set, sum_ind, ...
0427 nl_indx, NL_indx, AtQA0, AdtQA_V0, sp_indx, sp_dl_indx)
0428
0429
0430 [J, Vj, VjT, Qj, VjT_full] = forward_pass ...
0431 (B, G, Gt, ax_z, Cov, J0, prevJ, mar, Vmar, ...
0432 Nvact, Nt, Ntry, flag_pinv, indx, deltax, Delta_set, sum_ind, ...
0433 nl_indx, NL_indx, AtQA0, AdtQA_V0, sp_indx, sp_dl_indx);
0434
0435
0436 GJ = G * reshape(J, [Nvact, Nt * Ntry]);
0437 GJ = reshape(GJ, [Nsensor, Nt, Ntry]);
0438
0439
0440 function [Jfilt, Vj, VjT, Qj, VjT_full] = forward_pass ...
0441 (B, G, Gt, ax_z, Cov, J0, prevJ, mar, Vmar, ...
0442 Nvact, Nt, Ntry, flag_pinv, indx, deltax, Delta_set, sum_ind, ...
0443 nl_indx, NL_indx, AtQA0, AdtQA_V0, sp_indx, sp_dl_indx)
0444
0445
0446 Jfilt = zeros(Nvact, Nt, Ntry);
0447 Ndelta = length(Delta_set);
0448 AtQ = cell(Ndelta,1);
0449 A = cell(Ndelta,1);
0450 Vj = cell(Ndelta,1);
0451 for nl = 1:Ndelta
0452 AtQ{nl} = spalloc(Nvact, Nvact, Nvact);
0453 A{nl} = spalloc(Nvact, Nvact, Nvact);
0454 Vj{nl} = zeros(Nvact);
0455 end
0456
0457
0458 for nl = 1:Ndelta
0459 for nv = 1:Nvact
0460 if ~isempty(nl_indx{nv,nl})
0461 A{nl}(nv,NL_indx{nv,nl}) = mar{nv}(nl_indx{nv,nl});
0462 end
0463 end
0464 AtQ{nl} = (sparse(diag(1./ax_z))*A{nl})';
0465 end
0466 AtQA = set_AtQA(Nvact, Ndelta, ax_z, mar, Vmar, ...
0467 nl_indx, NL_indx, sp_indx, AtQA0);
0468 AdtQA_V = set_AdtQA_V(Nvact, Ndelta, ax_z, mar, Vmar, ...
0469 nl_indx, NL_indx, sp_dl_indx, AdtQA_V0);
0470
0471
0472 P = cell(Ndelta+1,1);
0473 Ka = cell(Ndelta+1,1);
0474 Nbound = Inf;
0475 for t = 1:Nt-1
0476 if Nbound ~= sum((t + Delta_set) <= Nt)
0477 Nbound = sum((t + Delta_set) <= Nt);
0478
0479 sumAtQA = spalloc(Nvact, Nvact, Nvact);
0480 for nl = 1:Nbound
0481 sumAtQA = sumAtQA + AtQA{nl};
0482 end
0483 QplusAtQA = diag(1./ax_z) + sumAtQA;
0484 P{Ndelta-Nbound+1} = inv(QplusAtQA);
0485 Ka{Ndelta-Nbound+1} = ...
0486 calc_kalman_gain2(G, Gt, P{Ndelta-Nbound+1}, Cov, flag_pinv);
0487
0488 if Nbound ~= 0
0489
0490 Vj{Ndelta-Nbound+1} = ...
0491 P{Ndelta-Nbound+1} - Ka{Ndelta-Nbound+1}*(G*P{Ndelta-Nbound+1});
0492 end
0493 end
0494 end
0495 PT = ax_z;
0496 KaT = calc_kalman_gain(G, Gt, PT, Cov, flag_pinv);
0497
0498 VjT_full = diag(PT) - KaT*(G*diag(PT));
0499 VjT = diag(VjT_full);
0500
0501
0502 Indx = zeros(Nvact, max(sum_ind));
0503 Deltax = zeros(Nvact, max(sum_ind));
0504 MARt = zeros(Nvact, max(sum_ind));
0505 for nv = 1:Nvact
0506 Indx(nv,1:size(indx{nv},1)) = indx{nv}';
0507 Deltax(nv,1:size(deltax{nv},1)) = deltax{nv}';
0508 MARt(nv,1:size(mar{nv},1)) = mar{nv}';
0509 end
0510
0511
0512 for ntr = 1:Ntry
0513 Nbound = Inf;
0514 for t = 1:Nt-1
0515 t_end = t + Delta_set(end);
0516
0517 if Nbound ~= sum((t + Delta_set) <= Nt)
0518 Nbound = sum((t + Delta_set) <= Nt);
0519 end
0520
0521
0522 Jtmp = [J0(:,:,ntr) Jfilt(:,1:t,ntr) prevJ(:,t+1:end,ntr)];
0523
0524 AJfilt = ...
0525 atimesj(MARt,Indx,Deltax,sum_ind,Jtmp(:,t_end:-1:(t_end-Delta_set(end))));
0526
0527
0528 saqj = zeros(Nvact,1);
0529 for nld = 1:Nbound
0530 saj = zeros(Nvact,1);
0531 adqsa_v_j = zeros(Nvact,1);
0532 for nl = 1:Ndelta
0533 if nl ~= nld
0534 t_tmp = t_end + Delta_set(nld) - Delta_set(nl);
0535 saj = saj + A{nl}*Jtmp(:,t_tmp);
0536 adqsa_v_j = adqsa_v_j + AdtQA_V{nld,nl}*Jtmp(:,t_tmp);
0537 end
0538 end
0539 saqj = saqj + AtQ{nld}* ...
0540 (Jtmp(:,t_end + Delta_set(nld)) - saj) - adqsa_v_j;
0541 end
0542
0543
0544 Jpred = P{Ndelta-Nbound+1}*(saqj + spdiags(1./ax_z,0,Nvact,Nvact)*AJfilt);
0545 Jfilt(:,t,ntr) = Jpred + Ka{Ndelta-Nbound+1}*(B(:,t,ntr) - G*Jpred);
0546 end
0547
0548
0549 Nt_end = Nt + Delta_set(end);
0550
0551
0552 Jtmp = [J0(:,:,ntr) Jfilt(:,:,ntr)];
0553
0554 AJfilt = ...
0555 atimesj(MARt,Indx,Deltax,sum_ind,Jtmp(:,Nt_end:-1:(Nt_end-Delta_set(end))));
0556
0557
0558 Jpred = AJfilt;
0559 Jfilt(:,Nt,ntr) = Jpred + KaT*(B(:,Nt,ntr) - G*Jpred);
0560 end
0561
0562
0563 vfreq = fliplr(-diff([(Nt-Delta_set)' 0]));
0564 invCov = inv(Cov);
0565
0566 sumlogVj = 0;
0567 sumTrGVjGtiC = 0;
0568 for nl = 1:Ndelta
0569 sumlogVj = sumlogVj + vfreq(nl)/2*vb_log_det(Vj{nl});
0570 sumTrGVjGtiC = sumTrGVjGtiC + vfreq(nl)/2*sum(repdiag(G*Vj{nl}*Gt,invCov));
0571 end
0572
0573 Qj(1) = 2*Ntry*(-Nt/2*vb_log_det(Cov)...
0574 - sumTrGVjGtiC...
0575 - (Nt-sum(vfreq))/2*sum(repdiag(G*VjT_full*Gt,invCov)));
0576
0577 Qj(2) = 2*Ntry*(...
0578 - Nt/2*sum(log(ax_z))...
0579 + sumlogVj...
0580 + (Nt-sum(vfreq))/2*vb_log_det(VjT_full));
0581
0582
0583
0584
0585 function [L1c L2c Hbc HAc Hec Hqc difbeta] = free_energy_constant...
0586 (gx_z, gx0_z, gx_e, gx0_e, Nt, Ntotaltrials, Nerror, Nvact, Ind)
0587
0588 gx_sx = 0.5*Nt*Nerror;
0589 L1c = 0.5*Nt*(Nerror*qG(gx_sx));
0590 L2c = 0.5*Nt*(Ntotaltrials*sum(qG(gx_z)));
0591 Hbc = HG0(gx_sx,0);
0592 Hqc = sum(HG0(gx_z, gx0_z));
0593 HAc = 0;
0594 Hec = 0;
0595 for nv = 1:Nvact
0596 HAc = HAc + 0.5*sum(qG(gx_e{nv}));
0597 Hec = Hec + sum(HG0(gx_e{nv}, gx0_e{nv}));
0598 end
0599 L1c = L1c + 0.5*(-Nt*Nerror*log(2*pi));
0600 L2c = L2c + 0.5*(Nvact*Nt*Ntotaltrials);
0601 HAc = HAc + 0.5*(sum(Ind(:)));
0602 difbeta = 0.5*Nt*Nvact*qG(gx_sx);
0603
0604
0605 function y = qG(gx)
0606 N=length(gx);
0607 nz = find(gx ~= 0);
0608 y = zeros(N,1);
0609 y_tmp = psi(gx(nz)) - log(gx(nz));
0610 y(nz) = y_tmp;
0611
0612
0613 function y =HG0(gx, gx0)
0614 N=length(gx0);
0615 nz = find(gx ~= 0);
0616 nz0 = find(gx0 ~= 0);
0617
0618 y1 = zeros(N,1);
0619 y1_tmp = gammaln(gx(nz)) - gx(nz).*psi(gx(nz)) + gx(nz);
0620 y1(nz) = y1_tmp;
0621
0622 y2 = zeros(N,1);
0623 y2_tmp = gammaln(gx0(nz0)) - gx0(nz0).*log(gx0(nz0)) + gx0(nz0);
0624 y2(nz0) = y2_tmp;
0625
0626 y3 = zeros(N,1);
0627 y3_tmp = gx0(nz0).*(psi(gx(nz0))-log(gx(nz0)));
0628 y3(nz0) = y3_tmp;
0629
0630 y = y1 - y2 + y3;