OSR: allow using analytic gradient

kalman-mex
Johannes Pfeifer 2023-09-11 17:17:23 +02:00 committed by Sébastien Villemot
parent 6037b9f096
commit 885fda0e20
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
7 changed files with 161 additions and 21 deletions

View File

@ -11399,6 +11399,14 @@ Optimal Simple Rules (OSR)
future release of Dynare. Use ``optim`` instead to set
optimizer-specific values. Default: ``1e-7``.
.. option:: analytic_derivation
Triggers estimation with analytic gradient of the objective function.
.. option:: analytic_derivation_mode = INTEGER
See :opt:analytic_derivation_mode.
.. option:: silent_optimizer
See :opt:`silent_optimizer`.

View File

@ -1,5 +1,6 @@
function [loss,info,exit_flag,vx,junk]=objective(x,M_, oo_, options_,i_params,i_var,weights)
% objective function for optimal simple rules (OSR)
function [loss,info,exit_flag,df,vx]=objective(x,M_, oo_, options_,i_params,i_var,weights)
% [loss,info,exit_flag,df,vx]=objective(x,M_, oo_, options_,i_params,i_var,weights)
% Objective function for optimal simple rules (OSR)
% INPUTS
% x vector values of the parameters
% over which to optimize
@ -14,9 +15,8 @@ function [loss,info,exit_flag,vx,junk]=objective(x,M_, oo_, options_,i_params,i_
% loss scalar loss function returned to solver
% info vector info vector returned by resol
% exit_flag scalar exit flag returned to solver
% df vectcor Analytic Jacobian
% vx vector variances of the endogenous variables
% junk empty dummy output for conformable
% header
%
% SPECIAL REQUIREMENTS
% none
@ -37,19 +37,18 @@ function [loss,info,exit_flag,vx,junk]=objective(x,M_, oo_, options_,i_params,i_
% You should have received a copy of the GNU General Public License
% along with Dynare. If not, see <https://www.gnu.org/licenses/>.
junk = [];
exit_flag = 1;
vx = [];
% set parameters of the policiy rule
df=NaN(length(i_params),1);
% set parameters of the policy rule
M_.params(i_params) = x;
% don't change below until the part where the loss function is computed
[dr,info] = resol(0,M_,options_,oo_);
[oo_.dr,info] = resol(0,M_,options_,oo_);
if info(1)
if info(1) == 3 || info(1) == 4 || info(1) == 5 || info(1)==6 ||info(1) == 19 ||...
info(1) == 20 || info(1) == 21 || info(1) == 23 || info(1) == 26 || ...
info(1) == 81 || info(1) == 84 || info(1) == 85
info(1) == 20 || info(1) == 21 || info(1) == 23 || info(1) == 26 || ...
info(1) == 81 || info(1) == 84 || info(1) == 85
loss = 1e8;
info(4)=info(2);
return
@ -60,5 +59,26 @@ if info(1)
end
end
vx = osr.get_variance_of_endogenous_variables(M_,options_,dr,i_var);
loss = full(weights(:)'*vx(:));
if ~options_.analytic_derivation
vx = osr.get_variance_of_endogenous_variables(M_,options_,oo_.dr,i_var);
loss = full(weights(:)'*vx(:));
else
totparam_nbr=length(i_params);
oo_.dr.derivs = get_perturbation_params_derivs(M_, options_, [], oo_, i_params, [], [], 0); %analytic derivatives of perturbation matrices
pruned_state_space = pruned_state_space_system(M_, options_, oo_.dr, i_var, 0, 0, 1);
vx = pruned_state_space.Var_y + pruned_state_space.E_y*pruned_state_space.E_y';
dE_yy = pruned_state_space.dVar_y;
for jp=1:length(i_params)
dE_yy(:,:,jp) = dE_yy(:,:,jp) + pruned_state_space.dE_y(:,jp)*pruned_state_space.E_y' + pruned_state_space.E_y*pruned_state_space.dE_y(:,jp)';
end
model_moments_params_derivs = reshape(dE_yy,length(i_var)^2,totparam_nbr);
df = NaN(totparam_nbr,1);
loss = full(weights(:)'*vx(:));
for jp=1:length(i_params)
df(jp,1) = sum(weights(:).*model_moments_params_derivs(:,jp));
end
end

@ -1 +1 @@
Subproject commit bd0ba65a61c0d97e3f537194136e3cdad0f4f3b2
Subproject commit 978789d02a4a27c479094512e791576414c54a73

View File

@ -120,6 +120,7 @@ MODFILES = \
irfs/example1_unit_std.mod \
optimal_policy/OSR/osr_example.mod \
optimal_policy/OSR/osr_example_objective_correctness.mod \
optimal_policy/OSR/osr_objective_correctness_anal_deriv.mod \
optimal_policy/OSR/osr_example_obj_corr_non_stat_vars.mod \
optimal_policy/OSR/osr_example_param_bounds.mod \
optimal_policy/OSR/osr_obj_corr_algo_1.mod \

View File

@ -10,6 +10,12 @@ kappa = 0.18;
alpha = 0.48;
sigma = -0.06;
gammarr = 0;
gammax0 = 0.2;
gammac0 = 1.5;
gamma_y_ = 8;
gamma_inf_ = 3;
model(linear);
y = delta * y(-1) + (1-delta)*y(+1)+sigma *(r - inflation(+1)) + y_;
@ -28,14 +34,11 @@ end;
optim_weights;
inflation 1;
y 1;
y,inflation 0.1;
end;
osr_params gammax0 gammac0 gamma_y_ gamma_inf_;
gammarr = 0;
gammax0 = 0.2;
gammac0 = 1.5;
gamma_y_ = 8;
gamma_inf_ = 3;
osr;
osr(analytic_derivation,opt_algo=4);
osr(analytic_derivation,opt_algo=1,optim=('DerivativeCheck','on','FiniteDifferenceType','central'));

View File

@ -26,7 +26,6 @@ end;
options_.nograph=1;
options_.nocorr=1;
options_.osr.tolf=1e-20;
osr_params gammax0 gammac0 gamma_y_ gamma_inf_;
@ -42,7 +41,7 @@ gammac0 = 1.5;
gamma_y_ = 8;
gamma_inf_ = 3;
osr;
osr(optim=('TolFun',1e-20));
%compute objective function manually
objective=oo_.var(strmatch('y',M_.endo_names,'exact'),strmatch('y',M_.endo_names,'exact'))+oo_.var(strmatch('inflation',M_.endo_names,'exact'),strmatch('inflation',M_.endo_names,'exact'))+oo_.var(strmatch('dummy_var',M_.endo_names,'exact'),strmatch('dummy_var',M_.endo_names,'exact'));

View File

@ -0,0 +1,109 @@
// Example of optimal simple rule
var y inflation r dummy_var;
varexo y_ inf_;
parameters delta sigma alpha kappa gammax0 gammac0 gamma_y_ gamma_inf_;
delta = 0.44;
kappa = 0.18;
alpha = 0.48;
sigma = -0.06;
model(linear);
y = delta * y(-1) + (1-delta)*y(+1)+sigma *(r - inflation(+1)) + y_;
inflation = alpha * inflation(-1) + (1-alpha) * inflation(+1) + kappa*y + inf_;
dummy_var=0.9*dummy_var(-1)+0.01*y;
r = gammax0*y(-1)+gammac0*inflation(-1)+gamma_y_*y_+gamma_inf_*inf_;
end;
shocks;
var y_;
stderr 0.63;
var inf_;
stderr 0.4;
end;
options_.nograph=1;
options_.nocorr=1;
osr_params gammax0 gammac0 gamma_y_ gamma_inf_;
optim_weights;
inflation 1;
y 1;
dummy_var 1;
end;
gammax0 = 0.2;
gammac0 = 1.5;
gamma_y_ = 8;
gamma_inf_ = 3;
osr(analytic_derivation,optim=('TolFun',1e-20));
%compute objective function manually
objective=oo_.var(strmatch('y',M_.endo_names,'exact'),strmatch('y',M_.endo_names,'exact'))+oo_.var(strmatch('inflation',M_.endo_names,'exact'),strmatch('inflation',M_.endo_names,'exact'))+oo_.var(strmatch('dummy_var',M_.endo_names,'exact'),strmatch('dummy_var',M_.endo_names,'exact'));
if abs(oo_.osr.objective_function-objective)>1e-8
error('Objective Function is wrong')
end
%redo computation with covariance specified
optim_weights;
inflation 1;
y 1;
dummy_var 1;
y,inflation 0.5;
end;
osr(analytic_derivation);
%compute objective function manually
objective=oo_.var(strmatch('y',M_.endo_names,'exact'),strmatch('y',M_.endo_names,'exact'))+oo_.var(strmatch('inflation',M_.endo_names,'exact'),strmatch('inflation',M_.endo_names,'exact'))+oo_.var(strmatch('dummy_var',M_.endo_names,'exact'),strmatch('dummy_var',M_.endo_names,'exact'))+0.5*oo_.var(strmatch('y',M_.endo_names,'exact'),strmatch('inflation',M_.endo_names,'exact'));
if abs(oo_.osr.objective_function-objective)>1e-8
error('Objective Function is wrong')
end
gammax0=1.35533;
gammac0=1.39664;
gamma_y_=16.6667;
gamma_inf_=9.13199;
%redo computation with double weight on one covariance
optim_weights;
inflation 1;
y 1;
dummy_var 1;
y,inflation 1;
end;
osr(analytic_derivation);
%compute objective function manually
objective=oo_.var(strmatch('y',M_.endo_names,'exact'),strmatch('y',M_.endo_names,'exact'))+oo_.var(strmatch('inflation',M_.endo_names,'exact'),strmatch('inflation',M_.endo_names,'exact'))+oo_.var(strmatch('dummy_var',M_.endo_names,'exact'),strmatch('dummy_var',M_.endo_names,'exact'))+1*oo_.var(strmatch('y',M_.endo_names,'exact'),strmatch('inflation',M_.endo_names,'exact'));
if abs(oo_.osr.objective_function-objective)>1e-8
error('Objective Function is wrong')
end
oo_covar_single=oo_;
%redo computation with single weight on both covariances
optim_weights;
inflation 1;
y 1;
dummy_var 1;
y,inflation 0.5;
inflation,y 0.5;
end;
osr(analytic_derivation);
%compute objective function manually
objective=oo_.var(strmatch('y',M_.endo_names,'exact'),strmatch('y',M_.endo_names,'exact'))+oo_.var(strmatch('inflation',M_.endo_names,'exact'),strmatch('inflation',M_.endo_names,'exact'))+oo_.var(strmatch('dummy_var',M_.endo_names,'exact'),strmatch('dummy_var',M_.endo_names,'exact'))+0.5*oo_.var(strmatch('y',M_.endo_names,'exact'),strmatch('inflation',M_.endo_names,'exact'))+0.5*oo_.var(strmatch('inflation',M_.endo_names,'exact'),strmatch('y',M_.endo_names,'exact'));
if abs(oo_.osr.objective_function-objective)>1e-8
error('Objective Function is wrong')
end
if abs(oo_.osr.objective_function-oo_covar_single.osr.objective_function)>1e-8
error('Objective Function is wrong')
end
if max(abs(cell2mat(struct2cell(oo_.osr.optim_params))-cell2mat(struct2cell(oo_covar_single.osr.optim_params))))>1e-5
error('Parameters should be identical')
end