C++17 modernization: use std::optional for the storage of orig_symb_id and orig_lead_lag in SymbolTable

For the diff and unaryOp auxvar types, these value may be set or unset
depending on the complexity of the expression represented by the auxvar.
fix-tolerance-parameters
Sébastien Villemot 2022-05-16 15:15:17 +02:00
parent 8fd1505ca2
commit fa7a926143
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
4 changed files with 48 additions and 45 deletions

View File

@ -3376,8 +3376,8 @@ DynamicModel::getVARDerivIDs(int lhs_symb_id, int lead_lag) const
continue;
}
if (avi->get_type() == AuxVarType::endoLag && avi->get_orig_symb_id() == lhs_symb_id
&& avi->get_orig_lead_lag() + lead_lag2 == lead_lag)
if (avi->get_type() == AuxVarType::endoLag && avi->get_orig_symb_id().value() == lhs_symb_id
&& avi->get_orig_lead_lag().value() + lead_lag2 == lead_lag)
deriv_ids.push_back(deriv_id2);
// Handle diff lag auxvar, possibly nested several times
@ -3392,7 +3392,7 @@ DynamicModel::getVARDerivIDs(int lhs_symb_id, int lead_lag) const
}
try
{
avi = &symbol_table.getAuxVarInfo(avi->get_orig_symb_id());
avi = &symbol_table.getAuxVarInfo(avi->get_orig_symb_id().value());
}
catch (SymbolTable::UnknownSymbolIDException)
{

View File

@ -8970,9 +8970,9 @@ ExprNode::matchParamTimesTargetMinusVariable(int symb_id) const
return true;
return (avi.get_type() == AuxVarType::unaryOp
&& avi.get_unary_op() == "log"
&& avi.get_orig_symb_id() != -1
&& !datatree.symbol_table.isAuxiliaryVariable(avi.get_orig_symb_id())
&& target->lag + avi.get_orig_lead_lag() == -1);
&& avi.get_orig_symb_id()
&& !datatree.symbol_table.isAuxiliaryVariable(*avi.get_orig_symb_id())
&& target->lag + avi.get_orig_lead_lag().value() == -1);
}
else
return target->lag == -1;

View File

