From 4c6e911d69abbebb858b000df027af6c4e96902a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= Date: Fri, 12 Apr 2019 18:18:13 +0200 Subject: [PATCH] k-order DLL: in MATLAB mode, get model derivatives from preprocessor at an arbitrary order MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We no longer use the old ‘modfile.dynamic’ compatibility layer. We directly call the ‘modfile.dynamic_g*’ functions. Ref #217 --- mex/sources/k_order_perturbation/dynamic_m.cc | 90 +++++++++++++------ 1 file changed, 64 insertions(+), 26 deletions(-) diff --git a/mex/sources/k_order_perturbation/dynamic_m.cc b/mex/sources/k_order_perturbation/dynamic_m.cc index 081887996..0116d7f83 100644 --- a/mex/sources/k_order_perturbation/dynamic_m.cc +++ b/mex/sources/k_order_perturbation/dynamic_m.cc @@ -72,37 +72,75 @@ void DynamicModelMFile::eval(const Vector &y, const Vector &x, const Vector &modParams, const Vector &ySteady, Vector &residual, std::vector &md) noexcept(false) { - constexpr int nlhs_dynamic = 4, nrhs_dynamic = 5; - mxArray *prhs[nrhs_dynamic], *plhs[nlhs_dynamic]; + mxArray *T_m = mxCreateDoubleMatrix(ntt, 1, mxREAL); - prhs[0] = mxCreateDoubleMatrix(y.length(), 1, mxREAL); - prhs[1] = mxCreateDoubleMatrix(1, x.length(), mxREAL); - prhs[2] = mxCreateDoubleMatrix(modParams.length(), 1, mxREAL); - prhs[3] = mxCreateDoubleMatrix(ySteady.length(), 1, mxREAL); - prhs[4] = mxCreateDoubleScalar(1.0); + mxArray *y_m = mxCreateDoubleMatrix(y.length(), 1, mxREAL); + std::copy_n(y.base(), y.length(), mxGetPr(y_m)); - std::copy_n(y.base(), y.length(), mxGetPr(prhs[0])); - std::copy_n(x.base(), x.length(), mxGetPr(prhs[1])); - std::copy_n(modParams.base(), modParams.length(), mxGetPr(prhs[2])); - std::copy_n(ySteady.base(), ySteady.length(), mxGetPr(prhs[3])); + mxArray *x_m = mxCreateDoubleMatrix(1, x.length(), mxREAL); + std::copy_n(x.base(), x.length(), mxGetPr(x_m)); - int retVal = mexCallMATLAB(nlhs_dynamic, plhs, nrhs_dynamic, prhs, DynamicMFilename.c_str()); - if (retVal != 0) - throw DynareException(__FILE__, __LINE__, "Trouble calling " + DynamicMFilename); + mxArray *params_m = mxCreateDoubleMatrix(modParams.length(), 1, mxREAL); + std::copy_n(modParams.base(), modParams.length(), mxGetPr(params_m)); - residual = Vector{plhs[0]}; + mxArray *steady_state_m = mxCreateDoubleMatrix(ySteady.length(), 1, mxREAL); + std::copy_n(ySteady.base(), ySteady.length(), mxGetPr(steady_state_m)); - assert(static_cast(mxGetM(plhs[1])) == md[0].nrows()); - assert(static_cast(mxGetN(plhs[1])) == md[0].ncols()); - std::copy_n(mxGetPr(plhs[1]), mxGetM(plhs[1])*mxGetN(plhs[1]), md[0].base()); + mxArray *it_m = mxCreateDoubleScalar(1.0); + mxArray *T_flag_m = mxCreateLogicalScalar(false); - if (md.size() >= 2) - unpackSparseMatrixAndCopyIntoTwoDMatData(plhs[2], md[1]); - if (md.size() >= 3) - unpackSparseMatrixAndCopyIntoTwoDMatData(plhs[3], md[2]); + { + // Compute temporary terms (for all orders) + std::string funcname = DynamicMFilename + "_g" + std::to_string(md.size()) + "_tt"; + mxArray *plhs[1], *prhs[] = { T_m, y_m, x_m, params_m, steady_state_m, it_m }; - for (int i = 0; i < nrhs_dynamic; i++) - mxDestroyArray(prhs[i]); - for (int i = 0; i < nlhs_dynamic; i++) - mxDestroyArray(plhs[i]); + int retVal = mexCallMATLAB(1, plhs, 6, prhs, funcname.c_str()); + if (retVal != 0) + throw DynareException(__FILE__, __LINE__, "Trouble calling " + funcname); + + mxDestroyArray(T_m); + T_m = plhs[0]; + } + + { + // Compute residuals + std::string funcname = DynamicMFilename + "_resid"; + mxArray *plhs[1], *prhs[] = { T_m, y_m, x_m, params_m, steady_state_m, it_m, T_flag_m }; + + int retVal = mexCallMATLAB(1, plhs, 7, prhs, funcname.c_str()); + if (retVal != 0) + throw DynareException(__FILE__, __LINE__, "Trouble calling " + funcname); + + residual = Vector{plhs[0]}; + mxDestroyArray(plhs[0]); + } + + for (size_t i = 1; i <= md.size(); i++) + { + // Compute model derivatives + std::string funcname = DynamicMFilename + "_g" + std::to_string(i); + mxArray *plhs[1], *prhs[] = { T_m, y_m, x_m, params_m, steady_state_m, it_m, T_flag_m }; + + int retVal = mexCallMATLAB(1, plhs, 7, prhs, funcname.c_str()); + if (retVal != 0) + throw DynareException(__FILE__, __LINE__, "Trouble calling " + funcname); + + assert(static_cast(mxGetM(plhs[0])) == md[i-1].nrows()); + assert(static_cast(mxGetN(plhs[0])) == md[i-1].ncols()); + + if (i == 1) + std::copy_n(mxGetPr(plhs[0]), mxGetM(plhs[0])*mxGetN(plhs[0]), md[i-1].base()); + else + unpackSparseMatrixAndCopyIntoTwoDMatData(plhs[0], md[i-1]); + + mxDestroyArray(plhs[0]); + } + + mxDestroyArray(T_m); + mxDestroyArray(y_m); + mxDestroyArray(x_m); + mxDestroyArray(params_m); + mxDestroyArray(steady_state_m); + mxDestroyArray(it_m); + mxDestroyArray(T_flag_m); }