Add more comments in routines for aux vars for unary ops / diff operators

By the way, do some small code simplifications.
issue#70
Sébastien Villemot 2019-08-19 18:22:55 +02:00
parent 8e9f6e4c57
commit 3941278832
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
4 changed files with 49 additions and 27 deletions

View File

@ -6267,16 +6267,15 @@ DynamicModel::substituteUnaryOps(StaticModel &static_model, const vector<int> &e
diff_table_t nodes; diff_table_t nodes;
ExprNode::subst_table_t subst_table; ExprNode::subst_table_t subst_table;
// Find matching unary ops that may be outside of diffs (i.e., those with different lags) // Mark unary ops to be substituted in model local variables that appear in selected equations
set<int> used_local_vars; set<int> used_local_vars;
for (int eqnumber : eqnumbers) for (int eqnumber : eqnumbers)
equations[eqnumber]->collectVariables(SymbolType::modelLocalVariable, used_local_vars); equations[eqnumber]->collectVariables(SymbolType::modelLocalVariable, used_local_vars);
// Only substitute unary ops in model local variables that appear in VAR equations
for (auto & it : local_variables_table) for (auto & it : local_variables_table)
if (used_local_vars.find(it.first) != used_local_vars.end()) if (used_local_vars.find(it.first) != used_local_vars.end())
it.second->findUnaryOpNodesForAuxVarCreation(static_model, nodes); it.second->findUnaryOpNodesForAuxVarCreation(static_model, nodes);
// Mark unary ops to be substituted in selected equations
for (int eqnumber : eqnumbers) for (int eqnumber : eqnumbers)
equations[eqnumber]->findUnaryOpNodesForAuxVarCreation(static_model, nodes); equations[eqnumber]->findUnaryOpNodesForAuxVarCreation(static_model, nodes);

View File

@ -445,10 +445,10 @@ public:
//! Creates aux vars for all unary operators //! Creates aux vars for all unary operators
pair<diff_table_t, ExprNode::subst_table_t> substituteUnaryOps(StaticModel &static_model); pair<diff_table_t, ExprNode::subst_table_t> substituteUnaryOps(StaticModel &static_model);
//! Creates aux vars for certain unary operators: originally implemented for support of VARs //! Creates aux vars for unary operators in certain equations: originally implemented for support of VARs
pair<diff_table_t, ExprNode::subst_table_t> substituteUnaryOps(StaticModel &static_model, const set<string> &eq_tags); pair<diff_table_t, ExprNode::subst_table_t> substituteUnaryOps(StaticModel &static_model, const set<string> &eq_tags);
//! Creates aux vars for certain unary operators: originally implemented for support of VARs //! Creates aux vars for unary operators in certain equations: originally implemented for support of VARs
pair<diff_table_t, ExprNode::subst_table_t> substituteUnaryOps(StaticModel &static_model, const vector<int> &eqnumbers); pair<diff_table_t, ExprNode::subst_table_t> substituteUnaryOps(StaticModel &static_model, const vector<int> &eqnumbers);
//! Substitutes diff operator //! Substitutes diff operator

View File

@ -3464,9 +3464,8 @@ UnaryOpNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_t
auto it = nodes.find(sthis); auto it = nodes.find(sthis);
if (it != nodes.end()) if (it != nodes.end())
{ {
for (map<int, expr_t>::const_iterator it1 = it->second.begin(); for (const auto &it1 : it->second)
it1 != it->second.end(); it1++) if (arg == it1.second)
if (arg == it1->second)
return; return;
it->second[arg_max_lag] = const_cast<UnaryOpNode *>(this); it->second[arg_max_lag] = const_cast<UnaryOpNode *>(this);
} }
@ -3488,9 +3487,8 @@ UnaryOpNode::findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table)
auto it = diff_table.find(sthis); auto it = diff_table.find(sthis);
if (it != diff_table.end()) if (it != diff_table.end())
{ {
for (map<int, expr_t>::const_iterator it1 = it->second.begin(); for (const auto &it1 : it->second)
it1 != it->second.end(); it1++) if (arg == it1.second)
if (arg == it1->second)
return; return;
it->second[arg_max_lag] = const_cast<UnaryOpNode *>(this); it->second[arg_max_lag] = const_cast<UnaryOpNode *>(this);
} }
@ -3508,11 +3506,12 @@ expr_t
UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table,
vector<BinaryOpNode *> &neweqs) const vector<BinaryOpNode *> &neweqs) const
{ {
// If this is not a diff node, then substitute recursively and return
expr_t argsubst = arg->substituteDiff(static_datatree, diff_table, subst_table, neweqs); expr_t argsubst = arg->substituteDiff(static_datatree, diff_table, subst_table, neweqs);
if (op_code != UnaryOpcode::diff) if (op_code != UnaryOpcode::diff)
return buildSimilarUnaryOpNode(argsubst, datatree); return buildSimilarUnaryOpNode(argsubst, datatree);
subst_table_t::const_iterator sit = subst_table.find(this); auto sit = subst_table.find(this);
if (sit != subst_table.end()) if (sit != subst_table.end())
return const_cast<VariableNode *>(sit->second); return const_cast<VariableNode *>(sit->second);
@ -3521,9 +3520,9 @@ UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table,
int symb_id; int symb_id;
if (it == diff_table.end() || it->second[-arg->maxLagWithDiffsExpanded()] != this) if (it == diff_table.end() || it->second[-arg->maxLagWithDiffsExpanded()] != this)
{ {
// diff does not appear in VAR equations /* diff does not appear in VAR equations, so simply create aux var and return.
// so simply create aux var and return Once the comparison of expression nodes works, come back and remove
// Once the comparison of expression nodes works, come back and remove this part, folding into the next loop. this part, folding into the next loop. */
symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, argsubst); symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, argsubst);
VariableNode *aux_var = datatree.AddVariable(symb_id, 0); VariableNode *aux_var = datatree.AddVariable(symb_id, 0);
neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(aux_var, neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(aux_var,
@ -3533,10 +3532,13 @@ UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table,
return const_cast<VariableNode *>(subst_table[this]); return const_cast<VariableNode *>(subst_table[this]);
} }
/* At this point, we know that this node (and its lagged/leaded brothers)
must be substituted. We create the auxiliary variable and fill the
substitution table for all those similar nodes, in an iteration going from
leads to lags. */
int last_arg_max_lag = 0; int last_arg_max_lag = 0;
VariableNode *last_aux_var = nullptr; VariableNode *last_aux_var = nullptr;
for (auto rit = it->second.rbegin(); for (auto rit = it->second.rbegin(); rit != it->second.rend(); ++rit)
rit != it->second.rend(); rit++)
{ {
expr_t argsubst = dynamic_cast<UnaryOpNode *>(rit->second)-> expr_t argsubst = dynamic_cast<UnaryOpNode *>(rit->second)->
arg->substituteDiff(static_datatree, diff_table, subst_table, neweqs); arg->substituteDiff(static_datatree, diff_table, subst_table, neweqs);
@ -3585,10 +3587,12 @@ UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table,
expr_t expr_t
UnaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const UnaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
{ {
subst_table_t::const_iterator sit = subst_table.find(this); auto sit = subst_table.find(this);
if (sit != subst_table.end()) if (sit != subst_table.end())
return const_cast<VariableNode *>(sit->second); return const_cast<VariableNode *>(sit->second);
/* If (the static equivalent of) this node is not marked for substitution,
then substitute recursively and return. */
auto *sthis = dynamic_cast<UnaryOpNode *>(this->toStatic(static_datatree)); auto *sthis = dynamic_cast<UnaryOpNode *>(this->toStatic(static_datatree));
auto it = nodes.find(sthis); auto it = nodes.find(sthis);
expr_t argsubst = arg->substituteUnaryOpNodes(static_datatree, nodes, subst_table, neweqs); expr_t argsubst = arg->substituteUnaryOpNodes(static_datatree, nodes, subst_table, neweqs);
@ -3601,7 +3605,7 @@ UnaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nod
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
string unary_op = ""; string unary_op;
switch (op_code) switch (op_code)
{ {
case UnaryOpcode::exp: case UnaryOpcode::exp:
@ -3665,15 +3669,17 @@ UnaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nod
unary_op = "erf"; unary_op = "erf";
break; break;
default: default:
{ cerr << "UnaryOpNode::substituteUnaryOpNodes: Shouldn't arrive here" << endl;
cerr << "UnaryOpNode::substituteUnaryOpNodes: Shouldn't arrive here" << endl; exit(EXIT_FAILURE);
exit(EXIT_FAILURE);
}
} }
/* At this point, we know that this node (and its lagged/leaded brothers)
must be substituted. We create the auxiliary variable and fill the
substitution table for all those similar nodes, in an iteration going from
leads to lags. */
int base_aux_lag = 0; int base_aux_lag = 0;
VariableNode *aux_var = nullptr; VariableNode *aux_var = nullptr;
for (auto rit = it->second.rbegin(); rit != it->second.rend(); rit++) for (auto rit = it->second.rbegin(); rit != it->second.rend(); ++rit)
if (rit == it->second.rbegin()) if (rit == it->second.rbegin())
{ {
int symb_id; int symb_id;

View File

@ -558,16 +558,33 @@ class ExprNode
//! Substitute VarExpectation nodes //! Substitute VarExpectation nodes
virtual expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const = 0; virtual expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const = 0;
//! Substitute diff operator //! Mark diff nodes to be substituted
/*! The various nodes that are equivalent from a static point of view are
grouped together in the nodes table, referenced by their maximum
lag.
TODO: This is technically wrong for complex expressions, and
should be improved by grouping together only those nodes that are
equivalent up to a shift in all leads/lags. */
virtual void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const = 0; virtual void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const = 0;
virtual void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const = 0; //! Substitute diff operator
virtual int findTargetVariable(int lhs_symb_id) const = 0;
virtual expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const = 0; virtual expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const = 0;
//! Mark unary ops nodes to be substituted
/*! The various nodes that are equivalent from a static point of view are
grouped together in the nodes table, referenced by their maximum
lag.
TODO: This is technically wrong for complex expressions, and
should be improved by grouping together only those nodes that are
equivalent up to a shift in all leads/lags. */
virtual void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const = 0;
//! Substitute unary ops nodes
virtual expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const = 0; virtual expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const = 0;
//! Substitute pac_expectation operator //! Substitute pac_expectation operator
virtual expr_t substitutePacExpectation(const string & name, expr_t subexpr) = 0; virtual expr_t substitutePacExpectation(const string & name, expr_t subexpr) = 0;
virtual int findTargetVariable(int lhs_symb_id) const = 0;
//! Add ExprNodes to the provided datatree //! Add ExprNodes to the provided datatree
virtual expr_t clone(DataTree &datatree) const = 0; virtual expr_t clone(DataTree &datatree) const = 0;