C++17 modernization: use std::optional for the storage of orig_symb_id and orig_lead_lag in SymbolTable

For the diff and unaryOp auxvar types, these value may be set or unset
depending on the complexity of the expression represented by the auxvar.
fix-tolerance-parameters
Sébastien Villemot 2022-05-16 15:15:17 +02:00
parent 8fd1505ca2
commit fa7a926143
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
4 changed files with 48 additions and 45 deletions

View File

@ -3376,8 +3376,8 @@ DynamicModel::getVARDerivIDs(int lhs_symb_id, int lead_lag) const
continue; continue;
} }
if (avi->get_type() == AuxVarType::endoLag && avi->get_orig_symb_id() == lhs_symb_id if (avi->get_type() == AuxVarType::endoLag && avi->get_orig_symb_id().value() == lhs_symb_id
&& avi->get_orig_lead_lag() + lead_lag2 == lead_lag) && avi->get_orig_lead_lag().value() + lead_lag2 == lead_lag)
deriv_ids.push_back(deriv_id2); deriv_ids.push_back(deriv_id2);
// Handle diff lag auxvar, possibly nested several times // Handle diff lag auxvar, possibly nested several times
@ -3392,7 +3392,7 @@ DynamicModel::getVARDerivIDs(int lhs_symb_id, int lead_lag) const
} }
try try
{ {
avi = &symbol_table.getAuxVarInfo(avi->get_orig_symb_id()); avi = &symbol_table.getAuxVarInfo(avi->get_orig_symb_id().value());
} }
catch (SymbolTable::UnknownSymbolIDException) catch (SymbolTable::UnknownSymbolIDException)
{ {

View File

@ -8970,9 +8970,9 @@ ExprNode::matchParamTimesTargetMinusVariable(int symb_id) const
return true; return true;
return (avi.get_type() == AuxVarType::unaryOp return (avi.get_type() == AuxVarType::unaryOp
&& avi.get_unary_op() == "log" && avi.get_unary_op() == "log"
&& avi.get_orig_symb_id() != -1 && avi.get_orig_symb_id()
&& !datatree.symbol_table.isAuxiliaryVariable(avi.get_orig_symb_id()) && !datatree.symbol_table.isAuxiliaryVariable(*avi.get_orig_symb_id())
&& target->lag + avi.get_orig_lead_lag() == -1); && target->lag + avi.get_orig_lead_lag().value() == -1);
} }
else else
return target->lag == -1; return target->lag == -1;

View File

