diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc index 5282909a..38c88eca 100644 --- a/src/DynamicModel.cc +++ b/src/DynamicModel.cc @@ -3456,31 +3456,48 @@ DynamicModel::runTrendTest(const eval_context_t &eval_context) } void -DynamicModel::updateVarAndTrendModelRhs() const +DynamicModel::updateVarAndTrendModel() const { for (int i = 0; i < 2; i++) { - map> eqnums; + map> eqnums, trend_eqnums; if (i == 0) eqnums = var_model_table.getEqNums(); else if (i == 1) - eqnums = trend_component_model_table.getEqNums(); + { + eqnums = trend_component_model_table.getEqNums(); + trend_eqnums = trend_component_model_table.getTrendEqNums(); + } map> trend_varr; map>>> rhsr; for (const auto & it : eqnums) { - vector lhs; - vector trend_var; + vector lhs, trend_var, trend_lhs; vector>> rhs; - int lhs_idx = 0; + if (i == 1) - lhs = trend_component_model_table.getLhs(it.first); + { + lhs = trend_component_model_table.getLhs(it.first); + for (auto teqn : trend_eqnums.at(it.first)) + { + int eqnidx = 0; + for (auto eqn : it.second) + { + if (eqn == teqn) + trend_lhs.push_back(lhs[eqnidx]); + eqnidx++; + } + } + } + + int lhs_idx = 0; for (auto eqn : it.second) { set> rhs_set; equations[eqn]->get_arg2()->collectDynamicVariables(SymbolType::endogenous, rhs_set); rhs.push_back(rhs_set); + if (i == 1) { int lhs_symb_id = lhs[lhs_idx++]; @@ -3492,11 +3509,31 @@ DynamicModel::updateVarAndTrendModelRhs() const catch (...) { } - trend_var.push_back(equations[eqn]->get_arg2()->findTrendVariable(lhs_symb_id)); + int trend_var_symb_id = equations[eqn]->get_arg2()->findTrendVariable(lhs_symb_id); + trend_var.push_back(trend_var_symb_id); + if (trend_var_symb_id >= 0) + { + if (symbol_table.isAuxiliaryVariable(trend_var_symb_id)) + try + { + trend_var_symb_id = symbol_table.getOrigSymbIdForAuxVar(trend_var_symb_id); + } + catch (...) + { + } + if (find(trend_lhs.begin(), trend_lhs.end(), trend_var_symb_id) == trend_lhs.end()) + { + cerr << "ERROR: trend found in trend_component equation #" << eqn << " (" + << symbol_table.getName(trend_var_symb_id) << ") does not correspond to a trend equation" << endl; + exit(EXIT_FAILURE); + } + } } } + rhsr[it.first] = rhs; - trend_varr[it.first] = trend_var; + if (i == 1) + trend_varr[it.first] = trend_var; } if (i == 0) diff --git a/src/DynamicModel.hh b/src/DynamicModel.hh index 30bdc162..9224149b 100644 --- a/src/DynamicModel.hh +++ b/src/DynamicModel.hh @@ -313,8 +313,9 @@ public: void fillVarModelTableFromOrigModel(StaticModel &static_model) const; //! Update the rhs references in the var model and trend component tables - //! after substitution of auxiliary variables - void updateVarAndTrendModelRhs() const; + //! after substitution of auxiliary variables and find the trend variables + //! in the trend_component model + void updateVarAndTrendModel() const; //! Add aux equations (and aux variables) for variables declared in var_model //! at max order if they don't already exist diff --git a/src/ModFile.cc b/src/ModFile.cc index ee3512c0..e7183d5d 100644 --- a/src/ModFile.cc +++ b/src/ModFile.cc @@ -580,7 +580,7 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, const dynamic_model.substituteEndoLagGreaterThanTwo(true); } - dynamic_model.updateVarAndTrendModelRhs(); + dynamic_model.updateVarAndTrendModel(); if (differentiate_forward_vars) dynamic_model.differentiateForwardVars(differentiate_forward_vars_subset); diff --git a/src/SubModel.cc b/src/SubModel.cc index 62613b39..f392c926 100644 --- a/src/SubModel.cc +++ b/src/SubModel.cc @@ -177,6 +177,12 @@ TrendComponentModelTable::getEqNums() const return eqnums; } +map> +TrendComponentModelTable::getTrendEqNums() const +{ + return trend_eqnums; +} + vector TrendComponentModelTable::getNonTrendEqNums(const string &name_arg) const { diff --git a/src/SubModel.hh b/src/SubModel.hh index e5cba518..96432780 100644 --- a/src/SubModel.hh +++ b/src/SubModel.hh @@ -59,6 +59,7 @@ public: vector getEqTags(const string &name_arg) const; map> getTrendEqTags() const; map> getEqNums() const; + map> getTrendEqNums() const; vector getEqNums(const string &name_arg) const; vector getMaxLags(const string &name_arg) const; int getMaxLag(const string &name_arg) const;