Provide block_trust_region MEX under solve_algo 13 and 14

- block trust region solver now available under solve_algo=13
  It is essentially the same as solve_algo=4, except that Jacobian by finite
  difference is not handled. A test file is added for that case
- block trust region solver with shortcut for equations that can be evaluated
  is now available under solve_algo=14 (in replacement of the pure-MATLAB solver)

Closes: Enterprise/dynare#3
time-shift
Sébastien Villemot 2020-07-16 18:20:07 +02:00
parent 7e21bf2a10
commit 865ab47fa9
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
6 changed files with 208 additions and 23 deletions

View File

@ -232,13 +232,13 @@ elseif options.solve_algo==9
[x, errorflag] = trust_region(f, x, 1:nn, 1:nn, jacobian_flag, options.gstep, ...
tolf, tolx, ...
maxit, options.debug, arguments{:});
elseif ismember(options.solve_algo, [2, 12, 4, 14])
elseif ismember(options.solve_algo, [2, 12, 4])
if ismember(options.solve_algo, [2, 12])
solver = @solve1;
else
solver = @trust_region;
end
specializedunivariateblocks = ismember(options.solve_algo, [12, 14]);
specializedunivariateblocks = options.solve_algo == 12;
if ~jacobian_flag
fjac = zeros(nn,nn) ;
dh = max(abs(x), options.gstep(1)*ones(nn,1))*eps^(1/3);
@ -251,19 +251,19 @@ elseif ismember(options.solve_algo, [2, 12, 4, 14])
[j1,j2,r,s] = dmperm(fjac);
JAC = abs(fjac(j1,j2))>0;
if options.debug
disp(['DYNARE_SOLVE (solve_algo=2|4|12|14): number of blocks = ' num2str(length(r)-1)]);
disp(['DYNARE_SOLVE (solve_algo=2|4|12): number of blocks = ' num2str(length(r)-1)]);
end
l = 0;
fre = false;
for i=length(r)-1:-1:1
blocklength = r(i+1)-r(i);
if options.debug
dprintf('DYNARE_SOLVE (solve_algo=2|4|12|14): solving block %u of size %u.', i, blocklength);
dprintf('DYNARE_SOLVE (solve_algo=2|4|12): solving block %u of size %u.', i, blocklength);
end
j = r(i):r(i+1)-1;
if specializedunivariateblocks
if options.debug
dprintf('DYNARE_SOLVE (solve_algo=2|4|12|14): solving block %u by evaluating RHS.', i);
dprintf('DYNARE_SOLVE (solve_algo=2|4|12): solving block %u by evaluating RHS.', i);
end
if isequal(blocklength, 1)
if i<length(r)-1
@ -304,7 +304,7 @@ elseif ismember(options.solve_algo, [2, 12, 4, 14])
end
else
if options.debug
dprintf('DYNARE_SOLVE (solve_algo=2|4|12|14): solving block %u with trust_region routine.', i);
dprintf('DYNARE_SOLVE (solve_algo=2|4|12): solving block %u with trust_region routine.', i);
end
end
[x, errorflag] = solver(f, x, j1(j), j2(j), jacobian_flag, ...
@ -356,6 +356,19 @@ elseif options.solve_algo == 11
catch
errorflag = true;
end
elseif ismember(options.solve_algo, [13, 14])
if ~jacobian_flag
error('DYNARE_SOLVE: option solve_algo=13|14 needs computed Jacobian')
end
auxstruct = struct();
if options.solve_algo == 14
auxstruct.lhs = lhs;
auxstruct.endo_names = endo_names;
auxstruct.isloggedlhs = isloggedlhs;
auxstruct.isauxdiffloggedrhs = isauxdiffloggedrhs;
end
[x, errorflag] = block_trust_region(f, x, tolf, options.solve_tolx, maxit, options.debug, auxstruct, arguments{:});
[fvec, fjac] = feval(f, x, arguments{:});
else
error('DYNARE_SOLVE: option solve_algo must be one of [0,1,2,3,4,9,10,11,12,14]')
error('DYNARE_SOLVE: option solve_algo must be one of [0,1,2,3,4,9,10,11,12,13,14]')
end

View File

@ -35,8 +35,8 @@ function [steady_state,params,info] = steady_(M_,options_,oo_)
% You should have received a copy of the GNU General Public License
% along with Dynare. If not, see <http://www.gnu.org/licenses/>.
if options_.solve_algo < 0 || options_.solve_algo > 12
error('STEADY: solve_algo must be between 0 and 12')
if options_.solve_algo < 0 || options_.solve_algo > 14
error('STEADY: solve_algo must be between 0 and 14')
end
if ~options_.bytecode && ~options_.block && options_.solve_algo > 4 && ...

View File

@ -11,6 +11,7 @@ nodist_block_trust_region_SOURCES = \
BUILT_SOURCES = $(nodist_block_trust_region_SOURCES)
CLEANFILES = $(nodist_block_trust_region_SOURCES)
dulmage_mendelsohn.o: matlab_mex.mod
dulmage_mendelsohn.mod: dulmage_mendelsohn.o
matlab_fcn_closure.mod: matlab_fcn_closure.o

View File

@ -1,4 +1,4 @@
! Copyright © 2019 Dynare Team
! Copyright © 2019-2020 Dynare Team
!
! This file is part of Dynare.
!
@ -31,39 +31,104 @@ subroutine mexFunction(nlhs, plhs, nrhs, prhs) bind(c, name='mexFunction')
real(real64), dimension(:), allocatable, target :: x
type(dm_block), dimension(:), allocatable, target :: blocks
integer :: info, i
real(real64), parameter :: tolf = 1e-6_real64
real(real64) :: tolf, tolx
integer :: maxiter
real(real64), dimension(:), allocatable :: fvec
real(real64), dimension(:,:), allocatable :: fjac
logical :: debug
logical :: debug, specializedunivariateblocks
character(len=80) :: debug_msg
logical(mxLogical), dimension(:), pointer :: isloggedlhs => null(), &
isauxdiffloggedrhs => null()
type(c_ptr) :: endo_names, lhs
logical :: fre ! True if the last block has been solved (i.e. not evaluated), so that residuals must be updated
integer, dimension(:), allocatable :: evaled_cols ! If fre=.false., lists the columns that have been evaluated so far without updating the residuals
if (nrhs < 3 .or. nlhs /= 2) then
call mexErrMsgTxt("Must have at least 3 inputs and exactly 2 outputs")
if (nrhs < 4 .or. nlhs /= 2) then
call mexErrMsgTxt("Must have at least 7 inputs and exactly 2 outputs")
return
end if
if (.not. ((mxIsChar(prhs(1)) .and. mxGetM(prhs(1)) == 1) .or. mxIsClass(prhs(1), "function_handle"))) then
call mexErrMsgTxt("First argument should be a string or a function handle")
call mexErrMsgTxt("First argument (function) should be a string or a function handle")
return
end if
if (.not. (mxIsDouble(prhs(2)) .and. (mxGetM(prhs(2)) == 1 .or. mxGetN(prhs(2)) == 1))) then
call mexErrMsgTxt("Second argument should be a real vector")
call mexErrMsgTxt("Second argument (initial guess) should be a real vector")
return
end if
if (.not. (mxIsLogicalScalar(prhs(3)))) then
call mexErrMsgTxt("Third argument should be a logical scalar")
if (.not. (mxIsScalar(prhs(3)) .and. mxIsNumeric(prhs(3)))) then
call mexErrMsgTxt("Third argument (tolf) should be a numeric scalar")
return
end if
if (.not. (mxIsScalar(prhs(4)) .and. mxIsNumeric(prhs(4)))) then
call mexErrMsgTxt("Fourth argument (tolx) should be a numeric scalar")
return
end if
if (.not. (mxIsScalar(prhs(5)) .and. mxIsNumeric(prhs(5)))) then
call mexErrMsgTxt("Fifth argument (maxiter) should be a numeric scalar")
return
end if
if (.not. (mxIsLogicalScalar(prhs(6)))) then
call mexErrMsgTxt("Sixth argument (debug) should be a logical scalar")
return
end if
if (.not. (mxIsStruct(prhs(7)) .and. &
(mxGetNumberOfFields(prhs(7)) == 0 .or. mxGetNumberOfFields(prhs(7)) == 4))) then
call mexErrMsgTxt("Seventh argument should be a struct with either 0 or 4 fields")
return
end if
specializedunivariateblocks = (mxGetNumberOfFields(prhs(7)) == 4)
func => prhs(1)
debug = mxGetScalar(prhs(3)) == 1._c_double
extra_args => prhs(4:nrhs)
tolf = mxGetScalar(prhs(3))
tolx = mxGetScalar(prhs(4))
maxiter = int(mxGetScalar(prhs(5)))
debug = mxGetScalar(prhs(6)) == 1._c_double
extra_args => prhs(8:nrhs) ! Extra arguments to func are in argument 8 and subsequent ones
associate (x_mat => mxGetPr(prhs(2)))
allocate(x(size(x_mat)))
x = x_mat
end associate
if (specializedunivariateblocks) then
block
type(c_ptr) :: tmp
tmp = mxGetField(prhs(7), 1_mwIndex, "isloggedlhs")
if (.not. (c_associated(tmp) .and. mxIsLogical(tmp) .and. mxGetNumberOfElements(tmp) == size(x))) then
call mexErrMsgTxt("Seventh argument must have a 'isloggedlhs' field of type logical, of same size as second argument")
return
end if
isloggedlhs => mxGetLogicals(tmp)
tmp = mxGetField(prhs(7), 1_mwIndex, "isauxdiffloggedrhs")
if (.not. (c_associated(tmp) .and. mxIsLogical(tmp) .and. mxGetNumberOfElements(tmp) == size(x))) then
call mexErrMsgTxt("Seventh argument must have a 'isauxdiffloggedrhs' field of type &
&logical, of same size as second argument")
return
end if
isauxdiffloggedrhs => mxGetLogicals(tmp)
lhs = mxGetField(prhs(7), 1_mwIndex, "lhs")
if (.not. (c_associated(lhs) .and. mxIsCell(lhs) .and. mxGetNumberOfElements(lhs) == size(x))) then
call mexErrMsgTxt("Seventh argument must have a 'lhs' field of type cell, of same size as second argument")
return
end if
endo_names = mxGetField(prhs(7), 1_mwIndex, "endo_names")
if (.not. (c_associated(endo_names) .and. mxIsCell(endo_names) .and. mxGetNumberOfElements(endo_names) == size(x))) then
call mexErrMsgTxt("Seventh argument must have a 'endo_names' field of type cell, of same size as second argument")
return
end if
end block
allocate(evaled_cols(0))
fre = .false.
end if
allocate(fvec(size(x)))
allocate(fjac(size(x), size(x)))
@ -79,7 +144,6 @@ subroutine mexFunction(nlhs, plhs, nrhs, prhs) bind(c, name='mexFunction')
end if
! Solve the system, starting from bottom-rightmost block
x_all => x
do i = size(blocks),1,-1
if (debug) then
write (debug_msg, "('DYNARE_SOLVE (solve_algo=13|14): solving block ', i0, ' of size ', i0)") &
@ -87,18 +151,77 @@ subroutine mexFunction(nlhs, plhs, nrhs, prhs) bind(c, name='mexFunction')
call mexPrintf_trim_newline(debug_msg)
end if
if (specializedunivariateblocks .and. size(blocks(i)%col_indices) == 1) then
if (debug) then
write (debug_msg, "('DYNARE_SOLVE (solve_algo=13|14): solving block ', i0, ' by evaluating RHS')") i
call mexPrintf_trim_newline(debug_msg)
end if
associate (eq => blocks(i)%row_indices(1), var => blocks(i)%col_indices(1))
if (fre .or. any(abs(fjac(eq, evaled_cols)) > 0._real64)) then
! Reevaluation of the residuals is required because the current RHS depends on
! variables that potentially have been updated previously.
nullify(x_indices, f_indices, x_all)
call matlab_fcn(x, fvec)
deallocate(evaled_cols) ! This shouldnt be necessary, but it crashes otherwise with gfortran 8
allocate(evaled_cols(0))
fre = .false.
end if
evaled_cols = [ evaled_cols, var]
block
! An associate() construct for lhs_eq and endo_name_var makes the
! code crash (with double free) using gfortran 8. Hence use a block
character(kind=c_char, len=:), allocatable :: lhs_eq, endo_name_var
lhs_eq = mxArrayToString(mxGetCell(lhs, int(eq, mwIndex)))
endo_name_var = mxArrayToString(mxGetCell(endo_names, int(var, mwIndex)))
if (lhs_eq == endo_name_var .or. lhs_eq == "log(" // endo_name_var // ")") then
if (isloggedlhs(eq)) then
x(var) = exp(log(x(var)) - fvec(eq))
else
x(var) = x(var) - fvec(eq)
end if
else
if (debug) then
write (debug_msg, "('LHS variable is not determined by RHS expression (', i0, ')')") eq
call mexPrintf_trim_newline(debug_msg)
write (debug_msg, "(a, ' -> ', a)") lhs_eq, endo_name_var
call mexPrintf_trim_newline(debug_msg)
end if
if (lhs_eq(1:9) == "AUX_DIFF_" .or. lhs_eq(1:13) == "log(AUX_DIFF_") then
if (isauxdiffloggedrhs(eq)) then
x(var) = exp(log(x(var)) + fvec(eq))
else
x(var) = x(var) + fvec(eq)
end if
else
call mexErrMsgTxt("Algorithm solve_algo=14 cannot be used with this nonlinear problem")
return
end if
end if
end block
end associate
cycle
else
if (debug) then
write (debug_msg, "('DYNARE_SOLVE (solve_algo=13|14): solving block ', i0, ' with trust region routine')") i
call mexPrintf_trim_newline(debug_msg)
end if
end if
block
real(real64), dimension(size(blocks(i)%col_indices)) :: x_block
x_indices => blocks(i)%col_indices
f_indices => blocks(i)%row_indices
x_all => x
if (size(x_indices) /= size(f_indices)) then
call mexErrMsgTxt("Non-square block")
return
end if
x_block = x(x_indices)
call trust_region_solve(x_block, matlab_fcn, info, tolf = tolf)
call trust_region_solve(x_block, matlab_fcn, info, tolx, tolf, maxiter)
x(x_indices) = x_block
end block
fre = .true.
end do
! Verify that we have a solution
@ -113,7 +236,7 @@ subroutine mexFunction(nlhs, plhs, nrhs, prhs) bind(c, name='mexFunction')
if (maxval(abs(fvec)) > tolf) then
if (debug) &
call mexPrintf_trim_newline("DYNARE_SOLVE (solve_algo=13|14): residuals still too large, solving for the whole model")
call trust_region_solve(x, matlab_fcn, info, tolf = tolf)
call trust_region_solve(x, matlab_fcn, info, tolx, tolf, maxiter)
else
info = 1
end if

View File

@ -134,6 +134,7 @@ MODFILES = \
steady_state/walsh1_ssm_block.mod \
steady_state/multi_leads.mod \
steady_state/example1_trust_region.mod \
steady_state/example1_block_trust_region.mod \
steady_state/Gali_2015_chapter_6_4.mod \
steady_state_operator/standard.mod \
steady_state_operator/use_dll.mod \

View File

@ -0,0 +1,47 @@
// Test block trust region nonlinear solver (solve_algo=13)
var y, c, k, a, h, b;
varexo e, u;
parameters beta, rho, alpha, delta, theta, psi, tau;
alpha = 0.36;
rho = 0.95;
tau = 0.025;
beta = 0.99;
delta = 0.025;
psi = 0;
theta = 2.95;
phi = 0.1;
model;
c*theta*h^(1+psi)=(1-alpha)*y;
k = beta*(((exp(b)*c)/(exp(b(+1))*c(+1)))
*(exp(b(+1))*alpha*y(+1)+(1-delta)*k));
y = exp(a)*(k(-1)^alpha)*(h^(1-alpha));
k = exp(b)*(y-c)+(1-delta)*k(-1);
a = rho*a(-1)+tau*b(-1) + e;
b = tau*a(-1)+rho*b(-1) + u;
end;
initval;
y = 1;
c = 0.8;
h = 0.3;
k = 10;
a = 0;
b = 0;
e = 0;
u = 0;
end;
options_.debug = true;
steady(solve_algo=13);
shocks;
var e; stderr 0.009;
var u; stderr 0.009;
var e, u = phi*0.009*0.009;
end;
stoch_simul(order=1,nomoments,irf=0);