Remove constructor and accessors for AuxVarInfo

Rather make all data members public and const, and use aggregate-initialization.
master
Sébastien Villemot 2022-07-20 14:32:57 +02:00
parent f0629555a5
commit 50d5b916e2
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
4 changed files with 87 additions and 141 deletions

View File

@ -2713,23 +2713,23 @@ DynamicModel::getVARDerivIDs(int lhs_symb_id, int lead_lag) const
continue;
}
if (avi->get_type() == AuxVarType::endoLag && avi->get_orig_symb_id().value() == lhs_symb_id
&& avi->get_orig_lead_lag().value() + lead_lag2 == lead_lag)
if (avi->type == AuxVarType::endoLag && avi->orig_symb_id.value() == lhs_symb_id
&& avi->orig_lead_lag.value() + lead_lag2 == lead_lag)
deriv_ids.push_back(deriv_id2);
// Handle diff lag auxvar, possibly nested several times
int diff_lag_depth = 0;
while (avi->get_type() == AuxVarType::diffLag)
while (avi->type == AuxVarType::diffLag)
{
diff_lag_depth++;
if (avi->get_orig_symb_id() == lhs_symb_id && lead_lag2 - diff_lag_depth == lead_lag)
if (avi->orig_symb_id == lhs_symb_id && lead_lag2 - diff_lag_depth == lead_lag)
{
deriv_ids.push_back(deriv_id2);
break;
}
try
{
avi = &symbol_table.getAuxVarInfo(avi->get_orig_symb_id().value());
avi = &symbol_table.getAuxVarInfo(avi->orig_symb_id.value());
}
catch (SymbolTable::UnknownSymbolIDException)
{

View File

@ -9101,14 +9101,14 @@ ExprNode::matchParamTimesTargetMinusVariable(int symb_id) const
return false;
if (datatree.symbol_table.isAuxiliaryVariable(target->symb_id))
{
auto avi = datatree.symbol_table.getAuxVarInfo(target->symb_id);
if (avi.get_type() == AuxVarType::pacTargetNonstationary && target->lag == -1)
auto &avi = datatree.symbol_table.getAuxVarInfo(target->symb_id);
if (avi.type == AuxVarType::pacTargetNonstationary && target->lag == -1)
return true;
return (avi.get_type() == AuxVarType::unaryOp
&& avi.get_unary_op() == "log"
&& avi.get_orig_symb_id()
&& !datatree.symbol_table.isAuxiliaryVariable(*avi.get_orig_symb_id())
&& target->lag + avi.get_orig_lead_lag().value() == -1);
return (avi.type == AuxVarType::unaryOp
&& avi.unary_op == "log"
&& avi.orig_symb_id
&& !datatree.symbol_table.isAuxiliaryVariable(*avi.orig_symb_id)
&& target->lag + avi.orig_lead_lag.value() == -1);
}
else
return target->lag == -1;

View File

@ -29,20 +29,6 @@
#include "SymbolTable.hh"
AuxVarInfo::AuxVarInfo(int symb_id_arg, AuxVarType type_arg, optional<int> orig_symb_id_arg,
optional<int> orig_lead_lag_arg, int equation_number_for_multiplier_arg,
int information_set_arg, expr_t expr_node_arg, string unary_op_arg) :
symb_id{symb_id_arg},
type{type_arg},
orig_symb_id{orig_symb_id_arg},
orig_lead_lag{orig_lead_lag_arg},
equation_number_for_multiplier{equation_number_for_multiplier_arg},
information_set{information_set_arg},
expr_node{expr_node_arg},
unary_op{move(unary_op_arg)}
{
}
int
SymbolTable::addSymbol(const string &name, SymbolType type, const string &tex_name, const vector<pair<string, string>> &partition_value) noexcept(false)
{
@ -343,9 +329,9 @@ SymbolTable::writeOutput(ostream &output) const noexcept(false)
else
for (int i = 0; i < static_cast<int>(aux_vars.size()); i++)
{
output << "M_.aux_vars(" << i+1 << ").endo_index = " << getTypeSpecificID(aux_vars[i].get_symb_id())+1 << ";" << endl
output << "M_.aux_vars(" << i+1 << ").endo_index = " << getTypeSpecificID(aux_vars[i].symb_id)+1 << ";" << endl
<< "M_.aux_vars(" << i+1 << ").type = " << aux_vars[i].get_type_id() << ";" << endl;
switch (aux_vars[i].get_type())
switch (aux_vars[i].type)
{
case AuxVarType::endoLead:
case AuxVarType::exoLead:
@ -359,23 +345,23 @@ SymbolTable::writeOutput(ostream &output) const noexcept(false)
case AuxVarType::diffLag:
case AuxVarType::diffLead:
case AuxVarType::diffForward:
output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id().value())+1 << ";" << endl
<< "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag().value() << ";" << endl;
output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].orig_symb_id.value())+1 << ";" << endl
<< "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].orig_lead_lag.value() << ";" << endl;
break;
case AuxVarType::unaryOp:
output << "M_.aux_vars(" << i+1 << ").unary_op = '" << aux_vars[i].get_unary_op() << "';" << endl;
output << "M_.aux_vars(" << i+1 << ").unary_op = '" << aux_vars[i].unary_op << "';" << endl;
[[fallthrough]];
case AuxVarType::diff:
if (aux_vars[i].get_orig_symb_id())
output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(*aux_vars[i].get_orig_symb_id())+1 << ";" << endl
<< "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag().value() << ";" << endl;
if (aux_vars[i].orig_symb_id)
output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(*aux_vars[i].orig_symb_id)+1 << ";" << endl
<< "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].orig_lead_lag.value() << ";" << endl;
break;
case AuxVarType::multiplier:
output << "M_.aux_vars(" << i+1 << ").eq_nbr = " << aux_vars[i].get_equation_number_for_multiplier() + 1 << ";" << endl;
output << "M_.aux_vars(" << i+1 << ").eq_nbr = " << aux_vars[i].equation_number_for_multiplier + 1 << ";" << endl;
break;
}
if (expr_t orig_expr = aux_vars[i].get_expr_node();
if (expr_t orig_expr = aux_vars[i].expr_node;
orig_expr)
{
output << "M_.aux_vars(" << i+1 << ").orig_expr = '";
@ -682,40 +668,40 @@ int
SymbolTable::searchAuxiliaryVars(int orig_symb_id, int orig_lead_lag) const noexcept(false)
{
for (const auto &aux_var : aux_vars)
if ((aux_var.get_type() == AuxVarType::endoLag || aux_var.get_type() == AuxVarType::exoLag)
&& aux_var.get_orig_symb_id() == orig_symb_id && aux_var.get_orig_lead_lag() == orig_lead_lag)
return aux_var.get_symb_id();
if ((aux_var.type == AuxVarType::endoLag || aux_var.type == AuxVarType::exoLag)
&& aux_var.orig_symb_id == orig_symb_id && aux_var.orig_lead_lag == orig_lead_lag)
return aux_var.symb_id;
throw SearchFailedException(orig_symb_id, orig_lead_lag);
}
int
SymbolTable::getOrigSymbIdForAuxVar(int aux_var_symb_id) const noexcept(false)
SymbolTable::getOrigSymbIdForAuxVar(int aux_var_symb_id_arg) const noexcept(false)
{
for (const auto &aux_var : aux_vars)
if ((aux_var.get_type() == AuxVarType::endoLag
|| aux_var.get_type() == AuxVarType::exoLag
|| aux_var.get_type() == AuxVarType::diff
|| aux_var.get_type() == AuxVarType::diffLag
|| aux_var.get_type() == AuxVarType::diffLead
|| aux_var.get_type() == AuxVarType::diffForward
|| aux_var.get_type() == AuxVarType::unaryOp)
&& aux_var.get_symb_id() == aux_var_symb_id)
if (optional<int> r = aux_var.get_orig_symb_id(); r)
if ((aux_var.type == AuxVarType::endoLag
|| aux_var.type == AuxVarType::exoLag
|| aux_var.type == AuxVarType::diff
|| aux_var.type == AuxVarType::diffLag
|| aux_var.type == AuxVarType::diffLead
|| aux_var.type == AuxVarType::diffForward
|| aux_var.type == AuxVarType::unaryOp)
&& aux_var.symb_id == aux_var_symb_id_arg)
if (optional<int> r = aux_var.orig_symb_id; r)
return *r;
else
throw UnknownSymbolIDException(aux_var_symb_id); // Some diff and unaryOp auxvars have orig_symb_id unset
throw UnknownSymbolIDException(aux_var_symb_id);
throw UnknownSymbolIDException(aux_var_symb_id_arg); // Some diff and unaryOp auxvars have orig_symb_id unset
throw UnknownSymbolIDException(aux_var_symb_id_arg);
}
pair<int, int>
SymbolTable::unrollDiffLeadLagChain(int symb_id, int lag) const noexcept(false)
{
for (const auto &aux_var : aux_vars)
if (aux_var.get_symb_id() == symb_id)
if (aux_var.get_type() == AuxVarType::diffLag || aux_var.get_type() == AuxVarType::diffLead)
if (aux_var.symb_id == symb_id)
if (aux_var.type == AuxVarType::diffLag || aux_var.type == AuxVarType::diffLead)
{
auto [orig_symb_id, orig_lag] = unrollDiffLeadLagChain(aux_var.get_orig_symb_id().value(), lag);
return { orig_symb_id, orig_lag + aux_var.get_orig_lead_lag().value() };
auto [orig_symb_id, orig_lag] = unrollDiffLeadLagChain(aux_var.orig_symb_id.value(), lag);
return { orig_symb_id, orig_lag + aux_var.orig_lead_lag.value() };
}
return { symb_id, lag };
}
@ -725,8 +711,8 @@ SymbolTable::getAuxiliaryVarsExprNode(int symb_id) const noexcept(false)
// throw exception if it is a Lagrange multiplier
{
for (const auto &aux_var : aux_vars)
if (aux_var.get_symb_id() == symb_id)
if (expr_t expr_node = aux_var.get_expr_node();
if (aux_var.symb_id == symb_id)
if (expr_t expr_node = aux_var.expr_node;
expr_node)
return expr_node;
else
@ -874,7 +860,7 @@ bool
SymbolTable::isAuxiliaryVariable(int symb_id) const
{
for (const auto &aux_var : aux_vars)
if (aux_var.get_symb_id() == symb_id)
if (aux_var.symb_id == symb_id)
return true;
return false;
}
@ -883,7 +869,7 @@ bool
SymbolTable::isAuxiliaryVariableButNotMultiplier(int symb_id) const
{
for (const auto &aux_var : aux_vars)
if (aux_var.get_symb_id() == symb_id && aux_var.get_type() != AuxVarType::multiplier)
if (aux_var.symb_id == symb_id && aux_var.type != AuxVarType::multiplier)
return true;
return false;
}
@ -892,10 +878,10 @@ bool
SymbolTable::isDiffAuxiliaryVariable(int symb_id) const
{
for (const auto &aux_var : aux_vars)
if (aux_var.get_symb_id() == symb_id
&& (aux_var.get_type() == AuxVarType::diff
|| aux_var.get_type() == AuxVarType::diffLag
|| aux_var.get_type() == AuxVarType::diffLead))
if (aux_var.symb_id == symb_id
&& (aux_var.type == AuxVarType::diff
|| aux_var.type == AuxVarType::diffLag
|| aux_var.type == AuxVarType::diffLead))
return true;
return false;
}
@ -977,9 +963,9 @@ SymbolTable::writeJsonOutput(ostream &output) const
{
if (i != 0)
output << ", ";
output << R"({"endo_index": )" << getTypeSpecificID(aux_vars[i].get_symb_id())+1
output << R"({"endo_index": )" << getTypeSpecificID(aux_vars[i].symb_id)+1
<< R"(, "type": )" << aux_vars[i].get_type_id();
switch (aux_vars[i].get_type())
switch (aux_vars[i].type)
{
case AuxVarType::endoLead:
case AuxVarType::exoLead:
@ -993,23 +979,23 @@ SymbolTable::writeJsonOutput(ostream &output) const
case AuxVarType::diffLag:
case AuxVarType::diffLead:
case AuxVarType::diffForward:
output << R"(, "orig_index": )" << getTypeSpecificID(aux_vars[i].get_orig_symb_id().value())+1
<< R"(, "orig_lead_lag": )" << aux_vars[i].get_orig_lead_lag().value();
output << R"(, "orig_index": )" << getTypeSpecificID(aux_vars[i].orig_symb_id.value())+1
<< R"(, "orig_lead_lag": )" << aux_vars[i].orig_lead_lag.value();
break;
case AuxVarType::unaryOp:
output << R"(, "unary_op": ")" << aux_vars[i].get_unary_op() << R"(")";
output << R"(, "unary_op": ")" << aux_vars[i].unary_op << R"(")";
[[fallthrough]];
case AuxVarType::diff:
if (aux_vars[i].get_orig_symb_id())
output << R"(, "orig_index": )" << getTypeSpecificID(*aux_vars[i].get_orig_symb_id())+1
<< R"(, "orig_lead_lag": )" << aux_vars[i].get_orig_lead_lag().value();
if (aux_vars[i].orig_symb_id)
output << R"(, "orig_index": )" << getTypeSpecificID(*aux_vars[i].orig_symb_id)+1
<< R"(, "orig_lead_lag": )" << aux_vars[i].orig_lead_lag.value();
break;
case AuxVarType::multiplier:
output << R"(, "eq_nbr": )" << aux_vars[i].get_equation_number_for_multiplier() + 1;
output << R"(, "eq_nbr": )" << aux_vars[i].equation_number_for_multiplier + 1;
break;
}
if (expr_t orig_expr = aux_vars[i].get_expr_node();
if (expr_t orig_expr = aux_vars[i].expr_node;
orig_expr)
{
output << R"(, "orig_expr": ")";
@ -1058,8 +1044,8 @@ optional<int>
SymbolTable::getEquationNumberForMultiplier(int symb_id) const
{
for (const auto &aux_var : aux_vars)
if (aux_var.get_symb_id() == symb_id && aux_var.get_type() == AuxVarType::multiplier)
return aux_var.get_equation_number_for_multiplier();
if (aux_var.symb_id == symb_id && aux_var.type == AuxVarType::multiplier)
return aux_var.equation_number_for_multiplier;
return nullopt;
}

View File

@ -57,76 +57,36 @@ enum class AuxVarType
};
//! Information on some auxiliary variables
class AuxVarInfo
struct AuxVarInfo
{
private:
int symb_id; //!< Symbol ID of the auxiliary variable
AuxVarType type; //!< Its type
optional<int> orig_symb_id; /* Symbol ID of the (only) endo that appears on the RHS of
the definition of this auxvar.
Used by endoLag, exoLag, diffForward, logTransform, diff, diffLag,
diffLead and unaryOp.
For diff and unaryOp, if the argument expression is more complex
than than a simple variable, this value is unset
(hence the need for std::optional). */
optional<int> orig_lead_lag; /* Lead/lag of the (only) endo as it appears on the RHS of the definition
of this auxvar. Only set if orig_symb_id is set
(in particular, for diff and unaryOp, unset
if orig_symb_id is unset).
For diff and diffForward, since the definition of the
auxvar is a time difference, the value corresponds to the
time index of the first term of that difference. */
int equation_number_for_multiplier; //!< Stores the original constraint equation number associated with this aux var. Only used for avMultiplier.
int information_set; //! Argument of expectation operator. Only used for avExpectation.
expr_t expr_node; //! Auxiliary variable definition
string unary_op; //! Used with AuxUnaryOp
public:
AuxVarInfo(int symb_id_arg, AuxVarType type_arg, optional<int> orig_symb_id_arg, optional<int> orig_lead_lag_arg, int equation_number_for_multiplier_arg, int information_set_arg, expr_t expr_node_arg, string unary_op_arg);
int
get_symb_id() const
{
return symb_id;
};
AuxVarType
get_type() const
{
return type;
};
const int symb_id; // Symbol ID of the auxiliary variable
const AuxVarType type; // Its type
const optional<int> orig_symb_id; /* Symbol ID of the (only) endo that appears on the RHS of
the definition of this auxvar.
Used by endoLag, exoLag, diffForward, logTransform, diff,
diffLag, diffLead and unaryOp.
For diff and unaryOp, if the argument expression is more complex
than than a simple variable, this value is unset
(hence the need for std::optional). */
const optional<int> orig_lead_lag; /* Lead/lag of the (only) endo as it appears on the RHS of the
definition of this auxvar. Only set if orig_symb_id is set
(in particular, for diff and unaryOp, unset
if orig_symb_id is unset).
For diff and diffForward, since the definition of the
auxvar is a time difference, the value corresponds to the
time index of the first term of that difference. */
const int equation_number_for_multiplier; /* Stores the original constraint equation number
associated with this aux var. Only used for
avMultiplier. */
const int information_set; // Argument of expectation operator. Only used for avExpectation.
const expr_t expr_node; // Auxiliary variable definition
const string unary_op; // Used with AuxUnaryOp
int
get_type_id() const
{
return static_cast<int>(type);
}
optional<int>
get_orig_symb_id() const
{
return orig_symb_id;
};
optional<int>
get_orig_lead_lag() const
{
return orig_lead_lag;
};
int
get_equation_number_for_multiplier() const
{
return equation_number_for_multiplier;
};
int
get_information_set() const
{
return information_set;
};
expr_t
get_expr_node() const
{
return expr_node;
};
const string &
get_unary_op() const
{
return unary_op;
};
};
//! Stores the symbol table
@ -318,7 +278,7 @@ public:
this auxvar (either because its of the wrong type, or because there is
no such orig var for this specific auxvar, in case of complex expressions
in diff or unaryOp). */
int getOrigSymbIdForAuxVar(int aux_var_symb_id) const noexcept(false);
int getOrigSymbIdForAuxVar(int aux_var_symb_id_arg) const noexcept(false);
/* Unrolls a chain of diffLag or diffLead aux vars until it founds a (regular) diff aux
var. In other words:
- if the arg is a (regu) diff aux var, returns the arg
@ -583,7 +543,7 @@ inline const AuxVarInfo &
SymbolTable::getAuxVarInfo(int symb_id) const
{
for (const auto &aux_var : aux_vars)
if (aux_var.get_symb_id() == symb_id)
if (aux_var.symb_id == symb_id)
return aux_var;
throw UnknownSymbolIDException(symb_id);
}