@ -29,9 +29,9 @@
#include "SymbolTable.hh"
AuxVarInfo::AuxVarInfo(int symb_id_arg, AuxVarType type_arg, int orig_symb_id_arg, int orig_lead_lag_arg,
int equation_number_for_multiplier_arg, int information_set_arg,
expr_t expr_node_arg, string unary_op_arg) :
AuxVarInfo::AuxVarInfo(int symb_id_arg, AuxVarType type_arg, optional<int> orig_symb_id_arg,
optional<int> orig_lead_lag_arg, int equation_number_for_multiplier_arg,
int information_set_arg, expr_t expr_node_arg, string unary_op_arg) :
symb_id{symb_id_arg},
type{type_arg},
orig_symb_id{orig_symb_id_arg},
@ -363,16 +363,16 @@ SymbolTable::writeOutput(ostream &output) const noexcept(false)
case AuxVarType::diffLag:
case AuxVarType::diffLead:
case AuxVarType::diffForward:
output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 << ";" << endl
<< "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl;
output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id().value())+1 << ";" << endl
<< "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag().value() << ";" << endl;
break;
case AuxVarType::unaryOp:
output << "M_.aux_vars(" << i+1 << ").unary_op = '" << aux_vars[i].get_unary_op() << "';" << endl;
// NB: Fallback!
case AuxVarType::diff:
if (aux_vars[i].get_orig_symb_id() >= 0)
output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 << ";" << endl
<< "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl;
if (aux_vars[i].get_orig_symb_id())
output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(*aux_vars[i].get_orig_symb_id())+1 << ";" << endl
<< "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag().value() << ";" << endl;
break;
case AuxVarType::multiplier:
output << "M_.aux_vars(" << i+1 << ").eq_nbr = " << aux_vars[i].get_equation_number_for_multiplier() + 1 << ";" << endl;
@ -589,7 +589,7 @@ SymbolTable::addDiffLeadAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_i
}
int
SymbolTable::addDiffAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lag) noexcept(false)
SymbolTable::addDiffAuxiliaryVar(int index, expr_t expr_arg, optional<int> orig_symb_id, optional<int> orig_lag) noexcept(false)
{
ostringstream varname;
int symb_id;
@ -606,13 +606,13 @@ SymbolTable::addDiffAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, i
exit(EXIT_FAILURE);
}
aux_vars.emplace_back(symb_id, AuxVarType::diff, orig_symb_id, orig_lag, 0, 0, expr_arg, "");
aux_vars.emplace_back(symb_id, AuxVarType::diff, move(orig_symb_id), move(orig_lag), 0, 0, expr_arg, "");
return symb_id;
}
int
SymbolTable::addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, string unary_op, int orig_symb_id, int orig_lag) noexcept(false)
SymbolTable::addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, string unary_op, optional<int> orig_symb_id, optional<int> orig_lag) noexcept(false)
{
ostringstream varname;
int symb_id;
@ -628,7 +628,7 @@ SymbolTable::addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, string unary_op,
exit(EXIT_FAILURE);
}
aux_vars.emplace_back(symb_id, AuxVarType::unaryOp, orig_symb_id, orig_lag, 0, 0, expr_arg, unary_op);
aux_vars.emplace_back(symb_id, AuxVarType::unaryOp, move(orig_symb_id), move(orig_lag), 0, 0, expr_arg, unary_op);
return symb_id;
}
@ -733,10 +733,10 @@ SymbolTable::getOrigSymbIdForAuxVar(int aux_var_symb_id) const noexcept(false)
|| aux_var.get_type() == AuxVarType::diffForward
|| aux_var.get_type() == AuxVarType::unaryOp)
&& aux_var.get_symb_id() == aux_var_symb_id)
if (int r = aux_var.get_orig_symb_id(); r >= 0)
return r;
if (optional<int> r = aux_var.get_orig_symb_id(); r)
return *r;
else
throw UnknownSymbolIDException(aux_var_symb_id); // Some diff and unaryOp auxvars have orig_symb_id == -1
throw UnknownSymbolIDException(aux_var_symb_id); // Some diff and unaryOp auxvars have orig_symb_id unset
throw UnknownSymbolIDException(aux_var_symb_id);
}
@ -747,8 +747,8 @@ SymbolTable::unrollDiffLeadLagChain(int symb_id, int lag) const noexcept(false)
if (aux_var.get_symb_id() == symb_id)
if (aux_var.get_type() == AuxVarType::diffLag || aux_var.get_type() == AuxVarType::diffLead)
{
auto [orig_symb_id, orig_lag] = unrollDiffLeadLagChain(aux_var.get_orig_symb_id(), lag);
return { orig_symb_id, orig_lag + aux_var.get_orig_lead_lag() };
auto [orig_symb_id, orig_lag] = unrollDiffLeadLagChain(aux_var.get_orig_symb_id().value(), lag);
return { orig_symb_id, orig_lag + aux_var.get_orig_lead_lag().value() };
}
return { symb_id, lag };
}
@ -1026,16 +1026,16 @@ SymbolTable::writeJsonOutput(ostream &output) const
case AuxVarType::diffLag:
case AuxVarType::diffLead:
case AuxVarType::diffForward:
output << R"(, "orig_index": )" << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1
<< R"(, "orig_lead_lag": )" << aux_vars[i].get_orig_lead_lag();
output << R"(, "orig_index": )" << getTypeSpecificID(aux_vars[i].get_orig_symb_id().value())+1
<< R"(, "orig_lead_lag": )" << aux_vars[i].get_orig_lead_lag().value();
break;
case AuxVarType::unaryOp:
output << R"(, "unary_op": ")" << aux_vars[i].get_unary_op() << R"(")";
// NB: Fallback!
case AuxVarType::diff:
if (aux_vars[i].get_orig_symb_id() >= 0)
output << R"(, "orig_index": )" << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1
<< R"(, "orig_lead_lag": )" << aux_vars[i].get_orig_lead_lag();
if (aux_vars[i].get_orig_symb_id())
output << R"(, "orig_index": )" << getTypeSpecificID(*aux_vars[i].get_orig_symb_id())+1
<< R"(, "orig_lead_lag": )" << aux_vars[i].get_orig_lead_lag().value();
break;
case AuxVarType::multiplier:
output << R"(, "eq_nbr": )" << aux_vars[i].get_equation_number_for_multiplier() + 1;

View File

@ -26,6 +26,7 @@
#include <vector>
#include <set>
#include <ostream>
#include <optional>
#include "CodeInterpreter.hh"
#include "ExprNode.hh"
@ -61,24 +62,26 @@ class AuxVarInfo
private:
int symb_id; //!< Symbol ID of the auxiliary variable
AuxVarType type; //!< Its type
int orig_symb_id; /* Symbol ID of the (only) endo that appears on the RHS of
the definition of this auxvar.
Used by endoLag, exoLag, diffForward, logTransform, diff, diffLag,
diffLead and unaryOp.
For diff and unaryOp, if the argument expression is more complex
than than a simple variable, this value is equal to -1. */
int orig_lead_lag; /* Lead/lag of the (only) endo as it appears on the RHS of the definition
of this auxvar. Only used if orig_symb_id is used.
(in particular, for diff and unaryOp, unused if orig_symb_id == -1).
For diff and diffForward, since the definition of the
auxvar is a time difference, the value corresponds to the
time index of the first term of that difference. */
optional<int> orig_symb_id; /* Symbol ID of the (only) endo that appears on the RHS of
the definition of this auxvar.
Used by endoLag, exoLag, diffForward, logTransform, diff, diffLag,
diffLead and unaryOp.
For diff and unaryOp, if the argument expression is more complex
than than a simple variable, this value is unset
(hence the need for std::optional). */
optional<int> orig_lead_lag; /* Lead/lag of the (only) endo as it appears on the RHS of the definition
of this auxvar. Only set if orig_symb_id is set
(in particular, for diff and unaryOp, unset
if orig_symb_id is unset).
For diff and diffForward, since the definition of the
auxvar is a time difference, the value corresponds to the
time index of the first term of that difference. */
int equation_number_for_multiplier; //!< Stores the original constraint equation number associated with this aux var. Only used for avMultiplier.
int information_set; //! Argument of expectation operator. Only used for avExpectation.
expr_t expr_node; //! Auxiliary variable definition
string unary_op; //! Used with AuxUnaryOp
public:
AuxVarInfo(int symb_id_arg, AuxVarType type_arg, int orig_symb_id, int orig_lead_lag, int equation_number_for_multiplier_arg, int information_set_arg, expr_t expr_node_arg, string unary_op_arg);
AuxVarInfo(int symb_id_arg, AuxVarType type_arg, optional<int> orig_symb_id_arg, optional<int> orig_lead_lag_arg, int equation_number_for_multiplier_arg, int information_set_arg, expr_t expr_node_arg, string unary_op_arg);
int
get_symb_id() const
{
@ -94,12 +97,12 @@ public:
{
return static_cast<int>(type);
}
int
optional<int>
get_orig_symb_id() const
{
return orig_symb_id;
};
int
optional<int>
get_orig_lead_lag() const
{
return orig_lead_lag;
@ -349,13 +352,13 @@ public:
diffLead increases it). */
pair<int, int> unrollDiffLeadLagChain(int symb_id, int lag) const noexcept(false);
//! Adds an auxiliary variable when the diff operator is encountered
int addDiffAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id = -1, int orig_lag = 0) noexcept(false);
int addDiffAuxiliaryVar(int index, expr_t expr_arg, optional<int> orig_symb_id = nullopt, optional<int> orig_lag = nullopt) noexcept(false);
//! Takes care of timing between diff statements
int addDiffLagAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lag) noexcept(false);
//! Takes care of timing between diff statements
int addDiffLeadAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lead) noexcept(false);
//! An Auxiliary variable for a unary op
int addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, string unary_op, int orig_symb_id = -1, int orig_lag = 0) noexcept(false);
int addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, string unary_op, optional<int> orig_symb_id = nullopt, optional<int> orig_lag = nullopt) noexcept(false);
//! An auxiliary variable for a pac_expectation operator
int addPacExpectationAuxiliaryVar(const string &name, expr_t expr_arg);
//! An auxiliary variable for a pac_target_nonstationary operator