diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc index bbc457ce..81afb452 100644 --- a/src/DynamicModel.cc +++ b/src/DynamicModel.cc @@ -3376,8 +3376,8 @@ DynamicModel::getVARDerivIDs(int lhs_symb_id, int lead_lag) const continue; } - if (avi->get_type() == AuxVarType::endoLag && avi->get_orig_symb_id() == lhs_symb_id - && avi->get_orig_lead_lag() + lead_lag2 == lead_lag) + if (avi->get_type() == AuxVarType::endoLag && avi->get_orig_symb_id().value() == lhs_symb_id + && avi->get_orig_lead_lag().value() + lead_lag2 == lead_lag) deriv_ids.push_back(deriv_id2); // Handle diff lag auxvar, possibly nested several times @@ -3392,7 +3392,7 @@ DynamicModel::getVARDerivIDs(int lhs_symb_id, int lead_lag) const } try { - avi = &symbol_table.getAuxVarInfo(avi->get_orig_symb_id()); + avi = &symbol_table.getAuxVarInfo(avi->get_orig_symb_id().value()); } catch (SymbolTable::UnknownSymbolIDException) { diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 3107f441..6dd398e1 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -8970,9 +8970,9 @@ ExprNode::matchParamTimesTargetMinusVariable(int symb_id) const return true; return (avi.get_type() == AuxVarType::unaryOp && avi.get_unary_op() == "log" - && avi.get_orig_symb_id() != -1 - && !datatree.symbol_table.isAuxiliaryVariable(avi.get_orig_symb_id()) - && target->lag + avi.get_orig_lead_lag() == -1); + && avi.get_orig_symb_id() + && !datatree.symbol_table.isAuxiliaryVariable(*avi.get_orig_symb_id()) + && target->lag + avi.get_orig_lead_lag().value() == -1); } else return target->lag == -1; diff --git a/src/SymbolTable.cc b/src/SymbolTable.cc index 8d4cc8d3..dae16a23 100644 --- a/src/SymbolTable.cc +++ b/src/SymbolTable.cc @@ -29,9 +29,9 @@ #include "SymbolTable.hh" -AuxVarInfo::AuxVarInfo(int symb_id_arg, AuxVarType type_arg, int orig_symb_id_arg, int orig_lead_lag_arg, - int equation_number_for_multiplier_arg, int information_set_arg, - expr_t expr_node_arg, string unary_op_arg) : +AuxVarInfo::AuxVarInfo(int symb_id_arg, AuxVarType type_arg, optional orig_symb_id_arg, + optional orig_lead_lag_arg, int equation_number_for_multiplier_arg, + int information_set_arg, expr_t expr_node_arg, string unary_op_arg) : symb_id{symb_id_arg}, type{type_arg}, orig_symb_id{orig_symb_id_arg}, @@ -363,16 +363,16 @@ SymbolTable::writeOutput(ostream &output) const noexcept(false) case AuxVarType::diffLag: case AuxVarType::diffLead: case AuxVarType::diffForward: - output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 << ";" << endl - << "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl; + output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id().value())+1 << ";" << endl + << "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag().value() << ";" << endl; break; case AuxVarType::unaryOp: output << "M_.aux_vars(" << i+1 << ").unary_op = '" << aux_vars[i].get_unary_op() << "';" << endl; // NB: Fallback! case AuxVarType::diff: - if (aux_vars[i].get_orig_symb_id() >= 0) - output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 << ";" << endl - << "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl; + if (aux_vars[i].get_orig_symb_id()) + output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(*aux_vars[i].get_orig_symb_id())+1 << ";" << endl + << "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag().value() << ";" << endl; break; case AuxVarType::multiplier: output << "M_.aux_vars(" << i+1 << ").eq_nbr = " << aux_vars[i].get_equation_number_for_multiplier() + 1 << ";" << endl; @@ -589,7 +589,7 @@ SymbolTable::addDiffLeadAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_i } int -SymbolTable::addDiffAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lag) noexcept(false) +SymbolTable::addDiffAuxiliaryVar(int index, expr_t expr_arg, optional orig_symb_id, optional orig_lag) noexcept(false) { ostringstream varname; int symb_id; @@ -606,13 +606,13 @@ SymbolTable::addDiffAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, i exit(EXIT_FAILURE); } - aux_vars.emplace_back(symb_id, AuxVarType::diff, orig_symb_id, orig_lag, 0, 0, expr_arg, ""); + aux_vars.emplace_back(symb_id, AuxVarType::diff, move(orig_symb_id), move(orig_lag), 0, 0, expr_arg, ""); return symb_id; } int -SymbolTable::addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, string unary_op, int orig_symb_id, int orig_lag) noexcept(false) +SymbolTable::addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, string unary_op, optional orig_symb_id, optional orig_lag) noexcept(false) { ostringstream varname; int symb_id; @@ -628,7 +628,7 @@ SymbolTable::addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, string unary_op, exit(EXIT_FAILURE); } - aux_vars.emplace_back(symb_id, AuxVarType::unaryOp, orig_symb_id, orig_lag, 0, 0, expr_arg, unary_op); + aux_vars.emplace_back(symb_id, AuxVarType::unaryOp, move(orig_symb_id), move(orig_lag), 0, 0, expr_arg, unary_op); return symb_id; } @@ -733,10 +733,10 @@ SymbolTable::getOrigSymbIdForAuxVar(int aux_var_symb_id) const noexcept(false) || aux_var.get_type() == AuxVarType::diffForward || aux_var.get_type() == AuxVarType::unaryOp) && aux_var.get_symb_id() == aux_var_symb_id) - if (int r = aux_var.get_orig_symb_id(); r >= 0) - return r; + if (optional r = aux_var.get_orig_symb_id(); r) + return *r; else - throw UnknownSymbolIDException(aux_var_symb_id); // Some diff and unaryOp auxvars have orig_symb_id == -1 + throw UnknownSymbolIDException(aux_var_symb_id); // Some diff and unaryOp auxvars have orig_symb_id unset throw UnknownSymbolIDException(aux_var_symb_id); } @@ -747,8 +747,8 @@ SymbolTable::unrollDiffLeadLagChain(int symb_id, int lag) const noexcept(false) if (aux_var.get_symb_id() == symb_id) if (aux_var.get_type() == AuxVarType::diffLag || aux_var.get_type() == AuxVarType::diffLead) { - auto [orig_symb_id, orig_lag] = unrollDiffLeadLagChain(aux_var.get_orig_symb_id(), lag); - return { orig_symb_id, orig_lag + aux_var.get_orig_lead_lag() }; + auto [orig_symb_id, orig_lag] = unrollDiffLeadLagChain(aux_var.get_orig_symb_id().value(), lag); + return { orig_symb_id, orig_lag + aux_var.get_orig_lead_lag().value() }; } return { symb_id, lag }; } @@ -1026,16 +1026,16 @@ SymbolTable::writeJsonOutput(ostream &output) const case AuxVarType::diffLag: case AuxVarType::diffLead: case AuxVarType::diffForward: - output << R"(, "orig_index": )" << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 - << R"(, "orig_lead_lag": )" << aux_vars[i].get_orig_lead_lag(); + output << R"(, "orig_index": )" << getTypeSpecificID(aux_vars[i].get_orig_symb_id().value())+1 + << R"(, "orig_lead_lag": )" << aux_vars[i].get_orig_lead_lag().value(); break; case AuxVarType::unaryOp: output << R"(, "unary_op": ")" << aux_vars[i].get_unary_op() << R"(")"; // NB: Fallback! case AuxVarType::diff: - if (aux_vars[i].get_orig_symb_id() >= 0) - output << R"(, "orig_index": )" << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 - << R"(, "orig_lead_lag": )" << aux_vars[i].get_orig_lead_lag(); + if (aux_vars[i].get_orig_symb_id()) + output << R"(, "orig_index": )" << getTypeSpecificID(*aux_vars[i].get_orig_symb_id())+1 + << R"(, "orig_lead_lag": )" << aux_vars[i].get_orig_lead_lag().value(); break; case AuxVarType::multiplier: output << R"(, "eq_nbr": )" << aux_vars[i].get_equation_number_for_multiplier() + 1; diff --git a/src/SymbolTable.hh b/src/SymbolTable.hh index 3d7e28ae..e5b21442 100644 --- a/src/SymbolTable.hh +++ b/src/SymbolTable.hh @@ -26,6 +26,7 @@ #include #include #include +#include #include "CodeInterpreter.hh" #include "ExprNode.hh" @@ -61,24 +62,26 @@ class AuxVarInfo private: int symb_id; //!< Symbol ID of the auxiliary variable AuxVarType type; //!< Its type - int orig_symb_id; /* Symbol ID of the (only) endo that appears on the RHS of - the definition of this auxvar. - Used by endoLag, exoLag, diffForward, logTransform, diff, diffLag, - diffLead and unaryOp. - For diff and unaryOp, if the argument expression is more complex - than than a simple variable, this value is equal to -1. */ - int orig_lead_lag; /* Lead/lag of the (only) endo as it appears on the RHS of the definition - of this auxvar. Only used if orig_symb_id is used. - (in particular, for diff and unaryOp, unused if orig_symb_id == -1). - For diff and diffForward, since the definition of the - auxvar is a time difference, the value corresponds to the - time index of the first term of that difference. */ + optional orig_symb_id; /* Symbol ID of the (only) endo that appears on the RHS of + the definition of this auxvar. + Used by endoLag, exoLag, diffForward, logTransform, diff, diffLag, + diffLead and unaryOp. + For diff and unaryOp, if the argument expression is more complex + than than a simple variable, this value is unset + (hence the need for std::optional). */ + optional orig_lead_lag; /* Lead/lag of the (only) endo as it appears on the RHS of the definition + of this auxvar. Only set if orig_symb_id is set + (in particular, for diff and unaryOp, unset + if orig_symb_id is unset). + For diff and diffForward, since the definition of the + auxvar is a time difference, the value corresponds to the + time index of the first term of that difference. */ int equation_number_for_multiplier; //!< Stores the original constraint equation number associated with this aux var. Only used for avMultiplier. int information_set; //! Argument of expectation operator. Only used for avExpectation. expr_t expr_node; //! Auxiliary variable definition string unary_op; //! Used with AuxUnaryOp public: - AuxVarInfo(int symb_id_arg, AuxVarType type_arg, int orig_symb_id, int orig_lead_lag, int equation_number_for_multiplier_arg, int information_set_arg, expr_t expr_node_arg, string unary_op_arg); + AuxVarInfo(int symb_id_arg, AuxVarType type_arg, optional orig_symb_id_arg, optional orig_lead_lag_arg, int equation_number_for_multiplier_arg, int information_set_arg, expr_t expr_node_arg, string unary_op_arg); int get_symb_id() const { @@ -94,12 +97,12 @@ public: { return static_cast(type); } - int + optional get_orig_symb_id() const { return orig_symb_id; }; - int + optional get_orig_lead_lag() const { return orig_lead_lag; @@ -349,13 +352,13 @@ public: diffLead increases it). */ pair unrollDiffLeadLagChain(int symb_id, int lag) const noexcept(false); //! Adds an auxiliary variable when the diff operator is encountered - int addDiffAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id = -1, int orig_lag = 0) noexcept(false); + int addDiffAuxiliaryVar(int index, expr_t expr_arg, optional orig_symb_id = nullopt, optional orig_lag = nullopt) noexcept(false); //! Takes care of timing between diff statements int addDiffLagAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lag) noexcept(false); //! Takes care of timing between diff statements int addDiffLeadAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lead) noexcept(false); //! An Auxiliary variable for a unary op - int addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, string unary_op, int orig_symb_id = -1, int orig_lag = 0) noexcept(false); + int addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, string unary_op, optional orig_symb_id = nullopt, optional orig_lag = nullopt) noexcept(false); //! An auxiliary variable for a pac_expectation operator int addPacExpectationAuxiliaryVar(const string &name, expr_t expr_arg); //! An auxiliary variable for a pac_target_nonstationary operator