Fix bug with diff or unary ops that have same static representation

Previously, for testing whether two diff() expressions or two unary ops were
the lead/lag of each other, the preprocessor would test whether they have the
same static representation. This is ok for simple expressions (e.g.
diff(x(-1))), but not for more complex ones (e.g. diff(x-y) and diff(x(-1)-y)
should not be given the same auxiliary variable).

This commit fixes this by properly constructing the equivalence relationship
and choosing a representative within each equivalence class. See the comments
above lag_equivalence_table_t in ExprNode.hh for more details.

Closes #27
issue#70
Sébastien Villemot 2019-10-22 14:56:28 +02:00
parent c5d223a79b
commit 8a83e08e79
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
7 changed files with 208 additions and 201 deletions

View File

@ -5021,10 +5021,10 @@ VarExpectationModelStatement::VarExpectationModelStatement(string model_name_arg
}
void
VarExpectationModelStatement::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, ExprNode::subst_table_t &subst_table)
VarExpectationModelStatement::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, ExprNode::subst_table_t &subst_table)
{
vector<BinaryOpNode *> neweqs;
expression = expression->substituteUnaryOpNodes(static_datatree, nodes, subst_table, neweqs);
expression = expression->substituteUnaryOpNodes(nodes, subst_table, neweqs);
if (neweqs.size() > 0)
{
cerr << "ERROR: the 'expression' option of var_expectation_model contains a variable with a unary operator that is not present in the VAR model" << endl;
@ -5033,10 +5033,10 @@ VarExpectationModelStatement::substituteUnaryOpNodes(DataTree &static_datatree,
}
void
VarExpectationModelStatement::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, ExprNode::subst_table_t &subst_table)
VarExpectationModelStatement::substituteDiff(const lag_equivalence_table_t &nodes, ExprNode::subst_table_t &subst_table)
{
vector<BinaryOpNode *> neweqs;
expression = expression->substituteDiff(static_datatree, diff_table, subst_table, neweqs);
expression = expression->substituteDiff(nodes, subst_table, neweqs);
if (neweqs.size() > 0)
{
cerr << "ERROR: the 'expression' option of var_expectation_model contains a diff'd variable that is not present in the VAR model" << endl;

View File

@ -1204,8 +1204,8 @@ private:
public:
VarExpectationModelStatement(string model_name_arg, expr_t expression_arg, string aux_model_name_arg,
string horizon_arg, expr_t discount_arg, const SymbolTable &symbol_table_arg);
void substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, ExprNode::subst_table_t &subst_table);
void substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, ExprNode::subst_table_t &subst_table);
void substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, ExprNode::subst_table_t &subst_table);
void substituteDiff(const lag_equivalence_table_t &diff_table, ExprNode::subst_table_t &subst_table);
// Analyzes the linear combination contained in the 'expression' option
/* Must be called after substituteUnaryOpNodes() and substituteDiff() (in
that order) */

View File

