Change prototype of DataTree::AddEqual()

This permits some simplifications.
issue#70
Sébastien Villemot 2020-04-02 19:10:59 +02:00
parent fb72472ee0
commit e88c05e3b8
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
4 changed files with 29 additions and 27 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright © 2003-2019 Dynare Team * Copyright © 2003-2020 Dynare Team
* *
* This file is part of Dynare. * This file is part of Dynare.
* *
@ -687,10 +687,12 @@ DataTree::AddPacExpectation(const string &model_name)
return p; return p;
} }
expr_t BinaryOpNode *
DataTree::AddEqual(expr_t iArg1, expr_t iArg2) DataTree::AddEqual(expr_t iArg1, expr_t iArg2)
{ {
return AddBinaryOp(iArg1, BinaryOpcode::equal, iArg2); /* We know that we can safely cast to BinaryOpNode because
BinaryOpCode::equal can never be reduced to a constant. */
return dynamic_cast<BinaryOpNode *>(AddBinaryOp(iArg1, BinaryOpcode::equal, iArg2));
} }
void void

View File

@ -1,5 +1,5 @@
/* /*
* Copyright © 2003-2019 Dynare Team * Copyright © 2003-2020 Dynare Team
* *
* This file is part of Dynare. * This file is part of Dynare.
* *
@ -244,7 +244,7 @@ public:
//! Add 2nd derivative of steady state w.r.t. parameter to model tree //! Add 2nd derivative of steady state w.r.t. parameter to model tree
expr_t AddSteadyStateParam2ndDeriv(expr_t iArg1, int param1_symb_id, int param2_symb_id); expr_t AddSteadyStateParam2ndDeriv(expr_t iArg1, int param1_symb_id, int param2_symb_id);
//! Adds "arg1=arg2" to model tree //! Adds "arg1=arg2" to model tree
expr_t AddEqual(expr_t iArg1, expr_t iArg2); BinaryOpNode *AddEqual(expr_t iArg1, expr_t iArg2);
//! Adds "var_expectation(model_name)" to model tree //! Adds "var_expectation(model_name)" to model tree
expr_t AddVarExpectation(const string &model_name); expr_t AddVarExpectation(const string &model_name);
//! Adds pac_expectation command to model tree //! Adds pac_expectation command to model tree

View File

@ -4621,9 +4621,9 @@ DynamicModel::addPacModelConsistentExpectationEquation(const string &name, int d
{ {
int symb_id = symbol_table.addDiffAuxiliaryVar(diff_node_to_search->idx, diff_node_to_search); int symb_id = symbol_table.addDiffAuxiliaryVar(diff_node_to_search->idx, diff_node_to_search);
target_base_diff_node = AddVariable(symb_id); target_base_diff_node = AddVariable(symb_id);
addEquation(dynamic_cast<BinaryOpNode *>(AddEqual(const_cast<VariableNode *>(target_base_diff_node), addEquation(AddEqual(const_cast<VariableNode *>(target_base_diff_node),
AddMinus(AddVariable(pac_target_symb_id), AddMinus(AddVariable(pac_target_symb_id),
AddVariable(pac_target_symb_id, -1)))), -1); AddVariable(pac_target_symb_id, -1))), -1);
neqs++; neqs++;
} }
@ -4635,8 +4635,8 @@ DynamicModel::addPacModelConsistentExpectationEquation(const string &name, int d
int symb_id = symbol_table.addDiffLeadAuxiliaryVar(this_diff_node->idx, this_diff_node, int symb_id = symbol_table.addDiffLeadAuxiliaryVar(this_diff_node->idx, this_diff_node,
last_aux_var->symb_id, last_aux_var->lag); last_aux_var->symb_id, last_aux_var->lag);
VariableNode *current_aux_var = AddVariable(symb_id); VariableNode *current_aux_var = AddVariable(symb_id);
addEquation(dynamic_cast<BinaryOpNode *>(AddEqual(current_aux_var, addEquation(AddEqual(current_aux_var,
AddVariable(last_aux_var->symb_id, 1))), -1); AddVariable(last_aux_var->symb_id, 1)), -1);
last_aux_var = current_aux_var; last_aux_var = current_aux_var;
target_aux_var_to_add[i] = current_aux_var; target_aux_var_to_add[i] = current_aux_var;
} }

View File

@ -252,7 +252,7 @@ ExprNode::createEndoLeadAuxiliaryVarForMyself(subst_table_t &subst_table, vector
if (auto it = subst_table.find(orig_expr); it == subst_table.end()) if (auto it = subst_table.find(orig_expr); it == subst_table.end())
{ {
int symb_id = datatree.symbol_table.addEndoLeadAuxiliaryVar(orig_expr->idx, substexpr); int symb_id = datatree.symbol_table.addEndoLeadAuxiliaryVar(orig_expr->idx, substexpr);
neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(datatree.AddVariable(symb_id, 0), substexpr))); neweqs.push_back(datatree.AddEqual(datatree.AddVariable(symb_id, 0), substexpr));
substexpr = datatree.AddVariable(symb_id, +1); substexpr = datatree.AddVariable(symb_id, +1);
assert(dynamic_cast<VariableNode *>(substexpr)); assert(dynamic_cast<VariableNode *>(substexpr));
subst_table[orig_expr] = dynamic_cast<VariableNode *>(substexpr); subst_table[orig_expr] = dynamic_cast<VariableNode *>(substexpr);
@ -287,7 +287,7 @@ ExprNode::createExoLeadAuxiliaryVarForMyself(subst_table_t &subst_table, vector<
if (auto it = subst_table.find(orig_expr); it == subst_table.end()) if (auto it = subst_table.find(orig_expr); it == subst_table.end())
{ {
int symb_id = datatree.symbol_table.addExoLeadAuxiliaryVar(orig_expr->idx, substexpr); int symb_id = datatree.symbol_table.addExoLeadAuxiliaryVar(orig_expr->idx, substexpr);
neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(datatree.AddVariable(symb_id, 0), substexpr))); neweqs.push_back(datatree.AddEqual(datatree.AddVariable(symb_id, 0), substexpr));
substexpr = datatree.AddVariable(symb_id, +1); substexpr = datatree.AddVariable(symb_id, +1);
assert(dynamic_cast<VariableNode *>(substexpr)); assert(dynamic_cast<VariableNode *>(substexpr));
subst_table[orig_expr] = dynamic_cast<VariableNode *>(substexpr); subst_table[orig_expr] = dynamic_cast<VariableNode *>(substexpr);
@ -1771,7 +1771,7 @@ VariableNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector
if (auto it = subst_table.find(orig_expr); it == subst_table.end()) if (auto it = subst_table.find(orig_expr); it == subst_table.end())
{ {
int aux_symb_id = datatree.symbol_table.addEndoLagAuxiliaryVar(symb_id, cur_lag+1, substexpr); int aux_symb_id = datatree.symbol_table.addEndoLagAuxiliaryVar(symb_id, cur_lag+1, substexpr);
neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(datatree.AddVariable(aux_symb_id, 0), substexpr))); neweqs.push_back(datatree.AddEqual(datatree.AddVariable(aux_symb_id, 0), substexpr));
substexpr = datatree.AddVariable(aux_symb_id, -1); substexpr = datatree.AddVariable(aux_symb_id, -1);
subst_table[orig_expr] = substexpr; subst_table[orig_expr] = substexpr;
} }
@ -1837,7 +1837,7 @@ VariableNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *
if (auto it = subst_table.find(orig_expr); it == subst_table.end()) if (auto it = subst_table.find(orig_expr); it == subst_table.end())
{ {
int aux_symb_id = datatree.symbol_table.addExoLagAuxiliaryVar(symb_id, cur_lag+1, substexpr); int aux_symb_id = datatree.symbol_table.addExoLagAuxiliaryVar(symb_id, cur_lag+1, substexpr);
neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(datatree.AddVariable(aux_symb_id, 0), substexpr))); neweqs.push_back(datatree.AddEqual(datatree.AddVariable(aux_symb_id, 0), substexpr));
substexpr = datatree.AddVariable(aux_symb_id, -1); substexpr = datatree.AddVariable(aux_symb_id, -1);
subst_table[orig_expr] = substexpr; subst_table[orig_expr] = substexpr;
} }
@ -1884,8 +1884,8 @@ VariableNode::differentiateForwardVars(const vector<string> &subset, subst_table
{ {
int aux_symb_id = datatree.symbol_table.addDiffForwardAuxiliaryVar(symb_id, datatree.AddMinus(datatree.AddVariable(symb_id, 0), int aux_symb_id = datatree.symbol_table.addDiffForwardAuxiliaryVar(symb_id, datatree.AddMinus(datatree.AddVariable(symb_id, 0),
datatree.AddVariable(symb_id, -1))); datatree.AddVariable(symb_id, -1)));
neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(datatree.AddVariable(aux_symb_id, 0), datatree.AddMinus(datatree.AddVariable(symb_id, 0), neweqs.push_back(datatree.AddEqual(datatree.AddVariable(aux_symb_id, 0), datatree.AddMinus(datatree.AddVariable(symb_id, 0),
datatree.AddVariable(symb_id, -1))))); datatree.AddVariable(symb_id, -1))));
diffvar = datatree.AddVariable(aux_symb_id, 1); diffvar = datatree.AddVariable(aux_symb_id, 1);
subst_table[this] = diffvar; subst_table[this] = diffvar;
} }
@ -3518,9 +3518,9 @@ UnaryOpNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t
this part, folding into the next loop. */ this part, folding into the next loop. */
int symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, const_cast<UnaryOpNode *>(this)); int symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, const_cast<UnaryOpNode *>(this));
VariableNode *aux_var = datatree.AddVariable(symb_id, 0); VariableNode *aux_var = datatree.AddVariable(symb_id, 0);
neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(aux_var, neweqs.push_back(datatree.AddEqual(aux_var,
datatree.AddMinus(argsubst, datatree.AddMinus(argsubst,
argsubst->decreaseLeadsLags(1))))); argsubst->decreaseLeadsLags(1))));
subst_table[this] = dynamic_cast<VariableNode *>(aux_var); subst_table[this] = dynamic_cast<VariableNode *>(aux_var);
return const_cast<VariableNode *>(subst_table[this]); return const_cast<VariableNode *>(subst_table[this]);
} }
@ -3548,9 +3548,9 @@ UnaryOpNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t
last_index = rit->first; last_index = rit->first;
last_aux_var = datatree.AddVariable(symb_id, 0); last_aux_var = datatree.AddVariable(symb_id, 0);
//ORIG_AUX_DIFF = argsubst - argsubst(-1) //ORIG_AUX_DIFF = argsubst - argsubst(-1)
neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(last_aux_var, neweqs.push_back(datatree.AddEqual(last_aux_var,
datatree.AddMinus(argsubst, datatree.AddMinus(argsubst,
argsubst->decreaseLeadsLags(1))))); argsubst->decreaseLeadsLags(1))));
subst_table[rit->second] = dynamic_cast<VariableNode *>(last_aux_var); subst_table[rit->second] = dynamic_cast<VariableNode *>(last_aux_var);
} }
else else
@ -3567,8 +3567,8 @@ UnaryOpNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t
last_aux_var->symb_id, last_aux_var->lag); last_aux_var->symb_id, last_aux_var->lag);
new_aux_var = datatree.AddVariable(symb_id, 0); new_aux_var = datatree.AddVariable(symb_id, 0);
neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(new_aux_var, neweqs.push_back(datatree.AddEqual(new_aux_var,
last_aux_var->decreaseLeadsLags(1)))); last_aux_var->decreaseLeadsLags(1)));
last_aux_var = new_aux_var; last_aux_var = new_aux_var;
} }
subst_table[rit->second] = dynamic_cast<VariableNode *>(new_aux_var); subst_table[rit->second] = dynamic_cast<VariableNode *>(new_aux_var);
@ -3688,8 +3688,8 @@ UnaryOpNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_
symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, dynamic_cast<UnaryOpNode *>(rit->second), unary_op, symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, dynamic_cast<UnaryOpNode *>(rit->second), unary_op,
vn->symb_id, vn->lag); vn->symb_id, vn->lag);
aux_var = datatree.AddVariable(symb_id, 0); aux_var = datatree.AddVariable(symb_id, 0);
neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(aux_var, neweqs.push_back(datatree.AddEqual(aux_var,
dynamic_cast<UnaryOpNode *>(rit->second)))); dynamic_cast<UnaryOpNode *>(rit->second)));
subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var); subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var);
base_index = rit->first; base_index = rit->first;
} }
@ -3794,7 +3794,7 @@ UnaryOpNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNo
//arg(lag-period) (holds entire subtree of arg(lag-period) //arg(lag-period) (holds entire subtree of arg(lag-period)
expr_t substexpr = (arg->substituteExpectation(subst_table, neweqs, partial_information_model))->decreaseLeadsLags(expectation_information_set); expr_t substexpr = (arg->substituteExpectation(subst_table, neweqs, partial_information_model))->decreaseLeadsLags(expectation_information_set);
assert(substexpr); assert(substexpr);
neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(newAuxE, substexpr))); //AUXE_period_arg.idx = arg(lag-period) neweqs.push_back(datatree.AddEqual(newAuxE, substexpr)); //AUXE_period_arg.idx = arg(lag-period)
newAuxE = datatree.AddVariable(symb_id, expectation_information_set); newAuxE = datatree.AddVariable(symb_id, expectation_information_set);
assert(dynamic_cast<VariableNode *>(newAuxE)); assert(dynamic_cast<VariableNode *>(newAuxE));