Add more comments in routines for aux vars for unary ops / diff operators
By the way, do some small code simplifications.issue#70
parent
8e9f6e4c57
commit
3941278832
|
@ -6267,16 +6267,15 @@ DynamicModel::substituteUnaryOps(StaticModel &static_model, const vector<int> &e
|
|||
diff_table_t nodes;
|
||||
ExprNode::subst_table_t subst_table;
|
||||
|
||||
// Find matching unary ops that may be outside of diffs (i.e., those with different lags)
|
||||
// Mark unary ops to be substituted in model local variables that appear in selected equations
|
||||
set<int> used_local_vars;
|
||||
for (int eqnumber : eqnumbers)
|
||||
equations[eqnumber]->collectVariables(SymbolType::modelLocalVariable, used_local_vars);
|
||||
|
||||
// Only substitute unary ops in model local variables that appear in VAR equations
|
||||
for (auto & it : local_variables_table)
|
||||
if (used_local_vars.find(it.first) != used_local_vars.end())
|
||||
it.second->findUnaryOpNodesForAuxVarCreation(static_model, nodes);
|
||||
|
||||
// Mark unary ops to be substituted in selected equations
|
||||
for (int eqnumber : eqnumbers)
|
||||
equations[eqnumber]->findUnaryOpNodesForAuxVarCreation(static_model, nodes);
|
||||
|
||||
|
|
|
@ -445,10 +445,10 @@ public:
|
|||
//! Creates aux vars for all unary operators
|
||||
pair<diff_table_t, ExprNode::subst_table_t> substituteUnaryOps(StaticModel &static_model);
|
||||
|
||||
//! Creates aux vars for certain unary operators: originally implemented for support of VARs
|
||||
//! Creates aux vars for unary operators in certain equations: originally implemented for support of VARs
|
||||
pair<diff_table_t, ExprNode::subst_table_t> substituteUnaryOps(StaticModel &static_model, const set<string> &eq_tags);
|
||||
|
||||
//! Creates aux vars for certain unary operators: originally implemented for support of VARs
|
||||
//! Creates aux vars for unary operators in certain equations: originally implemented for support of VARs
|
||||
pair<diff_table_t, ExprNode::subst_table_t> substituteUnaryOps(StaticModel &static_model, const vector<int> &eqnumbers);
|
||||
|
||||
//! Substitutes diff operator
|
||||
|
|
|
@ -3464,9 +3464,8 @@ UnaryOpNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_t
|
|||
auto it = nodes.find(sthis);
|
||||
if (it != nodes.end())
|
||||
{
|
||||
for (map<int, expr_t>::const_iterator it1 = it->second.begin();
|
||||
it1 != it->second.end(); it1++)
|
||||
if (arg == it1->second)
|
||||
for (const auto &it1 : it->second)
|
||||
if (arg == it1.second)
|
||||
return;
|
||||
it->second[arg_max_lag] = const_cast<UnaryOpNode *>(this);
|
||||
}
|
||||
|
@ -3488,9 +3487,8 @@ UnaryOpNode::findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table)
|
|||
auto it = diff_table.find(sthis);
|
||||
if (it != diff_table.end())
|
||||
{
|
||||
for (map<int, expr_t>::const_iterator it1 = it->second.begin();
|
||||
it1 != it->second.end(); it1++)
|
||||
if (arg == it1->second)
|
||||
for (const auto &it1 : it->second)
|
||||
if (arg == it1.second)
|
||||
return;
|
||||
it->second[arg_max_lag] = const_cast<UnaryOpNode *>(this);
|
||||
}
|
||||
|
@ -3508,11 +3506,12 @@ expr_t
|
|||
UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table,
|
||||
vector<BinaryOpNode *> &neweqs) const
|
||||
{
|
||||
// If this is not a diff node, then substitute recursively and return
|
||||
expr_t argsubst = arg->substituteDiff(static_datatree, diff_table, subst_table, neweqs);
|
||||
if (op_code != UnaryOpcode::diff)
|
||||
return buildSimilarUnaryOpNode(argsubst, datatree);
|
||||
|
||||
subst_table_t::const_iterator sit = subst_table.find(this);
|
||||
auto sit = subst_table.find(this);
|
||||
if (sit != subst_table.end())
|
||||
return const_cast<VariableNode *>(sit->second);
|
||||
|
||||
|
@ -3521,9 +3520,9 @@ UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table,
|
|||
int symb_id;
|
||||
if (it == diff_table.end() || it->second[-arg->maxLagWithDiffsExpanded()] != this)
|
||||
{
|
||||
// diff does not appear in VAR equations
|
||||
// so simply create aux var and return
|
||||
// Once the comparison of expression nodes works, come back and remove this part, folding into the next loop.
|
||||
/* diff does not appear in VAR equations, so simply create aux var and return.
|
||||
Once the comparison of expression nodes works, come back and remove
|
||||
this part, folding into the next loop. */
|
||||
symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, argsubst);
|
||||
VariableNode *aux_var = datatree.AddVariable(symb_id, 0);
|
||||
neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(aux_var,
|
||||
|
@ -3533,10 +3532,13 @@ UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table,
|
|||
return const_cast<VariableNode *>(subst_table[this]);
|
||||
}
|
||||
|
||||
/* At this point, we know that this node (and its lagged/leaded brothers)
|
||||
must be substituted. We create the auxiliary variable and fill the
|
||||
substitution table for all those similar nodes, in an iteration going from
|
||||
leads to lags. */
|
||||
int last_arg_max_lag = 0;
|
||||
VariableNode *last_aux_var = nullptr;
|
||||
for (auto rit = it->second.rbegin();
|
||||
rit != it->second.rend(); rit++)
|
||||
for (auto rit = it->second.rbegin(); rit != it->second.rend(); ++rit)
|
||||
{
|
||||
expr_t argsubst = dynamic_cast<UnaryOpNode *>(rit->second)->
|
||||
arg->substituteDiff(static_datatree, diff_table, subst_table, neweqs);
|
||||
|
@ -3585,10 +3587,12 @@ UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table,
|
|||
expr_t
|
||||
UnaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
|
||||
{
|
||||
subst_table_t::const_iterator sit = subst_table.find(this);
|
||||
auto sit = subst_table.find(this);
|
||||
if (sit != subst_table.end())
|
||||
return const_cast<VariableNode *>(sit->second);
|
||||
|
||||
/* If (the static equivalent of) this node is not marked for substitution,
|
||||
then substitute recursively and return. */
|
||||
auto *sthis = dynamic_cast<UnaryOpNode *>(this->toStatic(static_datatree));
|
||||
auto it = nodes.find(sthis);
|
||||
expr_t argsubst = arg->substituteUnaryOpNodes(static_datatree, nodes, subst_table, neweqs);
|
||||
|
@ -3601,7 +3605,7 @@ UnaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nod
|
|||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
string unary_op = "";
|
||||
string unary_op;
|
||||
switch (op_code)
|
||||
{
|
||||
case UnaryOpcode::exp:
|
||||
|
@ -3665,15 +3669,17 @@ UnaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nod
|
|||
unary_op = "erf";
|
||||
break;
|
||||
default:
|
||||
{
|
||||
cerr << "UnaryOpNode::substituteUnaryOpNodes: Shouldn't arrive here" << endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
cerr << "UnaryOpNode::substituteUnaryOpNodes: Shouldn't arrive here" << endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
/* At this point, we know that this node (and its lagged/leaded brothers)
|
||||
must be substituted. We create the auxiliary variable and fill the
|
||||
substitution table for all those similar nodes, in an iteration going from
|
||||
leads to lags. */
|
||||
int base_aux_lag = 0;
|
||||
VariableNode *aux_var = nullptr;
|
||||
for (auto rit = it->second.rbegin(); rit != it->second.rend(); rit++)
|
||||
for (auto rit = it->second.rbegin(); rit != it->second.rend(); ++rit)
|
||||
if (rit == it->second.rbegin())
|
||||
{
|
||||
int symb_id;
|
||||
|
|
|
@ -558,16 +558,33 @@ class ExprNode
|
|||
//! Substitute VarExpectation nodes
|
||||
virtual expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const = 0;
|
||||
|
||||
//! Substitute diff operator
|
||||
//! Mark diff nodes to be substituted
|
||||
/*! The various nodes that are equivalent from a static point of view are
|
||||
grouped together in the “nodes” table, referenced by their maximum
|
||||
lag.
|
||||
TODO: This is technically wrong for complex expressions, and
|
||||
should be improved by grouping together only those nodes that are
|
||||
equivalent up to a shift in all leads/lags. */
|
||||
virtual void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const = 0;
|
||||
virtual void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const = 0;
|
||||
virtual int findTargetVariable(int lhs_symb_id) const = 0;
|
||||
//! Substitute diff operator
|
||||
virtual expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const = 0;
|
||||
|
||||
//! Mark unary ops nodes to be substituted
|
||||
/*! The various nodes that are equivalent from a static point of view are
|
||||
grouped together in the “nodes” table, referenced by their maximum
|
||||
lag.
|
||||
TODO: This is technically wrong for complex expressions, and
|
||||
should be improved by grouping together only those nodes that are
|
||||
equivalent up to a shift in all leads/lags. */
|
||||
virtual void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const = 0;
|
||||
//! Substitute unary ops nodes
|
||||
virtual expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const = 0;
|
||||
|
||||
//! Substitute pac_expectation operator
|
||||
virtual expr_t substitutePacExpectation(const string & name, expr_t subexpr) = 0;
|
||||
|
||||
virtual int findTargetVariable(int lhs_symb_id) const = 0;
|
||||
|
||||
//! Add ExprNodes to the provided datatree
|
||||
virtual expr_t clone(DataTree &datatree) const = 0;
|
||||
|
||||
|
|
Loading…
Reference in New Issue