@ -6265,26 +6265,26 @@ DynamicModel::findPacExpectationEquationNumbers(vector<int> &eqnumbers) const
}
}
pair<diff_table_t, ExprNode::subst_table_t>
DynamicModel::substituteUnaryOps(StaticModel &static_model)
pair<lag_equivalence_table_t, ExprNode::subst_table_t>
DynamicModel::substituteUnaryOps()
{
vector<int> eqnumbers(equations.size());
iota(eqnumbers.begin(), eqnumbers.end(), 0);
return substituteUnaryOps(static_model, eqnumbers);
return substituteUnaryOps(eqnumbers);
}
pair<diff_table_t, ExprNode::subst_table_t>
DynamicModel::substituteUnaryOps(StaticModel &static_model, const set<string> &var_model_eqtags)
pair<lag_equivalence_table_t, ExprNode::subst_table_t>
DynamicModel::substituteUnaryOps(const set<string> &var_model_eqtags)
{
vector<int> eqnumbers = getEquationNumbersFromTags(var_model_eqtags);
findPacExpectationEquationNumbers(eqnumbers);
return substituteUnaryOps(static_model, eqnumbers);
return substituteUnaryOps(eqnumbers);
}
pair<diff_table_t, ExprNode::subst_table_t>
DynamicModel::substituteUnaryOps(StaticModel &static_model, const vector<int> &eqnumbers)
pair<lag_equivalence_table_t, ExprNode::subst_table_t>
DynamicModel::substituteUnaryOps(const vector<int> &eqnumbers)
{
diff_table_t nodes;
lag_equivalence_table_t nodes;
ExprNode::subst_table_t subst_table;
// Mark unary ops to be substituted in model local variables that appear in selected equations
@ -6293,22 +6293,22 @@ DynamicModel::substituteUnaryOps(StaticModel &static_model, const vector<int> &e
equations[eqnumber]->collectVariables(SymbolType::modelLocalVariable, used_local_vars);
for (auto & it : local_variables_table)
if (used_local_vars.find(it.first) != used_local_vars.end())
it.second->findUnaryOpNodesForAuxVarCreation(static_model, nodes);
it.second->findUnaryOpNodesForAuxVarCreation(nodes);
// Mark unary ops to be substituted in selected equations
for (int eqnumber : eqnumbers)
equations[eqnumber]->findUnaryOpNodesForAuxVarCreation(static_model, nodes);
equations[eqnumber]->findUnaryOpNodesForAuxVarCreation(nodes);
// Substitute in model local variables
vector<BinaryOpNode *> neweqs;
for (auto & it : local_variables_table)
it.second = it.second->substituteUnaryOpNodes(static_model, nodes, subst_table, neweqs);
it.second = it.second->substituteUnaryOpNodes(nodes, subst_table, neweqs);
// Substitute in equations
for (auto & equation : equations)
{
auto *substeq = dynamic_cast<BinaryOpNode *>(equation->
substituteUnaryOpNodes(static_model, nodes, subst_table, neweqs));
substituteUnaryOpNodes(nodes, subst_table, neweqs));
assert(substeq != nullptr);
equation = substeq;
}
@ -6325,15 +6325,15 @@ DynamicModel::substituteUnaryOps(StaticModel &static_model, const vector<int> &e
return { nodes, subst_table };
}
pair<diff_table_t, ExprNode::subst_table_t>
DynamicModel::substituteDiff(StaticModel &static_model, vector<expr_t> &pac_growth)
pair<lag_equivalence_table_t, ExprNode::subst_table_t>
DynamicModel::substituteDiff(vector<expr_t> &pac_growth)
{
/* Note: at this point, we know that there is no diff operator with a lead,
because they have been expanded by DataTree::AddDiff().
Hence we can go forward with the substitution without worrying about the
expectation operator. */
diff_table_t diff_table;
lag_equivalence_table_t diff_nodes;
ExprNode::subst_table_t diff_subst_table;
// Mark diff operators to be substituted in model local variables
@ -6342,52 +6342,51 @@ DynamicModel::substituteDiff(StaticModel &static_model, vector<expr_t> &pac_grow
equation->collectVariables(SymbolType::modelLocalVariable, used_local_vars);
for (auto & it : local_variables_table)
if (used_local_vars.find(it.first) != used_local_vars.end())
it.second->findDiffNodes(static_model, diff_table);
it.second->findDiffNodes(diff_nodes);
// Mark diff operators to be substituted in equations
for (const auto & equation : equations)
equation->findDiffNodes(static_model, diff_table);
equation->findDiffNodes(diff_nodes);
/* Ensure that all diff operators appear once with their argument at current
period (i.e. maxLag=0).
period (i.e. index 0 in the equivalence class, see comment above
lag_equivalence_table_t in ExprNode.hh for details on the concepts).
If it is not the case, generate the corresponding expressions.
This is necessary to avoid lags of more than one in the auxiliary
equation, which would then be modified by subsequent transformations
(removing lags > 1), which in turn would break the recursive ordering
of auxiliary equations. See issue McModelTeam/McModelProject#95 */
for (auto &it : diff_table)
for (auto &it : diff_nodes)
{
auto iterator_arg_max_lag = it.second.rbegin();
int arg_max_lag = iterator_arg_max_lag->first;
expr_t arg_max_expr = iterator_arg_max_lag->second;
auto iterator_max_index = it.second.rbegin();
int max_index = iterator_max_index->first;
expr_t max_index_expr = iterator_max_index->second;
/* We compare arg_max_lag with the result of countDiffs(), in order to
properly handle nested diffs. See issue McModelTeam/McModelProject#97 */
while (arg_max_lag < 1 - it.first->countDiffs())
while (max_index < 0)
{
arg_max_lag++;
arg_max_expr = arg_max_expr->decreaseLeadsLags(-1);
it.second[arg_max_lag] = arg_max_expr;
max_index++;
max_index_expr = max_index_expr->decreaseLeadsLags(-1);
it.second[max_index] = max_index_expr;
}
}
// Substitute in model local variables
vector<BinaryOpNode *> neweqs;
for (auto & it : local_variables_table)
it.second = it.second->substituteDiff(static_model, diff_table, diff_subst_table, neweqs);
it.second = it.second->substituteDiff(diff_nodes, diff_subst_table, neweqs);
// Substitute in equations
for (auto & equation : equations)
{
auto *substeq = dynamic_cast<BinaryOpNode *>(equation->
substituteDiff(static_model, diff_table, diff_subst_table, neweqs));
substituteDiff(diff_nodes, diff_subst_table, neweqs));
assert(substeq != nullptr);
equation = substeq;
}
for (auto & it : pac_growth)
if (it != nullptr)
it = it->substituteDiff(static_model, diff_table, diff_subst_table, neweqs);
it = it->substituteDiff(diff_nodes, diff_subst_table, neweqs);
// Add new equations
for (auto & neweq : neweqs)
@ -6398,7 +6397,7 @@ DynamicModel::substituteDiff(StaticModel &static_model, vector<expr_t> &pac_grow
if (diff_subst_table.size() > 0)
cout << "Substitution of Diff operator: added " << neweqs.size() << " auxiliary variables and equations." << endl;
return { diff_table, diff_subst_table };
return { diff_nodes, diff_subst_table };
}
void

View File

@ -438,16 +438,16 @@ public:
void substituteAdl();
//! Creates aux vars for all unary operators
pair<diff_table_t, ExprNode::subst_table_t> substituteUnaryOps(StaticModel &static_model);
pair<lag_equivalence_table_t, ExprNode::subst_table_t> substituteUnaryOps();
//! Creates aux vars for unary operators in certain equations: originally implemented for support of VARs
pair<diff_table_t, ExprNode::subst_table_t> substituteUnaryOps(StaticModel &static_model, const set<string> &eq_tags);
pair<lag_equivalence_table_t, ExprNode::subst_table_t> substituteUnaryOps(const set<string> &eq_tags);
//! Creates aux vars for unary operators in certain equations: originally implemented for support of VARs
pair<diff_table_t, ExprNode::subst_table_t> substituteUnaryOps(StaticModel &static_model, const vector<int> &eqnumbers);
pair<lag_equivalence_table_t, ExprNode::subst_table_t> substituteUnaryOps(const vector<int> &eqnumbers);
//! Substitutes diff operator
pair<diff_table_t, ExprNode::subst_table_t> substituteDiff(StaticModel &static_model, vector<expr_t> &pac_growth);
pair<lag_equivalence_table_t, ExprNode::subst_table_t> substituteDiff(vector<expr_t> &pac_growth);
//! Substitute VarExpectation operators
void substituteVarExpectation(const map<string, expr_t> &subst_table);

View File

@ -116,6 +116,17 @@ ExprNode::checkIfTemporaryTermThenWrite(ostream &output, ExprNodeOutputType outp
return true;
}
pair<expr_t, int>
ExprNode::getLagEquivalenceClass() const
{
int index = maxLead();
if (index == numeric_limits<int>::min())
index = 0; // If no variable in the expression, the equivalence class has size 1
return { decreaseLeadsLags(index), index };
}
void
ExprNode::collectVariables(SymbolType type, set<int> &result) const
{
@ -660,12 +671,12 @@ NumConstNode::substituteVarExpectation(const map<string, expr_t> &subst_table) c
}
void
NumConstNode::findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const
NumConstNode::findDiffNodes(lag_equivalence_table_t &nodes) const
{
}
void
NumConstNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const
NumConstNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
{
}
@ -676,13 +687,13 @@ NumConstNode::findTargetVariable(int lhs_symb_id) const
}
expr_t
NumConstNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
NumConstNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
{
return const_cast<NumConstNode *>(this);
}
expr_t
NumConstNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
NumConstNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
{
return const_cast<NumConstNode *>(this);
}
@ -1662,12 +1673,12 @@ VariableNode::substituteVarExpectation(const map<string, expr_t> &subst_table) c
}
void
VariableNode::findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const
VariableNode::findDiffNodes(lag_equivalence_table_t &nodes) const
{
}
void
VariableNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const
VariableNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
{
}
@ -1678,14 +1689,14 @@ VariableNode::findTargetVariable(int lhs_symb_id) const
}
expr_t
VariableNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table,
VariableNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
vector<BinaryOpNode *> &neweqs) const
{
return const_cast<VariableNode *>(this);
}
expr_t
VariableNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
VariableNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
{
return const_cast<VariableNode *>(this);
}
@ -3451,49 +3462,27 @@ UnaryOpNode::createAuxVarForUnaryOpNode() const
}
void
UnaryOpNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const
UnaryOpNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
{
arg->findUnaryOpNodesForAuxVarCreation(static_datatree, nodes);
arg->findUnaryOpNodesForAuxVarCreation(nodes);
if (!this->createAuxVarForUnaryOpNode())
return;
expr_t sthis = this->toStatic(static_datatree);
int arg_max_lag = -arg->maxLagWithDiffsExpanded();
// TODO: implement recursive expression comparison, ensuring that the difference in the lags is constant across nodes
auto it = nodes.find(sthis);
if (it != nodes.end())
{
for (const auto &it1 : it->second)
if (arg == it1.second)
return;
it->second[arg_max_lag] = const_cast<UnaryOpNode *>(this);
}
else
nodes[sthis][arg_max_lag] = const_cast<UnaryOpNode *>(this);
auto [lag_equiv_repr, index] = getLagEquivalenceClass();
nodes[lag_equiv_repr][index] = const_cast<UnaryOpNode *>(this);
}
void
UnaryOpNode::findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const
UnaryOpNode::findDiffNodes(lag_equivalence_table_t &nodes) const
{
arg->findDiffNodes(static_datatree, diff_table);
arg->findDiffNodes(nodes);
if (op_code != UnaryOpcode::diff)
return;
expr_t sthis = this->toStatic(static_datatree);
int arg_max_lag = -arg->maxLagWithDiffsExpanded();
// TODO: implement recursive expression comparison, ensuring that the difference in the lags is constant across nodes
auto it = diff_table.find(sthis);
if (it != diff_table.end())
{
for (const auto &it1 : it->second)
if (arg == it1.second)
return;
it->second[arg_max_lag] = const_cast<UnaryOpNode *>(this);
}
else
diff_table[sthis][arg_max_lag] = const_cast<UnaryOpNode *>(this);
auto [lag_equiv_repr, index] = getLagEquivalenceClass();
nodes[lag_equiv_repr][index] = const_cast<UnaryOpNode *>(this);
}
int
@ -3503,11 +3492,11 @@ UnaryOpNode::findTargetVariable(int lhs_symb_id) const
}
expr_t
UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table,
UnaryOpNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
vector<BinaryOpNode *> &neweqs) const
{
// If this is not a diff node, then substitute recursively and return
expr_t argsubst = arg->substituteDiff(static_datatree, diff_table, subst_table, neweqs);
expr_t argsubst = arg->substituteDiff(nodes, subst_table, neweqs);
if (op_code != UnaryOpcode::diff)
return buildSimilarUnaryOpNode(argsubst, datatree);
@ -3515,9 +3504,10 @@ UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table,
if (sit != subst_table.end())
return const_cast<VariableNode *>(sit->second);
expr_t sthis = dynamic_cast<UnaryOpNode *>(this->toStatic(static_datatree));
auto it = diff_table.find(sthis);
if (it == diff_table.end() || it->second[-arg->maxLagWithDiffsExpanded()] != this)
auto [lag_equiv_repr, index] = getLagEquivalenceClass();
auto it = nodes.find(lag_equiv_repr);
if (it == nodes.end() || it->second.find(index) == it->second.end()
|| it->second.at(index) != this)
{
/* diff does not appear in VAR equations, so simply create aux var and return.
Once the comparison of expression nodes works, come back and remove
@ -3535,12 +3525,12 @@ UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table,
must be substituted. We create the auxiliary variable and fill the
substitution table for all those similar nodes, in an iteration going from
leads to lags. */
int last_arg_max_lag = 0;
int last_index = 0;
VariableNode *last_aux_var = nullptr;
for (auto rit = it->second.rbegin(); rit != it->second.rend(); ++rit)
{
expr_t argsubst = dynamic_cast<UnaryOpNode *>(rit->second)->
arg->substituteDiff(static_datatree, diff_table, subst_table, neweqs);
arg->substituteDiff(nodes, subst_table, neweqs);
auto *vn = dynamic_cast<VariableNode *>(argsubst);
int symb_id;
if (rit == it->second.rbegin())
@ -3551,7 +3541,7 @@ UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table,
symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, argsubst);
// make originating aux var & equation
last_arg_max_lag = rit->first;
last_index = rit->first;
last_aux_var = datatree.AddVariable(symb_id, 0);
//ORIG_AUX_DIFF = argsubst - argsubst(-1)
neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(last_aux_var,
@ -3563,9 +3553,9 @@ UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table,
{
// just add equation of form: AUX_DIFF = LAST_AUX_VAR(-1)
VariableNode *new_aux_var = nullptr;
for (int i = last_arg_max_lag; i > rit->first; i--)
for (int i = last_index; i > rit->first; i--)
{
if (i == last_arg_max_lag)
if (i == last_index)
symb_id = datatree.symbol_table.addDiffLagAuxiliaryVar(argsubst->idx, argsubst,
last_aux_var->symb_id, last_aux_var->lag);
else
@ -3578,24 +3568,24 @@ UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table,
last_aux_var = new_aux_var;
}
subst_table[rit->second] = dynamic_cast<VariableNode *>(new_aux_var);
last_arg_max_lag = rit->first;
last_index = rit->first;
}
}
return const_cast<VariableNode *>(subst_table[this]);
}
expr_t
UnaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
UnaryOpNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
{
auto sit = subst_table.find(this);
if (sit != subst_table.end())
return const_cast<VariableNode *>(sit->second);
/* If (the static equivalent of) this node is not marked for substitution,
/* If the equivalence class of this node is not marked for substitution,
then substitute recursively and return. */
auto *sthis = dynamic_cast<UnaryOpNode *>(this->toStatic(static_datatree));
auto it = nodes.find(sthis);
expr_t argsubst = arg->substituteUnaryOpNodes(static_datatree, nodes, subst_table, neweqs);
auto [lag_equiv_repr, index] = getLagEquivalenceClass();
auto it = nodes.find(lag_equiv_repr);
expr_t argsubst = arg->substituteUnaryOpNodes(nodes, subst_table, neweqs);
if (it == nodes.end())
return buildSimilarUnaryOpNode(argsubst, datatree);
@ -3671,7 +3661,7 @@ UnaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nod
must be substituted. We create the auxiliary variable and fill the
substitution table for all those similar nodes, in an iteration going from
leads to lags. */
int base_aux_lag = 0;
int base_index = 0;
VariableNode *aux_var = nullptr;
for (auto rit = it->second.rbegin(); rit != it->second.rend(); ++rit)
if (rit == it->second.rbegin())
@ -3697,10 +3687,10 @@ UnaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nod
neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(aux_var,
dynamic_cast<UnaryOpNode *>(rit->second))));
subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var);
base_aux_lag = rit->first;
base_index = rit->first;
}
else
subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var->decreaseLeadsLags(base_aux_lag - rit->first));
subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var->decreaseLeadsLags(base_index - rit->first));
sit = subst_table.find(this);
return const_cast<VariableNode *>(sit->second);
@ -5455,33 +5445,33 @@ BinaryOpNode::substituteVarExpectation(const map<string, expr_t> &subst_table) c
}
void
BinaryOpNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const
BinaryOpNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
{
arg1->findUnaryOpNodesForAuxVarCreation(static_datatree, nodes);
arg2->findUnaryOpNodesForAuxVarCreation(static_datatree, nodes);
arg1->findUnaryOpNodesForAuxVarCreation(nodes);
arg2->findUnaryOpNodesForAuxVarCreation(nodes);
}
void
BinaryOpNode::findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const
BinaryOpNode::findDiffNodes(lag_equivalence_table_t &nodes) const
{
arg1->findDiffNodes(static_datatree, diff_table);
arg2->findDiffNodes(static_datatree, diff_table);
arg1->findDiffNodes(nodes);
arg2->findDiffNodes(nodes);
}
expr_t
BinaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table,
BinaryOpNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
vector<BinaryOpNode *> &neweqs) const
{
expr_t arg1subst = arg1->substituteDiff(static_datatree, diff_table, subst_table, neweqs);
expr_t arg2subst = arg2->substituteDiff(static_datatree, diff_table, subst_table, neweqs);
expr_t arg1subst = arg1->substituteDiff(nodes, subst_table, neweqs);
expr_t arg2subst = arg2->substituteDiff(nodes, subst_table, neweqs);
return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
}
expr_t
BinaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
BinaryOpNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
{
expr_t arg1subst = arg1->substituteUnaryOpNodes(static_datatree, nodes, subst_table, neweqs);
expr_t arg2subst = arg2->substituteUnaryOpNodes(static_datatree, nodes, subst_table, neweqs);
expr_t arg1subst = arg1->substituteUnaryOpNodes(nodes, subst_table, neweqs);
expr_t arg2subst = arg2->substituteUnaryOpNodes(nodes, subst_table, neweqs);
return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
}
@ -6770,19 +6760,19 @@ TrinaryOpNode::substituteVarExpectation(const map<string, expr_t> &subst_table)
}
void
TrinaryOpNode::findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const
TrinaryOpNode::findDiffNodes(lag_equivalence_table_t &nodes) const
{
arg1->findDiffNodes(static_datatree, diff_table);
arg2->findDiffNodes(static_datatree, diff_table);
arg3->findDiffNodes(static_datatree, diff_table);
arg1->findDiffNodes(nodes);
arg2->findDiffNodes(nodes);
arg3->findDiffNodes(nodes);
}
void
TrinaryOpNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const
TrinaryOpNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
{
arg1->findUnaryOpNodesForAuxVarCreation(static_datatree, nodes);
arg2->findUnaryOpNodesForAuxVarCreation(static_datatree, nodes);
arg3->findUnaryOpNodesForAuxVarCreation(static_datatree, nodes);
arg1->findUnaryOpNodesForAuxVarCreation(nodes);
arg2->findUnaryOpNodesForAuxVarCreation(nodes);
arg3->findUnaryOpNodesForAuxVarCreation(nodes);
}
int
@ -6797,21 +6787,21 @@ TrinaryOpNode::findTargetVariable(int lhs_symb_id) const
}
expr_t
TrinaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table,
TrinaryOpNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
vector<BinaryOpNode *> &neweqs) const
{
expr_t arg1subst = arg1->substituteDiff(static_datatree, diff_table, subst_table, neweqs);
expr_t arg2subst = arg2->substituteDiff(static_datatree, diff_table, subst_table, neweqs);
expr_t arg3subst = arg3->substituteDiff(static_datatree, diff_table, subst_table, neweqs);
expr_t arg1subst = arg1->substituteDiff(nodes, subst_table, neweqs);
expr_t arg2subst = arg2->substituteDiff(nodes, subst_table, neweqs);
expr_t arg3subst = arg3->substituteDiff(nodes, subst_table, neweqs);
return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
}
expr_t
TrinaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
TrinaryOpNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
{
expr_t arg1subst = arg1->substituteUnaryOpNodes(static_datatree, nodes, subst_table, neweqs);
expr_t arg2subst = arg2->substituteUnaryOpNodes(static_datatree, nodes, subst_table, neweqs);
expr_t arg3subst = arg3->substituteUnaryOpNodes(static_datatree, nodes, subst_table, neweqs);
expr_t arg1subst = arg1->substituteUnaryOpNodes(nodes, subst_table, neweqs);
expr_t arg2subst = arg2->substituteUnaryOpNodes(nodes, subst_table, neweqs);
expr_t arg3subst = arg3->substituteUnaryOpNodes(nodes, subst_table, neweqs);
return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
}
@ -7241,17 +7231,17 @@ AbstractExternalFunctionNode::substituteVarExpectation(const map<string, expr_t>
}
void
AbstractExternalFunctionNode::findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const
AbstractExternalFunctionNode::findDiffNodes(lag_equivalence_table_t &nodes) const
{
for (auto argument : arguments)
argument->findDiffNodes(static_datatree, diff_table);
argument->findDiffNodes(nodes);
}
void
AbstractExternalFunctionNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const
AbstractExternalFunctionNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
{
for (auto argument : arguments)
argument->findUnaryOpNodesForAuxVarCreation(static_datatree, nodes);
argument->findUnaryOpNodesForAuxVarCreation(nodes);
}
int
@ -7267,21 +7257,21 @@ AbstractExternalFunctionNode::findTargetVariable(int lhs_symb_id) const
}
expr_t
AbstractExternalFunctionNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table,
AbstractExternalFunctionNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
vector<BinaryOpNode *> &neweqs) const
{
vector<expr_t> arguments_subst;
for (auto argument : arguments)
arguments_subst.push_back(argument->substituteDiff(static_datatree, diff_table, subst_table, neweqs));
arguments_subst.push_back(argument->substituteDiff(nodes, subst_table, neweqs));
return buildSimilarExternalFunctionNode(arguments_subst, datatree);
}
expr_t
AbstractExternalFunctionNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
AbstractExternalFunctionNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
{
vector<expr_t> arguments_subst;
for (auto argument : arguments)
arguments_subst.push_back(argument->substituteUnaryOpNodes(static_datatree, nodes, subst_table, neweqs));
arguments_subst.push_back(argument->substituteUnaryOpNodes(nodes, subst_table, neweqs));
return buildSimilarExternalFunctionNode(arguments_subst, datatree);
}
@ -8926,12 +8916,12 @@ VarExpectationNode::substituteVarExpectation(const map<string, expr_t> &subst_ta
}
void
VarExpectationNode::findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const
VarExpectationNode::findDiffNodes(lag_equivalence_table_t &nodes) const
{
}
void
VarExpectationNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const
VarExpectationNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
{
}
@ -8942,14 +8932,14 @@ VarExpectationNode::findTargetVariable(int lhs_symb_id) const
}
expr_t
VarExpectationNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table,
VarExpectationNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
vector<BinaryOpNode *> &neweqs) const
{
return const_cast<VarExpectationNode *>(this);
}
expr_t
VarExpectationNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
VarExpectationNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
{
return const_cast<VarExpectationNode *>(this);
}
@ -9349,12 +9339,12 @@ PacExpectationNode::substituteVarExpectation(const map<string, expr_t> &subst_ta
}
void
PacExpectationNode::findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const
PacExpectationNode::findDiffNodes(lag_equivalence_table_t &nodes) const
{
}
void
PacExpectationNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const
PacExpectationNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
{
}
@ -9365,14 +9355,14 @@ PacExpectationNode::findTargetVariable(int lhs_symb_id) const
}
expr_t
PacExpectationNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table,
PacExpectationNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
vector<BinaryOpNode *> &neweqs) const
{
return const_cast<PacExpectationNode *>(this);
}
expr_t
PacExpectationNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
PacExpectationNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
{
return const_cast<PacExpectationNode *>(this);
}

