mr_hessian.m: refined algorithm that calibrates gradient step according to target variation in objective function

covariance-quadratic-approximation
Marco Ratto 2023-12-13 19:13:46 +01:00 committed by Johannes Pfeifer
parent e1e79d3177
commit 3931451250
1 changed files with 55 additions and 22 deletions

View File

@ -24,7 +24,7 @@ function [hessian_mat, gg, htol1, ihh, hh_mat0, hh1, hess_info] = mr_hessian(x,f
% derivatives
% - hess_info structure storing the step sizes for
% computation of Hessian
% - bounds prior bounds of parameters
% - bounds prior bounds of parameters
% - prior_std prior standard devation of parameters (can be NaN)
% - Save_files indicator whether files should be saved
% - varargin other inputs
@ -78,7 +78,7 @@ else
end
hess_info.h1 = min(hess_info.h1,0.5.*hmax);
hess_info.h1 = min(hess_info.h1,0.9.*hmax);
if htol0<hess_info.htol
hess_info.htol=htol0;
@ -106,22 +106,39 @@ while i<n
ic=0;
icount = 0;
h0=hess_info.h1(i);
while (abs(dx(it))<0.5*hess_info.htol || abs(dx(it))>(3*hess_info.htol)) && icount<10 && ic==0
istoobig=false;
istoosmall=false;
give_it_up=false;
while (abs(dx(it))<0.5*hess_info.htol || abs(dx(it))>(3*hess_info.htol)) && icount<10 && ic==0 && hmax(i)>2*1.e-10 && not(give_it_up)
icount=icount+1;
istoobig(it)=false;
istoosmall(it)=false;
if abs(dx(it))<0.5*hess_info.htol
if abs(dx(it)) ~= 0
hess_info.h1(i)=min(max(1.e-10,0.3*abs(x(i))), 0.9*hess_info.htol/abs(dx(it))*hess_info.h1(i));
else
hess_info.h1(i)=2.1*hess_info.h1(i);
istoosmall(it)=true;
if hess_info.h1(i)==0.9*hmax(i)% || hess_info.h1(i)==0.3*abs(x(i))
give_it_up=true;
end
hess_info.h1(i) = min(hess_info.h1(i),0.5*hmax(i));
hess_info.h1(i) = max(hess_info.h1(i),1.e-10);
if abs(dx(it)) ~= 0
% htmp=min(max(1.e-10,0.3*abs(x(i))), 0.9*hess_info.htol/abs(dx(it))*hess_info.h1(i));
htmp=0.9*hess_info.htol/abs(dx(it))*hess_info.h1(i);
else
htmp=2.1*hess_info.h1(i);
end
htmp = min(htmp,0.9*hmax(i));
if any(h0(istoobig(1:it))) %&& htmp>=min(h0(istoobig(1:it)))
htmp = 0.5*(min(h0(istoobig))+max(h0(istoosmall)));
end
hess_info.h1(i) = max(htmp,1.e-10);
xh1(i)=x(i)+hess_info.h1(i);
[fx,~,ffx]=penalty_objective_function(xh1,func,penalty,varargin{:});
end
if abs(dx(it))>(3*hess_info.htol)
hess_info.h1(i)= hess_info.htol/abs(dx(it))*hess_info.h1(i);
hess_info.h1(i) = max(hess_info.h1(i),1e-10);
istoobig(it)=true;
htmp= hess_info.htol/abs(dx(it))*hess_info.h1(i);
if any(h0(istoosmall(1:it))) %&& htmp<=max(h0(istoosmall(1:it)))
htmp = 0.5*(min(h0(istoobig))+max(h0(istoosmall)));
end
hess_info.h1(i) = max(htmp,1e-10);
xh1(i)=x(i)+hess_info.h1(i);
[fx,~,ffx]=penalty_objective_function(xh1,func,penalty,varargin{:});
iter=0;
@ -136,11 +153,27 @@ while i<n
it=it+1;
dx(it)=(fx-f0);
h0(it)=hess_info.h1(i);
if (hess_info.h1(i)<1.e-12*min(1,h2(i)) && hess_info.h1(i)<0.5*hmax(i))
if (hess_info.h1(i)<1.e-12*min(1,h2(i)) && hess_info.h1(i)<0.9*hmax(i))
ic=1;
hcheck=1;
end
end
if icount == 10 || hess_info.h1(i)==1.e-10
istoobig(it)=false;
istoosmall(it)=false;
if abs(dx(it))<0.5*hess_info.htol
istoosmall(it)=true;
end
if abs(dx(it))>(3*hess_info.htol)
istoobig(it)=true;
end
if any(istoobig) && (istoosmall(it) || istoobig(it))
% always better to be wrong from above, to avoid numerical noise
[ddx, ij]=min(dx(istoobig));
fx=f0+ddx;
hess_info.h1(i)=h0(ij);
end
end
f1(:,i)=fx;
if outer_product_gradient
if any(isnan(ffx)) || isempty(ffx)
@ -207,8 +240,8 @@ if outer_product_gradient
dum = (f1(:,i)+f_1(:,i)-2*f0)./(hess_info.h1(i)*h_1(i));
hessian_mat(:,(i-1)*n+i)=dum;
if any(dum<=eps)
hessian_mat(dum<=eps,(i-1)*n+i)=max(eps, gg(i)^2);
end
hessian_mat(dum<=eps,(i-1)*n+i)=max(eps, gg(i)^2);
end
end
end
@ -216,7 +249,7 @@ if outer_product_gradient
hh_mat=gga'*gga; % rescaled outer product hessian
hh_mat0=ggh'*ggh; % outer product hessian
A=diag(2.*hess_info.h1); % rescaling matrix
% igg=inv(hh_mat); % inverted rescaled outer product hessian
% igg=inv(hh_mat); % inverted rescaled outer product hessian
ihh=A'*(hh_mat\A); % inverted outer product hessian (based on rescaling)
if hflag>0 && min(eig(reshape(hessian_mat,n,n)))>0
hh0 = A*reshape(hessian_mat,n,n)*A'; %rescaled second order derivatives
@ -224,7 +257,7 @@ if outer_product_gradient
sd0=sqrt(diag(hh0)); %rescaled 'standard errors' using second order derivatives
sd=sqrt(diag(hh_mat)); %rescaled 'standard errors' using outer product
hh_mat=hh_mat./(sd*sd').*(sd0*sd0'); %rescaled inverse outer product with 'true' std's
ihh=A'*(hh_mat\A); % update inverted outer product hessian with 'true' std's
ihh=A'*(hh_mat\A); % update inverted outer product hessian with 'true' std's
sd=sqrt(diag(ihh)); %standard errors
sdh=sqrt(1./diag(hh)); %diagonal standard errors
for j=1:length(sd)
@ -238,12 +271,12 @@ if outer_product_gradient
igg=inv_A'*ihh*inv_A; % inverted rescaled outer product hessian with modified std's
% hh_mat=inv(igg); % outer product rescaled hessian with modified std's
hh_mat0=inv_A'/igg*inv_A; % outer product hessian with modified std's
% sd0=sqrt(1./diag(hh0)); %rescaled 'standard errors' using second order derivatives
% sd=sqrt(diag(igg)); %rescaled 'standard errors' using outer product
% igg=igg./(sd*sd').*(sd0*sd0'); %rescaled inverse outer product with 'true' std's
% hh_mat=inv(igg); % rescaled outer product hessian with 'true' std's
% ihh=A'*igg*A; % inverted outer product hessian
% hh_mat0=inv(A)'*hh_mat*inv(A); % outer product hessian with 'true' std's
% sd0=sqrt(1./diag(hh0)); %rescaled 'standard errors' using second order derivatives
% sd=sqrt(diag(igg)); %rescaled 'standard errors' using outer product
% igg=igg./(sd*sd').*(sd0*sd0'); %rescaled inverse outer product with 'true' std's
% hh_mat=inv(igg); % rescaled outer product hessian with 'true' std's
% ihh=A'*igg*A; % inverted outer product hessian
% hh_mat0=inv(A)'*hh_mat*inv(A); % outer product hessian with 'true' std's
end
if hflag<2
hessian_mat=hh_mat0(:);