@ -29,9 +29,9 @@
#include "SymbolTable.hh" #include "SymbolTable.hh"
AuxVarInfo::AuxVarInfo(int symb_id_arg, AuxVarType type_arg, int orig_symb_id_arg, int orig_lead_lag_arg, AuxVarInfo::AuxVarInfo(int symb_id_arg, AuxVarType type_arg, optional<int> orig_symb_id_arg,
int equation_number_for_multiplier_arg, int information_set_arg, optional<int> orig_lead_lag_arg, int equation_number_for_multiplier_arg,
expr_t expr_node_arg, string unary_op_arg) : int information_set_arg, expr_t expr_node_arg, string unary_op_arg) :
symb_id{symb_id_arg}, symb_id{symb_id_arg},
type{type_arg}, type{type_arg},
orig_symb_id{orig_symb_id_arg}, orig_symb_id{orig_symb_id_arg},
@ -363,16 +363,16 @@ SymbolTable::writeOutput(ostream &output) const noexcept(false)
case AuxVarType::diffLag: case AuxVarType::diffLag:
case AuxVarType::diffLead: case AuxVarType::diffLead:
case AuxVarType::diffForward: case AuxVarType::diffForward:
output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 << ";" << endl output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id().value())+1 << ";" << endl
<< "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl; << "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag().value() << ";" << endl;
break; break;
case AuxVarType::unaryOp: case AuxVarType::unaryOp:
output << "M_.aux_vars(" << i+1 << ").unary_op = '" << aux_vars[i].get_unary_op() << "';" << endl; output << "M_.aux_vars(" << i+1 << ").unary_op = '" << aux_vars[i].get_unary_op() << "';" << endl;
// NB: Fallback! // NB: Fallback!
case AuxVarType::diff: case AuxVarType::diff:
if (aux_vars[i].get_orig_symb_id() >= 0) if (aux_vars[i].get_orig_symb_id())
output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 << ";" << endl output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(*aux_vars[i].get_orig_symb_id())+1 << ";" << endl
<< "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl; << "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag().value() << ";" << endl;
break; break;
case AuxVarType::multiplier: case AuxVarType::multiplier:
output << "M_.aux_vars(" << i+1 << ").eq_nbr = " << aux_vars[i].get_equation_number_for_multiplier() + 1 << ";" << endl; output << "M_.aux_vars(" << i+1 << ").eq_nbr = " << aux_vars[i].get_equation_number_for_multiplier() + 1 << ";" << endl;
@ -589,7 +589,7 @@ SymbolTable::addDiffLeadAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_i
} }
int int
SymbolTable::addDiffAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lag) noexcept(false) SymbolTable::addDiffAuxiliaryVar(int index, expr_t expr_arg, optional<int> orig_symb_id, optional<int> orig_lag) noexcept(false)
{ {
ostringstream varname; ostringstream varname;
int symb_id; int symb_id;
@ -606,13 +606,13 @@ SymbolTable::addDiffAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, i
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
aux_vars.emplace_back(symb_id, AuxVarType::diff, orig_symb_id, orig_lag, 0, 0, expr_arg, ""); aux_vars.emplace_back(symb_id, AuxVarType::diff, move(orig_symb_id), move(orig_lag), 0, 0, expr_arg, "");
return symb_id; return symb_id;
} }
int int
SymbolTable::addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, string unary_op, int orig_symb_id, int orig_lag) noexcept(false) SymbolTable::addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, string unary_op, optional<int> orig_symb_id, optional<int> orig_lag) noexcept(false)
{ {
ostringstream varname; ostringstream varname;
int symb_id; int symb_id;
@ -628,7 +628,7 @@ SymbolTable::addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, string unary_op,
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
aux_vars.emplace_back(symb_id, AuxVarType::unaryOp, orig_symb_id, orig_lag, 0, 0, expr_arg, unary_op); aux_vars.emplace_back(symb_id, AuxVarType::unaryOp, move(orig_symb_id), move(orig_lag), 0, 0, expr_arg, unary_op);
return symb_id; return symb_id;
} }
@ -733,10 +733,10 @@ SymbolTable::getOrigSymbIdForAuxVar(int aux_var_symb_id) const noexcept(false)
|| aux_var.get_type() == AuxVarType::diffForward || aux_var.get_type() == AuxVarType::diffForward
|| aux_var.get_type() == AuxVarType::unaryOp) || aux_var.get_type() == AuxVarType::unaryOp)
&& aux_var.get_symb_id() == aux_var_symb_id) && aux_var.get_symb_id() == aux_var_symb_id)
if (int r = aux_var.get_orig_symb_id(); r >= 0) if (optional<int> r = aux_var.get_orig_symb_id(); r)
return r; return *r;
else else
throw UnknownSymbolIDException(aux_var_symb_id); // Some diff and unaryOp auxvars have orig_symb_id == -1 throw UnknownSymbolIDException(aux_var_symb_id); // Some diff and unaryOp auxvars have orig_symb_id unset
throw UnknownSymbolIDException(aux_var_symb_id); throw UnknownSymbolIDException(aux_var_symb_id);
} }
@ -747,8 +747,8 @@ SymbolTable::unrollDiffLeadLagChain(int symb_id, int lag) const noexcept(false)
if (aux_var.get_symb_id() == symb_id) if (aux_var.get_symb_id() == symb_id)
if (aux_var.get_type() == AuxVarType::diffLag || aux_var.get_type() == AuxVarType::diffLead) if (aux_var.get_type() == AuxVarType::diffLag || aux_var.get_type() == AuxVarType::diffLead)
{ {
auto [orig_symb_id, orig_lag] = unrollDiffLeadLagChain(aux_var.get_orig_symb_id(), lag); auto [orig_symb_id, orig_lag] = unrollDiffLeadLagChain(aux_var.get_orig_symb_id().value(), lag);
return { orig_symb_id, orig_lag + aux_var.get_orig_lead_lag() }; return { orig_symb_id, orig_lag + aux_var.get_orig_lead_lag().value() };
} }
return { symb_id, lag }; return { symb_id, lag };
} }
@ -1026,16 +1026,16 @@ SymbolTable::writeJsonOutput(ostream &output) const
case AuxVarType::diffLag: case AuxVarType::diffLag:
case AuxVarType::diffLead: case AuxVarType::diffLead:
case AuxVarType::diffForward: case AuxVarType::diffForward:
output << R"(, "orig_index": )" << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 output << R"(, "orig_index": )" << getTypeSpecificID(aux_vars[i].get_orig_symb_id().value())+1
<< R"(, "orig_lead_lag": )" << aux_vars[i].get_orig_lead_lag(); << R"(, "orig_lead_lag": )" << aux_vars[i].get_orig_lead_lag().value();
break; break;
case AuxVarType::unaryOp: case AuxVarType::unaryOp:
output << R"(, "unary_op": ")" << aux_vars[i].get_unary_op() << R"(")"; output << R"(, "unary_op": ")" << aux_vars[i].get_unary_op() << R"(")";
// NB: Fallback! // NB: Fallback!
case AuxVarType::diff: case AuxVarType::diff:
if (aux_vars[i].get_orig_symb_id() >= 0) if (aux_vars[i].get_orig_symb_id())
output << R"(, "orig_index": )" << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 output << R"(, "orig_index": )" << getTypeSpecificID(*aux_vars[i].get_orig_symb_id())+1
<< R"(, "orig_lead_lag": )" << aux_vars[i].get_orig_lead_lag(); << R"(, "orig_lead_lag": )" << aux_vars[i].get_orig_lead_lag().value();
break; break;
case AuxVarType::multiplier: case AuxVarType::multiplier:
output << R"(, "eq_nbr": )" << aux_vars[i].get_equation_number_for_multiplier() + 1; output << R"(, "eq_nbr": )" << aux_vars[i].get_equation_number_for_multiplier() + 1;

View File

@ -26,6 +26,7 @@
#include <vector> #include <vector>
#include <set> #include <set>
#include <ostream> #include <ostream>
#include <optional>
#include "CodeInterpreter.hh" #include "CodeInterpreter.hh"
#include "ExprNode.hh" #include "ExprNode.hh"
@ -61,24 +62,26 @@ class AuxVarInfo
private: private:
int symb_id; //!< Symbol ID of the auxiliary variable int symb_id; //!< Symbol ID of the auxiliary variable
AuxVarType type; //!< Its type AuxVarType type; //!< Its type
int orig_symb_id; /* Symbol ID of the (only) endo that appears on the RHS of optional<int> orig_symb_id; /* Symbol ID of the (only) endo that appears on the RHS of
the definition of this auxvar. the definition of this auxvar.
Used by endoLag, exoLag, diffForward, logTransform, diff, diffLag, Used by endoLag, exoLag, diffForward, logTransform, diff, diffLag,
diffLead and unaryOp. diffLead and unaryOp.
For diff and unaryOp, if the argument expression is more complex For diff and unaryOp, if the argument expression is more complex
than than a simple variable, this value is equal to -1. */ than than a simple variable, this value is unset
int orig_lead_lag; /* Lead/lag of the (only) endo as it appears on the RHS of the definition (hence the need for std::optional). */
of this auxvar. Only used if orig_symb_id is used. optional<int> orig_lead_lag; /* Lead/lag of the (only) endo as it appears on the RHS of the definition
(in particular, for diff and unaryOp, unused if orig_symb_id == -1). of this auxvar. Only set if orig_symb_id is set
For diff and diffForward, since the definition of the (in particular, for diff and unaryOp, unset
auxvar is a time difference, the value corresponds to the if orig_symb_id is unset).
time index of the first term of that difference. */ For diff and diffForward, since the definition of the
auxvar is a time difference, the value corresponds to the
time index of the first term of that difference. */
int equation_number_for_multiplier; //!< Stores the original constraint equation number associated with this aux var. Only used for avMultiplier. int equation_number_for_multiplier; //!< Stores the original constraint equation number associated with this aux var. Only used for avMultiplier.
int information_set; //! Argument of expectation operator. Only used for avExpectation. int information_set; //! Argument of expectation operator. Only used for avExpectation.
expr_t expr_node; //! Auxiliary variable definition expr_t expr_node; //! Auxiliary variable definition
string unary_op; //! Used with AuxUnaryOp string unary_op; //! Used with AuxUnaryOp
public: public:
AuxVarInfo(int symb_id_arg, AuxVarType type_arg, int orig_symb_id, int orig_lead_lag, int equation_number_for_multiplier_arg, int information_set_arg, expr_t expr_node_arg, string unary_op_arg); AuxVarInfo(int symb_id_arg, AuxVarType type_arg, optional<int> orig_symb_id_arg, optional<int> orig_lead_lag_arg, int equation_number_for_multiplier_arg, int information_set_arg, expr_t expr_node_arg, string unary_op_arg);
int int
get_symb_id() const get_symb_id() const
{ {
@ -94,12 +97,12 @@ public:
{ {
return static_cast<int>(type); return static_cast<int>(type);
} }
int optional<int>
get_orig_symb_id() const get_orig_symb_id() const
{ {
return orig_symb_id; return orig_symb_id;
}; };
int optional<int>
get_orig_lead_lag() const get_orig_lead_lag() const
{ {
return orig_lead_lag; return orig_lead_lag;
@ -349,13 +352,13 @@ public:
diffLead increases it). */ diffLead increases it). */
pair<int, int> unrollDiffLeadLagChain(int symb_id, int lag) const noexcept(false); pair<int, int> unrollDiffLeadLagChain(int symb_id, int lag) const noexcept(false);
//! Adds an auxiliary variable when the diff operator is encountered //! Adds an auxiliary variable when the diff operator is encountered
int addDiffAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id = -1, int orig_lag = 0) noexcept(false); int addDiffAuxiliaryVar(int index, expr_t expr_arg, optional<int> orig_symb_id = nullopt, optional<int> orig_lag = nullopt) noexcept(false);
//! Takes care of timing between diff statements //! Takes care of timing between diff statements
int addDiffLagAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lag) noexcept(false); int addDiffLagAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lag) noexcept(false);
//! Takes care of timing between diff statements //! Takes care of timing between diff statements
int addDiffLeadAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lead) noexcept(false); int addDiffLeadAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lead) noexcept(false);
//! An Auxiliary variable for a unary op //! An Auxiliary variable for a unary op
int addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, string unary_op, int orig_symb_id = -1, int orig_lag = 0) noexcept(false); int addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, string unary_op, optional<int> orig_symb_id = nullopt, optional<int> orig_lag = nullopt) noexcept(false);
//! An auxiliary variable for a pac_expectation operator //! An auxiliary variable for a pac_expectation operator
int addPacExpectationAuxiliaryVar(const string &name, expr_t expr_arg); int addPacExpectationAuxiliaryVar(const string &name, expr_t expr_arg);
//! An auxiliary variable for a pac_target_nonstationary operator //! An auxiliary variable for a pac_target_nonstationary operator