No longer store symbol type in VariableNode

This facilitates switching variable types on the fly. In particular, this
allows removing the hack in DynamicModel::updateAfterVariableChange() that way
basically recreating all the nodes after the type change.
issue#70
Sébastien Villemot 2018-10-09 17:50:04 +02:00
parent c47b6e6e4c
commit 215283005e
5 changed files with 43 additions and 67 deletions

View File

@ -4827,21 +4827,6 @@ DynamicModel::writeAuxVarRecursiveDefinitions(ostream &output, ExprNodeOutputTyp
} }
} }
void
DynamicModel::updateAfterVariableChange(DynamicModel &dm)
{
variable_node_map.clear();
unary_op_node_map.clear();
binary_op_node_map.clear();
trinary_op_node_map.clear();
external_function_node_map.clear();
first_deriv_external_function_node_map.clear();
second_deriv_external_function_node_map.clear();
cloneDynamic(dm);
dm.replaceMyEquations(*this);
}
void void
DynamicModel::cloneDynamic(DynamicModel &dynamic_model) const DynamicModel::cloneDynamic(DynamicModel &dynamic_model) const
{ {

View File

@ -378,9 +378,6 @@ public:
/*! It assumes that the dynamic model given in argument has just been allocated */ /*! It assumes that the dynamic model given in argument has just been allocated */
void cloneDynamic(DynamicModel &dynamic_model) const; void cloneDynamic(DynamicModel &dynamic_model) const;
//! update equations after variable type change in model block
void updateAfterVariableChange(DynamicModel &dynamic_model);
//! Replaces model equations with derivatives of Lagrangian w.r.t. endogenous //! Replaces model equations with derivatives of Lagrangian w.r.t. endogenous
void computeRamseyPolicyFOCs(const StaticModel &static_model, const bool nopreprocessoroutput); void computeRamseyPolicyFOCs(const StaticModel &static_model, const bool nopreprocessoroutput);
//! Replaces the model equations in dynamic_model with those in this model //! Replaces the model equations in dynamic_model with those in this model

View File

@ -714,12 +714,11 @@ NumConstNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, c
VariableNode::VariableNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg, int lag_arg) : VariableNode::VariableNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg, int lag_arg) :
ExprNode{datatree_arg, idx_arg}, ExprNode{datatree_arg, idx_arg},
symb_id{symb_id_arg}, symb_id{symb_id_arg},
type{datatree.symbol_table.getType(symb_id_arg)},
lag{lag_arg} lag{lag_arg}
{ {
// It makes sense to allow a lead/lag on parameters: during steady state calibration, endogenous and parameters can be swapped // It makes sense to allow a lead/lag on parameters: during steady state calibration, endogenous and parameters can be swapped
assert(type != SymbolType::externalFunction assert(get_type() != SymbolType::externalFunction
&& (lag == 0 || (type != SymbolType::modelLocalVariable && type != SymbolType::modFileLocalVariable))); && (lag == 0 || (get_type() != SymbolType::modelLocalVariable && get_type() != SymbolType::modFileLocalVariable)));
} }
void void
@ -731,7 +730,7 @@ VariableNode::prepareForDerivation()
preparedForDerivation = true; preparedForDerivation = true;
// Fill in non_null_derivatives // Fill in non_null_derivatives
switch (type) switch (get_type())
{ {
case SymbolType::endogenous: case SymbolType::endogenous:
case SymbolType::exogenous: case SymbolType::exogenous:
@ -765,7 +764,7 @@ VariableNode::prepareForDerivation()
expr_t expr_t
VariableNode::computeDerivative(int deriv_id) VariableNode::computeDerivative(int deriv_id)
{ {
switch (type) switch (get_type())
{ {
case SymbolType::endogenous: case SymbolType::endogenous:
case SymbolType::exogenous: case SymbolType::exogenous:
@ -806,7 +805,7 @@ VariableNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, t
auto it = temporary_terms.find(const_cast<VariableNode *>(this)); auto it = temporary_terms.find(const_cast<VariableNode *>(this));
if (it != temporary_terms.end()) if (it != temporary_terms.end())
temporary_terms_inuse.insert(idx); temporary_terms_inuse.insert(idx);
if (type == SymbolType::modelLocalVariable) if (get_type() == SymbolType::modelLocalVariable)
datatree.getLocalVariable(symb_id)->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block); datatree.getLocalVariable(symb_id)->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block);
} }
@ -821,7 +820,7 @@ VariableNode::writeJsonAST(ostream &output) const
{ {
output << "{\"node_type\" : \"VariableNode\", " output << "{\"node_type\" : \"VariableNode\", "
<< "\"name\" : \"" << datatree.symbol_table.getName(symb_id) << "\", \"type\" : \""; << "\"name\" : \"" << datatree.symbol_table.getName(symb_id) << "\", \"type\" : \"";
switch (type) switch (get_type())
{ {
case SymbolType::endogenous: case SymbolType::endogenous:
output << "endogenous"; output << "endogenous";
@ -896,6 +895,7 @@ VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_idxs_t &temporary_terms_idxs, const temporary_terms_idxs_t &temporary_terms_idxs,
const deriv_node_temp_terms_t &tef_terms) const const deriv_node_temp_terms_t &tef_terms) const
{ {
auto type = get_type();
if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs)) if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
return; return;
@ -1150,7 +1150,7 @@ VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
expr_t expr_t
VariableNode::substituteStaticAuxiliaryVariable() const VariableNode::substituteStaticAuxiliaryVariable() const
{ {
if (type == SymbolType::endogenous) if (get_type() == SymbolType::endogenous)
{ {
try try
{ {
@ -1179,6 +1179,7 @@ VariableNode::compile(ostream &CompileCode, unsigned int &instruction_number,
const map_idx_t &map_idx, bool dynamic, bool steady_dynamic, const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
const deriv_node_temp_terms_t &tef_terms) const const deriv_node_temp_terms_t &tef_terms) const
{ {
auto type = get_type();
if (type == SymbolType::modelLocalVariable || type == SymbolType::modFileLocalVariable) if (type == SymbolType::modelLocalVariable || type == SymbolType::modFileLocalVariable)
datatree.getLocalVariable(symb_id)->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms); datatree.getLocalVariable(symb_id)->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms);
else else
@ -1255,14 +1256,14 @@ VariableNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
vector<vector<temporary_terms_t>> &v_temporary_terms, vector<vector<temporary_terms_t>> &v_temporary_terms,
int equation) const int equation) const
{ {
if (type == SymbolType::modelLocalVariable) if (get_type() == SymbolType::modelLocalVariable)
datatree.getLocalVariable(symb_id)->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation); datatree.getLocalVariable(symb_id)->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation);
} }
void void
VariableNode::collectVARLHSVariable(set<expr_t> &result) const VariableNode::collectVARLHSVariable(set<expr_t> &result) const
{ {
if (type == SymbolType::endogenous && lag == 0) if (get_type() == SymbolType::endogenous && lag == 0)
result.insert(const_cast<VariableNode *>(this)); result.insert(const_cast<VariableNode *>(this));
else else
{ {
@ -1274,9 +1275,9 @@ VariableNode::collectVARLHSVariable(set<expr_t> &result) const
void void
VariableNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &result) const VariableNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &result) const
{ {
if (type == type_arg) if (get_type() == type_arg)
result.emplace(symb_id, lag); result.emplace(symb_id, lag);
if (type == SymbolType::modelLocalVariable) if (get_type() == SymbolType::modelLocalVariable)
datatree.getLocalVariable(symb_id)->collectDynamicVariables(type_arg, result); datatree.getLocalVariable(symb_id)->collectDynamicVariables(type_arg, result);
} }
@ -1295,7 +1296,7 @@ VariableNode::normalizeEquation(int var_endo, vector<pair<int, pair<expr_t, expr
the flag is equal to 2. the flag is equal to 2.
- an expression equal to the RHS if flag = 0 and equal to NULL elsewhere - an expression equal to the RHS if flag = 0 and equal to NULL elsewhere
*/ */
if (type == SymbolType::endogenous) if (get_type() == SymbolType::endogenous)
{ {
if (datatree.symbol_table.getTypeSpecificID(symb_id) == var_endo && lag == 0) if (datatree.symbol_table.getTypeSpecificID(symb_id) == var_endo && lag == 0)
/* the endogenous variable */ /* the endogenous variable */
@ -1305,7 +1306,7 @@ VariableNode::normalizeEquation(int var_endo, vector<pair<int, pair<expr_t, expr
} }
else else
{ {
if (type == SymbolType::parameter) if (get_type() == SymbolType::parameter)
return { 0, datatree.AddVariable(symb_id, 0) }; return { 0, datatree.AddVariable(symb_id, 0) };
else else
return { 0, datatree.AddVariable(symb_id, lag) }; return { 0, datatree.AddVariable(symb_id, lag) };
@ -1315,7 +1316,7 @@ VariableNode::normalizeEquation(int var_endo, vector<pair<int, pair<expr_t, expr
expr_t expr_t
VariableNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) VariableNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
{ {
switch (type) switch (get_type())
{ {
case SymbolType::endogenous: case SymbolType::endogenous:
case SymbolType::exogenous: case SymbolType::exogenous:
@ -1380,7 +1381,7 @@ VariableNode::toStatic(DataTree &static_datatree) const
void void
VariableNode::computeXrefs(EquationInfo &ei) const VariableNode::computeXrefs(EquationInfo &ei) const
{ {
switch (type) switch (get_type())
{ {
case SymbolType::endogenous: case SymbolType::endogenous:
ei.endo.emplace(symb_id, lag); ei.endo.emplace(symb_id, lag);
@ -1409,6 +1410,12 @@ VariableNode::computeXrefs(EquationInfo &ei) const
} }
} }
SymbolType
VariableNode::get_type() const
{
return datatree.symbol_table.getType(symb_id);
}
expr_t expr_t
VariableNode::cloneDynamic(DataTree &dynamic_datatree) const VariableNode::cloneDynamic(DataTree &dynamic_datatree) const
{ {
@ -1418,7 +1425,7 @@ VariableNode::cloneDynamic(DataTree &dynamic_datatree) const
int int
VariableNode::maxEndoLead() const VariableNode::maxEndoLead() const
{ {
switch (type) switch (get_type())
{ {
case SymbolType::endogenous: case SymbolType::endogenous:
return max(lag, 0); return max(lag, 0);
@ -1432,7 +1439,7 @@ VariableNode::maxEndoLead() const
int int
VariableNode::maxExoLead() const VariableNode::maxExoLead() const
{ {
switch (type) switch (get_type())
{ {
case SymbolType::exogenous: case SymbolType::exogenous:
return max(lag, 0); return max(lag, 0);
@ -1446,7 +1453,7 @@ VariableNode::maxExoLead() const
int int
VariableNode::maxEndoLag() const VariableNode::maxEndoLag() const
{ {
switch (type) switch (get_type())
{ {
case SymbolType::endogenous: case SymbolType::endogenous:
return max(-lag, 0); return max(-lag, 0);
@ -1460,7 +1467,7 @@ VariableNode::maxEndoLag() const
int int
VariableNode::maxExoLag() const VariableNode::maxExoLag() const
{ {
switch (type) switch (get_type())
{ {
case SymbolType::exogenous: case SymbolType::exogenous:
return max(-lag, 0); return max(-lag, 0);
@ -1474,7 +1481,7 @@ VariableNode::maxExoLag() const
int int
VariableNode::maxLead() const VariableNode::maxLead() const
{ {
switch (type) switch (get_type())
{ {
case SymbolType::endogenous: case SymbolType::endogenous:
return lag; return lag;
@ -1490,7 +1497,7 @@ VariableNode::maxLead() const
int int
VariableNode::VarMinLag() const VariableNode::VarMinLag() const
{ {
switch (type) switch (get_type())
{ {
case SymbolType::endogenous: case SymbolType::endogenous:
return -lag; return -lag;
@ -1509,7 +1516,7 @@ VariableNode::VarMinLag() const
int int
VariableNode::maxLag() const VariableNode::maxLag() const
{ {
switch (type) switch (get_type())
{ {
case SymbolType::endogenous: case SymbolType::endogenous:
return -lag; return -lag;
@ -1595,7 +1602,7 @@ VariableNode::substitutePacExpectation(map<const PacExpectationNode *, const Bin
expr_t expr_t
VariableNode::decreaseLeadsLags(int n) const VariableNode::decreaseLeadsLags(int n) const
{ {
switch (type) switch (get_type())
{ {
case SymbolType::endogenous: case SymbolType::endogenous:
case SymbolType::exogenous: case SymbolType::exogenous:
@ -1623,7 +1630,7 @@ expr_t
VariableNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const VariableNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
{ {
expr_t value; expr_t value;
switch (type) switch (get_type())
{ {
case SymbolType::endogenous: case SymbolType::endogenous:
if (lag <= 1) if (lag <= 1)
@ -1648,7 +1655,7 @@ VariableNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector
expr_t value; expr_t value;
subst_table_t::const_iterator it; subst_table_t::const_iterator it;
int cur_lag; int cur_lag;
switch (type) switch (get_type())
{ {
case SymbolType::endogenous: case SymbolType::endogenous:
if (lag >= -1) if (lag >= -1)
@ -1696,7 +1703,7 @@ expr_t
VariableNode::substituteExoLead(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const VariableNode::substituteExoLead(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
{ {
expr_t value; expr_t value;
switch (type) switch (get_type())
{ {
case SymbolType::exogenous: case SymbolType::exogenous:
if (lag <= 0) if (lag <= 0)
@ -1721,7 +1728,7 @@ VariableNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *
expr_t value; expr_t value;
subst_table_t::const_iterator it; subst_table_t::const_iterator it;
int cur_lag; int cur_lag;
switch (type) switch (get_type())
{ {
case SymbolType::exogenous: case SymbolType::exogenous:
if (lag >= 0) if (lag >= 0)
@ -1775,7 +1782,7 @@ expr_t
VariableNode::differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const VariableNode::differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
{ {
expr_t value; expr_t value;
switch (type) switch (get_type())
{ {
case SymbolType::endogenous: case SymbolType::endogenous:
assert(lag <= 1); assert(lag <= 1);
@ -1820,7 +1827,7 @@ VariableNode::isNumConstNodeEqualTo(double value) const
bool bool
VariableNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const VariableNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
{ {
if (type == type_arg && datatree.symbol_table.getTypeSpecificID(symb_id) == variable_id && lag == lag_arg) if (get_type() == type_arg && datatree.symbol_table.getTypeSpecificID(symb_id) == variable_id && lag == lag_arg)
return true; return true;
else else
return false; return false;
@ -1835,7 +1842,7 @@ VariableNode::containsPacExpectation(const string &pac_model_name) const
bool bool
VariableNode::containsEndogenous() const VariableNode::containsEndogenous() const
{ {
if (type == SymbolType::endogenous) if (get_type() == SymbolType::endogenous)
return true; return true;
else else
return false; return false;
@ -1844,7 +1851,7 @@ VariableNode::containsEndogenous() const
bool bool
VariableNode::containsExogenous() const VariableNode::containsExogenous() const
{ {
return (type == SymbolType::exogenous || type == SymbolType::exogenousDet); return (get_type() == SymbolType::exogenous || get_type() == SymbolType::exogenousDet);
} }
expr_t expr_t
@ -1947,8 +1954,8 @@ void
VariableNode::getPacNonOptimizingPart(set<pair<int, pair<pair<int, int>, double>>> VariableNode::getPacNonOptimizingPart(set<pair<int, pair<pair<int, int>, double>>>
&params_vars_and_scaling_factor) const &params_vars_and_scaling_factor) const
{ {
if (type != SymbolType::endogenous if (get_type() != SymbolType::endogenous
&& type != SymbolType::exogenous) && get_type() != SymbolType::exogenous)
{ {
cerr << "ERROR VariableNode::getPacNonOptimizingPart: Error in parsing PAC equation" cerr << "ERROR VariableNode::getPacNonOptimizingPart: Error in parsing PAC equation"
<< endl; << endl;
@ -1993,7 +2000,7 @@ void
VariableNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const VariableNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
{ {
string varname = datatree.symbol_table.getName(symb_id); string varname = datatree.symbol_table.getName(symb_id);
if (type == SymbolType::endogenous) if (get_type() == SymbolType::endogenous)
if (model_endos_and_lags.find(varname) == model_endos_and_lags.end()) if (model_endos_and_lags.find(varname) == model_endos_and_lags.end())
model_endos_and_lags[varname] = min(model_endos_and_lags[varname], lag); model_endos_and_lags[varname] = min(model_endos_and_lags[varname], lag);
else else

View File

@ -693,7 +693,6 @@ class VariableNode : public ExprNode
private: private:
//! Id from the symbol table //! Id from the symbol table
const int symb_id; const int symb_id;
const SymbolType type;
//! A positive value is a lead, a negative is a lag //! A positive value is a lead, a negative is a lag
const int lag; const int lag;
expr_t computeDerivative(int deriv_id) override; expr_t computeDerivative(int deriv_id) override;
@ -717,11 +716,7 @@ public:
void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const map_idx_t &map_idx, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override; void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const map_idx_t &map_idx, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override;
expr_t toStatic(DataTree &static_datatree) const override; expr_t toStatic(DataTree &static_datatree) const override;
void computeXrefs(EquationInfo &ei) const override; void computeXrefs(EquationInfo &ei) const override;
SymbolType SymbolType get_type() const;
get_type() const
{
return type;
};
int int
get_symb_id() const get_symb_id() const
{ {

View File

@ -369,14 +369,6 @@ ParsingDriver::declare_or_change_type(SymbolType new_type, const string &name)
symb_id = mod_file->symbol_table.getID(name); symb_id = mod_file->symbol_table.getID(name);
mod_file->symbol_table.changeType(symb_id, new_type); mod_file->symbol_table.changeType(symb_id, new_type);
// change in equations in ModelTree
auto dm = make_unique<DynamicModel>(mod_file->symbol_table,
mod_file->num_constants,
mod_file->external_functions_table,
mod_file->trend_component_model_table,
mod_file->var_model_table);
mod_file->dynamic_model.updateAfterVariableChange(*dm);
// remove error messages // remove error messages
undeclared_model_vars.erase(name); undeclared_model_vars.erase(name);
for (auto it = undeclared_model_variable_errors.begin(); for (auto it = undeclared_model_variable_errors.begin();