diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc index d200e06f..15daa76c 100644 --- a/src/DynamicModel.cc +++ b/src/DynamicModel.cc @@ -3119,12 +3119,13 @@ DynamicModel::updateVarAndTrendModel() const { for (bool var : { true, false }) { - map> trend_varr; + map>> trend_varr; map>>> rhsr; for (const auto &[model_name, eqns] : (var ? var_model_table.getEqNums() : trend_component_model_table.getEqNums())) { - vector lhs, trend_var, trend_lhs; + vector lhs, trend_lhs; + vector> trend_var; vector>> rhs; if (!var) @@ -3160,25 +3161,25 @@ DynamicModel::updateVarAndTrendModel() const catch (...) { } - int trend_var_symb_id = equations[eqn]->arg2->findTargetVariable(lhs_symb_id); - if (trend_var_symb_id >= 0) + optional trend_var_symb_id = equations[eqn]->arg2->findTargetVariable(lhs_symb_id); + if (trend_var_symb_id) { - if (symbol_table.isDiffAuxiliaryVariable(trend_var_symb_id)) + if (symbol_table.isDiffAuxiliaryVariable(*trend_var_symb_id)) try { - trend_var_symb_id = symbol_table.getOrigSymbIdForAuxVar(trend_var_symb_id); + trend_var_symb_id = symbol_table.getOrigSymbIdForAuxVar(*trend_var_symb_id); } catch (...) { } - if (find(trend_lhs.begin(), trend_lhs.end(), trend_var_symb_id) == trend_lhs.end()) + if (find(trend_lhs.begin(), trend_lhs.end(), *trend_var_symb_id) == trend_lhs.end()) { cerr << "ERROR: trend found in trend_component equation #" << eqn << " (" - << symbol_table.getName(trend_var_symb_id) << ") does not correspond to a trend equation" << endl; + << symbol_table.getName(*trend_var_symb_id) << ") does not correspond to a trend equation" << endl; exit(EXIT_FAILURE); } } - trend_var.push_back(trend_var_symb_id); + trend_var.push_back(move(trend_var_symb_id)); } } diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 35738af6..3107f441 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -643,10 +643,10 @@ NumConstNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) { } -int +optional NumConstNode::findTargetVariable(int lhs_symb_id) const { - return -1; + return nullopt; } expr_t @@ -1593,13 +1593,13 @@ VariableNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) datatree.getLocalVariable(symb_id)->findUnaryOpNodesForAuxVarCreation(nodes); } -int +optional VariableNode::findTargetVariable(int lhs_symb_id) const { if (get_type() == SymbolType::modelLocalVariable) return datatree.getLocalVariable(symb_id)->findTargetVariable(lhs_symb_id); - return -1; + return nullopt; } expr_t @@ -3393,7 +3393,7 @@ UnaryOpNode::findDiffNodes(lag_equivalence_table_t &nodes) const nodes[lag_equiv_repr][index] = const_cast(this); } -int +optional UnaryOpNode::findTargetVariable(int lhs_symb_id) const { return arg->findTargetVariable(lhs_symb_id); @@ -5310,14 +5310,14 @@ BinaryOpNode::findTargetVariableHelper1(int lhs_symb_id, int rhs_symb_id) const return false; } -int +optional BinaryOpNode::findTargetVariableHelper(const expr_t arg1, const expr_t arg2, int lhs_symb_id) const { set params; arg1->collectVariables(SymbolType::parameter, params); if (params.size() != 1) - return -1; + return nullopt; set> endogs; arg2->collectDynamicVariables(SymbolType::endogenous, endogs); @@ -5331,18 +5331,18 @@ BinaryOpNode::findTargetVariableHelper(const expr_t arg1, const expr_t arg2, else if (findTargetVariableHelper1(lhs_symb_id, endogs.rbegin()->first)) return endogs.begin()->first; } - return -1; + return nullopt; } -int +optional BinaryOpNode::findTargetVariable(int lhs_symb_id) const { - int retval = findTargetVariableHelper(arg1, arg2, lhs_symb_id); - if (retval < 0) + optional retval = findTargetVariableHelper(arg1, arg2, lhs_symb_id); + if (!retval) retval = findTargetVariableHelper(arg2, arg1, lhs_symb_id); - if (retval < 0) + if (!retval) retval = arg1->findTargetVariable(lhs_symb_id); - if (retval < 0) + if (!retval) retval = arg2->findTargetVariable(lhs_symb_id); return retval; } @@ -6447,13 +6447,13 @@ TrinaryOpNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) arg3->findUnaryOpNodesForAuxVarCreation(nodes); } -int +optional TrinaryOpNode::findTargetVariable(int lhs_symb_id) const { - int retval = arg1->findTargetVariable(lhs_symb_id); - if (retval < 0) + optional retval = arg1->findTargetVariable(lhs_symb_id); + if (!retval) retval = arg2->findTargetVariable(lhs_symb_id); - if (retval < 0) + if (!retval) retval = arg3->findTargetVariable(lhs_symb_id); return retval; } @@ -6871,14 +6871,14 @@ AbstractExternalFunctionNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_ argument->findUnaryOpNodesForAuxVarCreation(nodes); } -int +optional AbstractExternalFunctionNode::findTargetVariable(int lhs_symb_id) const { for (auto argument : arguments) - if (int retval = argument->findTargetVariable(lhs_symb_id); - retval >= 0) + if (optional retval = argument->findTargetVariable(lhs_symb_id); + retval) return retval; - return -1; + return nullopt; } expr_t @@ -8391,10 +8391,10 @@ SubModelNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) { } -int +optional SubModelNode::findTargetVariable(int lhs_symb_id) const { - return -1; + return nullopt; } expr_t diff --git a/src/ExprNode.hh b/src/ExprNode.hh index 9ea42414..3a77d815 100644 --- a/src/ExprNode.hh +++ b/src/ExprNode.hh @@ -25,6 +25,7 @@ #include #include #include +#include using namespace std; @@ -620,7 +621,7 @@ public: //! Substitute pac_target_nonstationary operator virtual expr_t substitutePacTargetNonstationary(const string &name, expr_t subexpr) = 0; - virtual int findTargetVariable(int lhs_symb_id) const = 0; + virtual optional findTargetVariable(int lhs_symb_id) const = 0; //! Add ExprNodes to the provided datatree virtual expr_t clone(DataTree &datatree) const = 0; @@ -811,7 +812,7 @@ public: expr_t substituteVarExpectation(const map &subst_table) const override; void findDiffNodes(lag_equivalence_table_t &nodes) const override; void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override; - int findTargetVariable(int lhs_symb_id) const override; + optional findTargetVariable(int lhs_symb_id) const override; expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector &neweqs) const override; expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector &neweqs) const override; expr_t substitutePacExpectation(const string &name, expr_t subexpr) override; @@ -884,7 +885,7 @@ public: expr_t substituteVarExpectation(const map &subst_table) const override; void findDiffNodes(lag_equivalence_table_t &nodes) const override; void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override; - int findTargetVariable(int lhs_symb_id) const override; + optional findTargetVariable(int lhs_symb_id) const override; expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector &neweqs) const override; expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector &neweqs) const override; expr_t substitutePacExpectation(const string &name, expr_t subexpr) override; @@ -989,7 +990,7 @@ public: void findDiffNodes(lag_equivalence_table_t &nodes) const override; bool createAuxVarForUnaryOpNode() const; void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override; - int findTargetVariable(int lhs_symb_id) const override; + optional findTargetVariable(int lhs_symb_id) const override; expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector &neweqs) const override; expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector &neweqs) const override; expr_t substitutePacExpectation(const string &name, expr_t subexpr) override; @@ -1096,8 +1097,8 @@ public: void findDiffNodes(lag_equivalence_table_t &nodes) const override; void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override; bool findTargetVariableHelper1(int lhs_symb_id, int rhs_symb_id) const; - int findTargetVariableHelper(const expr_t arg1, const expr_t arg2, int lhs_symb_id) const; - int findTargetVariable(int lhs_symb_id) const override; + optional findTargetVariableHelper(const expr_t arg1, const expr_t arg2, int lhs_symb_id) const; + optional findTargetVariable(int lhs_symb_id) const override; expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector &neweqs) const override; expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector &neweqs) const override; expr_t substitutePacExpectation(const string &name, expr_t subexpr) override; @@ -1232,7 +1233,7 @@ public: expr_t substituteVarExpectation(const map &subst_table) const override; void findDiffNodes(lag_equivalence_table_t &nodes) const override; void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override; - int findTargetVariable(int lhs_symb_id) const override; + optional findTargetVariable(int lhs_symb_id) const override; expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector &neweqs) const override; expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector &neweqs) const override; expr_t substitutePacExpectation(const string &name, expr_t subexpr) override; @@ -1342,7 +1343,7 @@ public: expr_t substituteVarExpectation(const map &subst_table) const override; void findDiffNodes(lag_equivalence_table_t &nodes) const override; void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override; - int findTargetVariable(int lhs_symb_id) const override; + optional findTargetVariable(int lhs_symb_id) const override; expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector &neweqs) const override; expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector &neweqs) const override; expr_t substitutePacExpectation(const string &name, expr_t subexpr) override; @@ -1520,7 +1521,7 @@ public: expr_t substituteModelLocalVariables() const override; void findDiffNodes(lag_equivalence_table_t &nodes) const override; void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override; - int findTargetVariable(int lhs_symb_id) const override; + optional findTargetVariable(int lhs_symb_id) const override; expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector &neweqs) const override; expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector &neweqs) const override; void computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const override; diff --git a/src/SubModel.cc b/src/SubModel.cc index ccc72512..f80b3b47 100644 --- a/src/SubModel.cc +++ b/src/SubModel.cc @@ -83,7 +83,7 @@ TrendComponentModelTable::setRhs(map>>> rhs_ar } void -TrendComponentModelTable::setTargetVar(map> target_vars_arg) +TrendComponentModelTable::setTargetVar(map>> target_vars_arg) { target_vars = move(target_vars_arg); } @@ -319,8 +319,8 @@ TrendComponentModelTable::writeOutput(const string &basename, ostream &output) c i++; } output << "M_.trend_component." << name << ".target_vars = ["; - for (auto it : target_vars.at(name)) - output << (it >= 0 ? symbol_table.getTypeSpecificID(it) + 1 : -1) << " "; + for (const optional &it : target_vars.at(name)) + output << (it ? symbol_table.getTypeSpecificID(*it) + 1 : -1) << " "; output << "];" << endl; vector target_eqtags_vec = target_eqtags.at(name); diff --git a/src/SubModel.hh b/src/SubModel.hh index af19c6db..6c24d2cb 100644 --- a/src/SubModel.hh +++ b/src/SubModel.hh @@ -24,6 +24,7 @@ #include #include #include +#include #include "ExprNode.hh" #include "SymbolTable.hh" @@ -49,7 +50,7 @@ private: map>>> rhs; map> diff; map> lhs_expr_t; - map> target_vars; + map>> target_vars; map, expr_t>> AR; // name -> (eqn, lag, lhs_symb_id) -> expr_t /* Note that A0 in the trend-component model context is not the same thing as in the structural VAR context. */ @@ -89,7 +90,7 @@ public: void setMaxLags(map> max_lags_arg); void setDiff(map> diff_arg); void setOrigDiffVar(map> orig_diff_var_arg); - void setTargetVar(map> target_vars_arg); + void setTargetVar(map>> target_vars_arg); void setAR(map, expr_t>> AR_arg); void setA0(map, expr_t>> A0_arg, map, expr_t>> A0star_arg);