From 46709ada3f5d73ef1f38642da014057a9a54f9aa Mon Sep 17 00:00:00 2001 From: Houtan Bastani Date: Fri, 7 Sep 2018 10:56:40 +0200 Subject: [PATCH] output AR matrix in file for trend component models --- src/DynamicModel.cc | 16 ++++++++++---- src/DynamicModel.hh | 2 +- src/ExprNode.cc | 22 ++++++++++++------- src/ModFile.cc | 2 +- src/SubModel.cc | 52 +++++++++++++++++++++++++++++++++++++++------ src/SubModel.hh | 9 +++++--- src/SymbolTable.cc | 25 ++++++++++++++++++++++ src/SymbolTable.hh | 4 ++++ 8 files changed, 109 insertions(+), 23 deletions(-) diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc index bbad80b8..1bf017a7 100644 --- a/src/DynamicModel.cc +++ b/src/DynamicModel.cc @@ -3709,19 +3709,22 @@ DynamicModel::fillVarModelTableFromOrigModel(StaticModel &static_model) const // Fill AR Matrix map, expr_t>> ARr; - fillAutoregressiveMatrix(ARr); + fillAutoregressiveMatrix(ARr, false); var_model_table.setAR(ARr); } void -DynamicModel::fillAutoregressiveMatrix(map, expr_t>> &ARr) const +DynamicModel::fillAutoregressiveMatrix(map, expr_t>> &ARr, bool is_trend_component_model) const { - for (const auto & it : var_model_table.getEqNums()) + auto eqnums = is_trend_component_model ? trend_component_model_table.getEqNums() : var_model_table.getEqNums(); + for (const auto & it : eqnums) { int i = 0; map, expr_t> AR; + vector lhs = is_trend_component_model ? + trend_component_model_table.getLhs(it.first) : var_model_table.getLhs(it.first); for (auto eqn : it.second) - equations[eqn]->get_arg2()->fillAutoregressiveRow(i++, var_model_table.getLhs(it.first), AR); + equations[eqn]->get_arg2()->fillAutoregressiveRow(i++, lhs, AR); ARr[it.first] = AR; } } @@ -3844,6 +3847,11 @@ DynamicModel::fillTrendComponentModelTable() const trend_component_model_table.setRhs(rhsr); trend_component_model_table.setLhsExprT(lhs_expr_tr); trend_component_model_table.setNonstationary(nonstationaryr); + + // Fill AR Matrix + map, expr_t>> ARr; + fillAutoregressiveMatrix(ARr, true); + trend_component_model_table.setAR(ARr); } void diff --git a/src/DynamicModel.hh b/src/DynamicModel.hh index 796cc537..b7bca905 100644 --- a/src/DynamicModel.hh +++ b/src/DynamicModel.hh @@ -305,7 +305,7 @@ public: void setNonZeroHessianEquations(map &eqs); //! Fill Autoregressive Matrix for var_model - void fillAutoregressiveMatrix(map, expr_t>> &ARr) const; + void fillAutoregressiveMatrix(map, expr_t>> &ARr, bool is_trend_component_model) const; //! Fill the Trend Component Model Table void fillTrendComponentModelTable() const; diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 88844180..ade8d751 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -5440,21 +5440,30 @@ BinaryOpNode::fillAutoregressiveRowHelper(expr_t arg1, expr_t arg2, const vector &lhs, map, expr_t> &AR) const { + if (op_code != BinaryOpcode::times) + return; + set> endogs, tmp; arg2->collectDynamicVariables(SymbolType::endogenous, endogs); if (endogs.size() != 1) return; int lhs_symb_id = endogs.begin()->first; - if (find(lhs.begin(), lhs.end(), lhs_symb_id) == lhs.end()) - return; + int lag = endogs.begin()->second; + if (datatree.symbol_table.isAuxiliaryVariable(lhs_symb_id)) + { + int orig_lhs_symb_id = datatree.symbol_table.getOrigSymbIdForDiffAuxVar(lhs_symb_id); + if (find(lhs.begin(), lhs.end(), orig_lhs_symb_id) == lhs.end()) + return; + lag = -(datatree.symbol_table.getOrigLeadLagForDiffAuxVar(lhs_symb_id) - 1); + lhs_symb_id = orig_lhs_symb_id; + } arg1->collectDynamicVariables(SymbolType::endogenous, tmp); arg1->collectDynamicVariables(SymbolType::exogenous, tmp); if (tmp.size() != 0) return; - int lag = endogs.begin()->second; if (AR.find(make_tuple(eqn, -lag, lhs_symb_id)) != AR.end()) { cerr << "BinaryOpNode::fillAutoregressiveRowHelper: Error filling AR matrix: lag/symb_id encountered more than once in equtaion" << endl; @@ -5466,11 +5475,8 @@ BinaryOpNode::fillAutoregressiveRowHelper(expr_t arg1, expr_t arg2, void BinaryOpNode::fillAutoregressiveRow(int eqn, const vector &lhs, map, expr_t> &AR) const { - if (op_code == BinaryOpcode::times) - { - fillAutoregressiveRowHelper(arg1, arg2, eqn, lhs, AR); - fillAutoregressiveRowHelper(arg2, arg1, eqn, lhs, AR); - } + fillAutoregressiveRowHelper(arg1, arg2, eqn, lhs, AR); + fillAutoregressiveRowHelper(arg2, arg1, eqn, lhs, AR); arg1->fillAutoregressiveRow(eqn, lhs, AR); arg2->fillAutoregressiveRow(eqn, lhs, AR); } diff --git a/src/ModFile.cc b/src/ModFile.cc index be36c11c..c94cb80f 100644 --- a/src/ModFile.cc +++ b/src/ModFile.cc @@ -878,7 +878,7 @@ ModFile::writeOutputFiles(const string &basename, bool clear_all, bool clear_glo symbol_table.writeOutput(mOutputFile); var_model_table.writeOutput(basename, mOutputFile); - trend_component_model_table.writeOutput(mOutputFile); + trend_component_model_table.writeOutput(basename, mOutputFile); // Initialize M_.Sigma_e, M_.Correlation_matrix, M_.H, and M_.Correlation_matrix_ME mOutputFile << "M_.Sigma_e = zeros(" << symbol_table.exo_nbr() << ", " diff --git a/src/SubModel.cc b/src/SubModel.cc index c937f1f6..917eaf84 100644 --- a/src/SubModel.cc +++ b/src/SubModel.cc @@ -45,18 +45,18 @@ void TrendComponentModelTable::setEqNums(map> eqnums_arg) { eqnums = move(eqnums_arg); - setUndiffEqnums(); + setNonTrendEqnums(); } void TrendComponentModelTable::setTrendEqNums(map> trend_eqnums_arg) { trend_eqnums = move(trend_eqnums_arg); - setUndiffEqnums(); + setNonTrendEqnums(); } void -TrendComponentModelTable::setUndiffEqnums() +TrendComponentModelTable::setNonTrendEqnums() { if (!nontrend_eqnums.empty() || eqnums.empty() || trend_eqnums.empty()) return; @@ -65,8 +65,7 @@ TrendComponentModelTable::setUndiffEqnums() { vector nontrend_vec; for (auto eq : it.second) - if (find(trend_eqnums[it.first].begin(), trend_eqnums[it.first].end(), eq) - == trend_eqnums[it.first].end()) + if (find(trend_eqnums[it.first].begin(), trend_eqnums[it.first].end(), eq) == trend_eqnums[it.first].end()) nontrend_vec.push_back(eq); nontrend_eqnums[it.first] = nontrend_vec; } @@ -120,6 +119,12 @@ TrendComponentModelTable::setOrigDiffVar(map> orig_diff_var_ orig_diff_var = move(orig_diff_var_arg); } +void +TrendComponentModelTable::setAR(map, expr_t>> AR_arg) +{ + AR = move(AR_arg); +} + map> TrendComponentModelTable::getEqTags() const { @@ -228,8 +233,20 @@ TrendComponentModelTable::getOrigDiffVar(const string &name_arg) const } void -TrendComponentModelTable::writeOutput(ostream &output) const +TrendComponentModelTable::writeOutput(const string &basename, ostream &output) const { + string filename = "+" + basename + "/trend_component_ar.m"; + ofstream ar_output; + ar_output.open(filename, ios::out | ios::binary); + if (!ar_output.is_open()) + { + cerr << "Error: Can't open file " << filename << " for writing" << endl; + exit(EXIT_FAILURE); + } + ar_output << "function ar = trend_component_ar(model_name, params)" << endl + << "%function ar = trend_component_ar(model_name, params)" << endl + << "% File automatically generated by the Dynare preprocessor" << endl << endl; + for (const auto &name : names) { output << "M_.trend_component." << name << ".model_name = '" << name << "';" << endl @@ -291,7 +308,30 @@ TrendComponentModelTable::writeOutput(ostream &output) const for (auto it : trend_vars.at(name)) output << (it >= 0 ? symbol_table.getTypeSpecificID(it) + 1 : -1) << " "; output << "];" << endl; + + vector nontrend_lhs; + vector lhsv = getLhs(name); + vector eqnumsv = getEqNums(name); + for (int nontrend_it : getNonTrendEqNums(name)) + nontrend_lhs.push_back(lhsv.at(distance(eqnumsv.begin(), find(eqnumsv.begin(), eqnumsv.end(), nontrend_it)))); + + ar_output << "if strcmp(model_name, '" << name << "')" << endl + << " ar = zeros(" << nontrend_lhs.size() << ", " << nontrend_lhs.size() << ", " << getMaxLag(name) << ");" << endl; + for (const auto & it : AR.at(name)) + { + int eqn, lag, lhs_symb_id; + tie (eqn, lag, lhs_symb_id) = it.first; + int colidx = (int) distance(nontrend_lhs.begin(), find(nontrend_lhs.begin(), nontrend_lhs.end(), lhs_symb_id)); + ar_output << " ar(" << eqn + 1 << ", " << colidx + 1 << ", " << lag << ") = "; + it.second->writeOutput(ar_output, ExprNodeOutputType::matlabDynamicModel); + ar_output << ";" << endl; + } + ar_output << " return" << endl + << "end" << endl << endl; } + ar_output << "error([model_name ' is not a valid trend_component_model name'])" << endl + << "end" << endl; + ar_output.close(); } void diff --git a/src/SubModel.hh b/src/SubModel.hh index 71307486..649b3896 100644 --- a/src/SubModel.hh +++ b/src/SubModel.hh @@ -45,6 +45,7 @@ private: map> diff, nonstationary; map> lhs_expr_t; map> trend_vars; + map, expr_t>> AR; // AR: name -> (eqn, lag, lhs_symb_id) -> param_expr_t public: TrendComponentModelTable(SymbolTable &symbol_table_arg); @@ -80,16 +81,18 @@ public: void setOrigDiffVar(map> orig_diff_var_arg); void setNonstationary(map> nonstationary_arg); void setTrendVar(map> trend_vars_arg); + void setAR(map, expr_t>> AR_arg); + void setNonTrendEqNums(map> trend_eqnums_arg); //! Write output of this class - void writeOutput(ostream &output) const; + void writeOutput(const string &basename, ostream &output) const; //! Write JSON Output void writeJsonOutput(ostream &output) const; private: void checkModelName(const string &name_arg) const; - void setUndiffEqnums(); + void setNonTrendEqnums(); }; inline bool @@ -120,7 +123,7 @@ private: public: VarModelTable(SymbolTable &symbol_table_arg); - //! Add a trend component model + //! Add a VAR model void addVarModel(string name, vector eqtags, pair symbol_list_and_order_arg); diff --git a/src/SymbolTable.cc b/src/SymbolTable.cc index b08c6117..335a5418 100644 --- a/src/SymbolTable.cc +++ b/src/SymbolTable.cc @@ -874,6 +874,31 @@ SymbolTable::getOrigSymbIdForAuxVar(int aux_var_symb_id) const noexcept(false) throw UnknownSymbolIDException(aux_var_symb_id); } +int +SymbolTable::getOrigLeadLagForDiffAuxVar(int diff_aux_var_symb_id) const noexcept(false) +{ + int lag = 0; + for (const auto & aux_var : aux_vars) + if ((aux_var.get_type() == AuxVarType::diff + || aux_var.get_type() == AuxVarType::diffLag) + && aux_var.get_symb_id() == diff_aux_var_symb_id) + lag += 1 + getOrigLeadLagForDiffAuxVar(aux_var.get_orig_symb_id()); + return lag; +} + +int +SymbolTable::getOrigSymbIdForDiffAuxVar(int diff_aux_var_symb_id) const noexcept(false) +{ + int orig_symb_id = -1; + for (const auto & aux_var : aux_vars) + if (aux_var.get_symb_id() == diff_aux_var_symb_id) + if (aux_var.get_type() == AuxVarType::diff) + orig_symb_id = diff_aux_var_symb_id; + else if (aux_var.get_type() == AuxVarType::diffLag) + orig_symb_id = getOrigSymbIdForDiffAuxVar(aux_var.get_orig_symb_id()); + return orig_symb_id; +} + expr_t SymbolTable::getAuxiliaryVarsExprNode(int symb_id) const noexcept(false) // throw exception if it is a Lagrange multiplier diff --git a/src/SymbolTable.hh b/src/SymbolTable.hh index c2b47d21..9e30bd28 100644 --- a/src/SymbolTable.hh +++ b/src/SymbolTable.hh @@ -291,6 +291,10 @@ public: int searchAuxiliaryVars(int orig_symb_id, int orig_lead_lag) const noexcept(false); //! Serches aux_vars for the aux var represented by aux_var_symb_id and returns its associated orig_symb_id int getOrigSymbIdForAuxVar(int aux_var_symb_id) const noexcept(false); + //! Searches for diff aux var and finds the original lag associated with this variable + int getOrigLeadLagForDiffAuxVar(int diff_aux_var_symb_id) const noexcept(false); + //! Searches for diff aux var and finds the symb id associated with this variable + int getOrigSymbIdForDiffAuxVar(int diff_aux_var_symb_id) const noexcept(false); //! Adds an auxiliary variable when var_model is used with an order that is greater in absolute value //! than the largest lag present in the model. int addVarModelEndoLagAuxiliaryVar(int orig_symb_id, int orig_lead_lag, expr_t expr_arg) noexcept(false);