View File

@ -63,9 +63,27 @@ using eval_context_t = map<int, double>;
//! Type for tracking first/second derivative functions that have already been written as temporary terms
using deriv_node_temp_terms_t = map<pair<int, vector<expr_t>>, int>;
//! Type for the substitution map used in the process of substitutitng diff expressions
//! diff_table[static_expr_t][lag] -> [dynamic_expr_t]
using diff_table_t = map<expr_t, map<int, expr_t>>;
//! Type for the substitution map used for creating aux. vars for diff and unary_ops
/*! Let ≅ be the equivalence relationship such that two expressions e₁ and e₂
are equivalent iff e can be obtained from e by shifting all leads/lags by
the same number of periods (e.g. x+yx+y).
For each equivalence class, we select a representative element, which is
the class member which has no lead and a variable appearing at current
period (in the previous example, it would be x+y). (Obviously, if there
is no variable in the expression, then there is only one element in the
class, and that one is the representative)
Each member of an equivalence class is represented by an integer,
corresponding to its distance to the representative element (e.g. x+y
has index 2 and x+y has index 4). The representative element has index 0
by definition.
The keys in the std::map are the representative elements of the various
equivalence classes. The values are themselves std::map that describe the
equivalence class: they associate indices of class members to the
expressions with which they should be substituted. */
using lag_equivalence_table_t = map<expr_t, map<int, expr_t>>;
//! Possible types of output when writing ExprNode(s)
enum class ExprNodeOutputType
@ -210,6 +228,11 @@ class ExprNode
// Internal helper for matchVariableTimesConstantTimesParam()
virtual void matchVTCTPHelper(int &var_id, int &lag, int &param_id, double &constant, bool at_denominator) const;
/* Computes the representative element and the index under the
lag-equivalence relationship. See the comment above
lag_equivalence_table_t for an explanation of these concepts. */
pair<expr_t, int> getLagEquivalenceClass() const;
public:
ExprNode(DataTree &datatree_arg, int idx_arg);
virtual ~ExprNode() = default;
@ -561,26 +584,20 @@ class ExprNode
virtual expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const = 0;
//! Mark diff nodes to be substituted
/*! The various nodes that are equivalent from a static point of view are
grouped together in the nodes table, referenced by their maximum
lag.
TODO: This is technically wrong for complex expressions, and
should be improved by grouping together only those nodes that are
equivalent up to a shift in all leads/lags. */
virtual void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const = 0;
/*! The various nodes that are equivalent up to a shift of leads/lags are
grouped together in the nodes table. See the comment above
lag_equivalence_table_t for more details. */
virtual void findDiffNodes(lag_equivalence_table_t &nodes) const = 0;
//! Substitute diff operator
virtual expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const = 0;
virtual expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const = 0;
//! Mark unary ops nodes to be substituted
/*! The various nodes that are equivalent from a static point of view are
grouped together in the nodes table, referenced by their maximum
lag.
TODO: This is technically wrong for complex expressions, and
should be improved by grouping together only those nodes that are
equivalent up to a shift in all leads/lags. */
virtual void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const = 0;
/*! The various nodes that are equivalent up to a shift of leads/lags are
grouped together in the nodes table. See the comment above
lag_equivalence_table_t for more details. */
virtual void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const = 0;
//! Substitute unary ops nodes
virtual expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const = 0;
virtual expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const = 0;
//! Substitute pac_expectation operator
virtual expr_t substitutePacExpectation(const string & name, expr_t subexpr) = 0;
@ -725,11 +742,11 @@ public:
expr_t substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const override;
expr_t substituteAdl() const override;
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
void findDiffNodes(lag_equivalence_table_t &nodes) const override;
void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override;
int findTargetVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(const string & name, expr_t subexpr) override;
expr_t decreaseLeadsLagsPredeterminedVariables() const override;
expr_t differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
@ -808,11 +825,11 @@ public:
expr_t substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const override;
expr_t substituteAdl() const override;
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
void findDiffNodes(lag_equivalence_table_t &nodes) const override;
void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override;
int findTargetVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(const string & name, expr_t subexpr) override;
expr_t decreaseLeadsLagsPredeterminedVariables() const override;
expr_t differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
@ -918,12 +935,12 @@ public:
expr_t substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const override;
expr_t substituteAdl() const override;
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
void findDiffNodes(lag_equivalence_table_t &nodes) const override;
bool createAuxVarForUnaryOpNode() const;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override;
int findTargetVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(const string & name, expr_t subexpr) override;
expr_t decreaseLeadsLagsPredeterminedVariables() const override;
expr_t differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
@ -1031,13 +1048,13 @@ public:
expr_t substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const override;
expr_t substituteAdl() const override;
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
void findDiffNodes(lag_equivalence_table_t &nodes) const override;
void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override;
bool findTargetVariableHelper1(int lhs_symb_id, int rhs_symb_id) const;
int findTargetVariableHelper(const expr_t arg1, const expr_t arg2, int lhs_symb_id) const;
int findTargetVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(const string & name, expr_t subexpr) override;
expr_t decreaseLeadsLagsPredeterminedVariables() const override;
expr_t differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
@ -1161,11 +1178,11 @@ public:
expr_t substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const override;
expr_t substituteAdl() const override;
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
void findDiffNodes(lag_equivalence_table_t &nodes) const override;
void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override;
int findTargetVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(const string & name, expr_t subexpr) override;
expr_t decreaseLeadsLagsPredeterminedVariables() const override;
expr_t differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
@ -1279,11 +1296,11 @@ public:
expr_t substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const override;
expr_t substituteAdl() const override;
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
void findDiffNodes(lag_equivalence_table_t &nodes) const override;
void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override;
int findTargetVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(const string & name, expr_t subexpr) override;
virtual expr_t buildSimilarExternalFunctionNode(vector<expr_t> &alt_args, DataTree &alt_datatree) const = 0;
expr_t decreaseLeadsLagsPredeterminedVariables() const override;
@ -1482,11 +1499,11 @@ public:
expr_t substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const override;
expr_t substituteAdl() const override;
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
void findDiffNodes(lag_equivalence_table_t &nodes) const override;
void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override;
int findTargetVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(const string & name, expr_t subexpr) override;
pair<int, expr_t> normalizeEquation(int symb_id_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const override;
void compile(ostream &CompileCode, unsigned int &instruction_number,
@ -1563,11 +1580,11 @@ public:
expr_t substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const override;
expr_t substituteAdl() const override;
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
void findDiffNodes(lag_equivalence_table_t &nodes) const override;
void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override;
int findTargetVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(const string & name, expr_t subexpr) override;
pair<int, expr_t> normalizeEquation(int symb_id_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const override;
void compile(ostream &CompileCode, unsigned int &instruction_number,

View File

@ -440,16 +440,17 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, const
for (auto & it1 : it.second)
eqtags.insert(it1);
diff_table_t unary_ops_nodes;
// Create auxiliary variables and equations for unary ops
lag_equivalence_table_t unary_ops_nodes;
ExprNode::subst_table_t unary_ops_subst_table;
if (transform_unary_ops)
tie(unary_ops_nodes, unary_ops_subst_table) = dynamic_model.substituteUnaryOps(diff_static_model);
tie(unary_ops_nodes, unary_ops_subst_table) = dynamic_model.substituteUnaryOps();
else
// substitute only those unary ops that appear in auxiliary model equations
tie(unary_ops_nodes, unary_ops_subst_table) = dynamic_model.substituteUnaryOps(diff_static_model, eqtags);
tie(unary_ops_nodes, unary_ops_subst_table) = dynamic_model.substituteUnaryOps(eqtags);
// Create auxiliary variable and equations for Diff operators
auto [diff_table, diff_subst_table] = dynamic_model.substituteDiff(diff_static_model, pac_growth);
auto [diff_nodes, diff_subst_table] = dynamic_model.substituteDiff(pac_growth);
// Fill Trend Component Model Table
dynamic_model.fillTrendComponentModelTable();
@ -587,8 +588,8 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, const
/* Substitute unary and diff operators in the 'expression' option, then
match the linear combination in the expression option */
vems->substituteUnaryOpNodes(diff_static_model, unary_ops_nodes, unary_ops_subst_table);
vems->substituteDiff(diff_static_model, diff_table, diff_subst_table);
vems->substituteUnaryOpNodes(unary_ops_nodes, unary_ops_subst_table);
vems->substituteDiff(diff_nodes, diff_subst_table);
vems->matchExpression();
/* Create auxiliary parameters and the expression to be substituted into