No longer store symbol type in VariableNode

This facilitates switching variable types on the fly. In particular, this
allows removing the hack in DynamicModel::updateAfterVariableChange() that way
basically recreating all the nodes after the type change.
issue#70
Sébastien Villemot 2018-10-09 17:50:04 +02:00
parent c47b6e6e4c
commit 215283005e
5 changed files with 43 additions and 67 deletions

View File

@ -4827,21 +4827,6 @@ DynamicModel::writeAuxVarRecursiveDefinitions(ostream &output, ExprNodeOutputTyp
}
}
void
DynamicModel::updateAfterVariableChange(DynamicModel &dm)
{
variable_node_map.clear();
unary_op_node_map.clear();
binary_op_node_map.clear();
trinary_op_node_map.clear();
external_function_node_map.clear();
first_deriv_external_function_node_map.clear();
second_deriv_external_function_node_map.clear();
cloneDynamic(dm);
dm.replaceMyEquations(*this);
}
void
DynamicModel::cloneDynamic(DynamicModel &dynamic_model) const
{

View File

@ -378,9 +378,6 @@ public:
/*! It assumes that the dynamic model given in argument has just been allocated */
void cloneDynamic(DynamicModel &dynamic_model) const;
//! update equations after variable type change in model block
void updateAfterVariableChange(DynamicModel &dynamic_model);
//! Replaces model equations with derivatives of Lagrangian w.r.t. endogenous
void computeRamseyPolicyFOCs(const StaticModel &static_model, const bool nopreprocessoroutput);
//! Replaces the model equations in dynamic_model with those in this model

View File

@ -714,12 +714,11 @@ NumConstNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, c
VariableNode::VariableNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg, int lag_arg) :
ExprNode{datatree_arg, idx_arg},
symb_id{symb_id_arg},
type{datatree.symbol_table.getType(symb_id_arg)},
lag{lag_arg}
{
// It makes sense to allow a lead/lag on parameters: during steady state calibration, endogenous and parameters can be swapped
assert(type != SymbolType::externalFunction
&& (lag == 0 || (type != SymbolType::modelLocalVariable && type != SymbolType::modFileLocalVariable)));
assert(get_type() != SymbolType::externalFunction
&& (lag == 0 || (get_type() != SymbolType::modelLocalVariable && get_type() != SymbolType::modFileLocalVariable)));
}
void
@ -731,7 +730,7 @@ VariableNode::prepareForDerivation()
preparedForDerivation = true;
// Fill in non_null_derivatives
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
case SymbolType::exogenous:
@ -765,7 +764,7 @@ VariableNode::prepareForDerivation()
expr_t
VariableNode::computeDerivative(int deriv_id)
{
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
case SymbolType::exogenous:
@ -806,7 +805,7 @@ VariableNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, t
auto it = temporary_terms.find(const_cast<VariableNode *>(this));
if (it != temporary_terms.end())
temporary_terms_inuse.insert(idx);
if (type == SymbolType::modelLocalVariable)
if (get_type() == SymbolType::modelLocalVariable)
datatree.getLocalVariable(symb_id)->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block);
}
@ -821,7 +820,7 @@ VariableNode::writeJsonAST(ostream &output) const
{
output << "{\"node_type\" : \"VariableNode\", "
<< "\"name\" : \"" << datatree.symbol_table.getName(symb_id) << "\", \"type\" : \"";
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
output << "endogenous";
@ -896,6 +895,7 @@ VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_idxs_t &temporary_terms_idxs,
const deriv_node_temp_terms_t &tef_terms) const
{
auto type = get_type();
if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
return;
@ -1150,7 +1150,7 @@ VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
expr_t
VariableNode::substituteStaticAuxiliaryVariable() const
{
if (type == SymbolType::endogenous)
if (get_type() == SymbolType::endogenous)
{
try
{
@ -1179,6 +1179,7 @@ VariableNode::compile(ostream &CompileCode, unsigned int &instruction_number,
const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
const deriv_node_temp_terms_t &tef_terms) const
{
auto type = get_type();
if (type == SymbolType::modelLocalVariable || type == SymbolType::modFileLocalVariable)
datatree.getLocalVariable(symb_id)->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms);
else
@ -1255,14 +1256,14 @@ VariableNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
vector<vector<temporary_terms_t>> &v_temporary_terms,
int equation) const
{
if (type == SymbolType::modelLocalVariable)
if (get_type() == SymbolType::modelLocalVariable)
datatree.getLocalVariable(symb_id)->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation);
}
void
VariableNode::collectVARLHSVariable(set<expr_t> &result) const
{
if (type == SymbolType::endogenous && lag == 0)
if (get_type() == SymbolType::endogenous && lag == 0)
result.insert(const_cast<VariableNode *>(this));
else
{
@ -1274,9 +1275,9 @@ VariableNode::collectVARLHSVariable(set<expr_t> &result) const
void
VariableNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &result) const
{
if (type == type_arg)
if (get_type() == type_arg)
result.emplace(symb_id, lag);
if (type == SymbolType::modelLocalVariable)
if (get_type() == SymbolType::modelLocalVariable)
datatree.getLocalVariable(symb_id)->collectDynamicVariables(type_arg, result);
}
@ -1295,7 +1296,7 @@ VariableNode::normalizeEquation(int var_endo, vector<pair<int, pair<expr_t, expr
the flag is equal to 2.
- an expression equal to the RHS if flag = 0 and equal to NULL elsewhere
*/
if (type == SymbolType::endogenous)
if (get_type() == SymbolType::endogenous)
{
if (datatree.symbol_table.getTypeSpecificID(symb_id) == var_endo && lag == 0)
/* the endogenous variable */
@ -1305,7 +1306,7 @@ VariableNode::normalizeEquation(int var_endo, vector<pair<int, pair<expr_t, expr
}
else
{
if (type == SymbolType::parameter)
if (get_type() == SymbolType::parameter)
return { 0, datatree.AddVariable(symb_id, 0) };
else
return { 0, datatree.AddVariable(symb_id, lag) };
@ -1315,7 +1316,7 @@ VariableNode::normalizeEquation(int var_endo, vector<pair<int, pair<expr_t, expr
expr_t
VariableNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
{
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
case SymbolType::exogenous:
@ -1380,7 +1381,7 @@ VariableNode::toStatic(DataTree &static_datatree) const
void
VariableNode::computeXrefs(EquationInfo &ei) const
{
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
ei.endo.emplace(symb_id, lag);
@ -1409,6 +1410,12 @@ VariableNode::computeXrefs(EquationInfo &ei) const
}
}
SymbolType
VariableNode::get_type() const
{
return datatree.symbol_table.getType(symb_id);
}
expr_t
VariableNode::cloneDynamic(DataTree &dynamic_datatree) const
{
@ -1418,7 +1425,7 @@ VariableNode::cloneDynamic(DataTree &dynamic_datatree) const
int
VariableNode::maxEndoLead() const
{
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
return max(lag, 0);
@ -1432,7 +1439,7 @@ VariableNode::maxEndoLead() const
int
VariableNode::maxExoLead() const
{
switch (type)
switch (get_type())
{
case SymbolType::exogenous:
return max(lag, 0);
@ -1446,7 +1453,7 @@ VariableNode::maxExoLead() const
int
VariableNode::maxEndoLag() const
{
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
return max(-lag, 0);
@ -1460,7 +1467,7 @@ VariableNode::maxEndoLag() const
int
VariableNode::maxExoLag() const
{
switch (type)
switch (get_type())
{
case SymbolType::exogenous:
return max(-lag, 0);
@ -1474,7 +1481,7 @@ VariableNode::maxExoLag() const
int
VariableNode::maxLead() const
{
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
return lag;
@ -1490,7 +1497,7 @@ VariableNode::maxLead() const
int
VariableNode::VarMinLag() const
{
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
return -lag;
@ -1509,7 +1516,7 @@ VariableNode::VarMinLag() const
int
VariableNode::maxLag() const
{
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
return -lag;
@ -1595,7 +1602,7 @@ VariableNode::substitutePacExpectation(map<const PacExpectationNode *, const Bin
expr_t
VariableNode::decreaseLeadsLags(int n) const
{
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
case SymbolType::exogenous:
@ -1623,7 +1630,7 @@ expr_t
VariableNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
{
expr_t value;
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
if (lag <= 1)
@ -1648,7 +1655,7 @@ VariableNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector
expr_t value;
subst_table_t::const_iterator it;
int cur_lag;
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
if (lag >= -1)
@ -1696,7 +1703,7 @@ expr_t
VariableNode::substituteExoLead(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
{
expr_t value;
switch (type)
switch (get_type())
{
case SymbolType::exogenous:
if (lag <= 0)
@ -1721,7 +1728,7 @@ VariableNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *
expr_t value;
subst_table_t::const_iterator it;
int cur_lag;
switch (type)
switch (get_type())
{
case SymbolType::exogenous:
if (lag >= 0)
@ -1775,7 +1782,7 @@ expr_t
VariableNode::differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
{
expr_t value;
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
assert(lag <= 1);
@ -1820,7 +1827,7 @@ VariableNode::isNumConstNodeEqualTo(double value) const
bool
VariableNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
{
if (type == type_arg && datatree.symbol_table.getTypeSpecificID(symb_id) == variable_id && lag == lag_arg)
if (get_type() == type_arg && datatree.symbol_table.getTypeSpecificID(symb_id) == variable_id && lag == lag_arg)
return true;
else
return false;
@ -1835,7 +1842,7 @@ VariableNode::containsPacExpectation(const string &pac_model_name) const
bool
VariableNode::containsEndogenous() const
{
if (type == SymbolType::endogenous)
if (get_type() == SymbolType::endogenous)
return true;
else
return false;
@ -1844,7 +1851,7 @@ VariableNode::containsEndogenous() const
bool
VariableNode::containsExogenous() const
{
return (type == SymbolType::exogenous || type == SymbolType::exogenousDet);
return (get_type() == SymbolType::exogenous || get_type() == SymbolType::exogenousDet);
}
expr_t
@ -1947,8 +1954,8 @@ void
VariableNode::getPacNonOptimizingPart(set<pair<int, pair<pair<int, int>, double>>>
&params_vars_and_scaling_factor) const
{
if (type != SymbolType::endogenous
&& type != SymbolType::exogenous)
if (get_type() != SymbolType::endogenous
&& get_type() != SymbolType::exogenous)
{
cerr << "ERROR VariableNode::getPacNonOptimizingPart: Error in parsing PAC equation"
<< endl;
@ -1993,7 +2000,7 @@ void
VariableNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
{
string varname = datatree.symbol_table.getName(symb_id);
if (type == SymbolType::endogenous)
if (get_type() == SymbolType::endogenous)
if (model_endos_and_lags.find(varname) == model_endos_and_lags.end())
model_endos_and_lags[varname] = min(model_endos_and_lags[varname], lag);
else

View File

@ -693,7 +693,6 @@ class VariableNode : public ExprNode
private:
//! Id from the symbol table
const int symb_id;
const SymbolType type;
//! A positive value is a lead, a negative is a lag
const int lag;
expr_t computeDerivative(int deriv_id) override;
@ -717,11 +716,7 @@ public:
void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const map_idx_t &map_idx, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override;
expr_t toStatic(DataTree &static_datatree) const override;
void computeXrefs(EquationInfo &ei) const override;
SymbolType
get_type() const
{
return type;
};
SymbolType get_type() const;
int
get_symb_id() const
{

View File

@ -369,14 +369,6 @@ ParsingDriver::declare_or_change_type(SymbolType new_type, const string &name)
symb_id = mod_file->symbol_table.getID(name);
mod_file->symbol_table.changeType(symb_id, new_type);
// change in equations in ModelTree
auto dm = make_unique<DynamicModel>(mod_file->symbol_table,
mod_file->num_constants,
mod_file->external_functions_table,
mod_file->trend_component_model_table,
mod_file->var_model_table);
mod_file->dynamic_model.updateAfterVariableChange(*dm);
// remove error messages
undeclared_model_vars.erase(name);
for (auto it = undeclared_model_variable_errors.begin();