Refactor code for collecting variables appearing in expressions

- rename ExprNode::collectVariables in ExprNode::collectDynamicVariables
- new ExprNode::collectVariables: same as above, but without lag information
- remove ExprNode::findUnusedEndogenous: essentially redundant with the above)
- remove ExprNode::collectModelLocalVariables: idem
time-shift
Sébastien Villemot 2013-11-29 15:32:49 +01:00
parent f7cdc39ff2
commit ed2f6d62c1
9 changed files with 74 additions and 123 deletions

View File

@ -3475,7 +3475,7 @@ DynamicModel::computeRamseyPolicyFOCs(const StaticModel &static_model)
int max_eq_lead = 0;
int max_eq_lag = 0;
for (int i = 0; i < (int) equations.size(); i++)
equations[i]->collectVariables(eEndogenous, dynvars);
equations[i]->collectDynamicVariables(eEndogenous, dynvars);
for (set<pair<int, int> >::const_iterator it = dynvars.begin();
it != dynvars.end(); it++)
@ -3572,11 +3572,17 @@ DynamicModel::toStatic(StaticModel &static_model) const
static_model.addAuxEquation((*it)->toStatic(static_model));
}
void
DynamicModel::findUnusedEndogenous(set<int> &unusedEndogs)
set<int>
DynamicModel::findUnusedEndogenous()
{
set<int> usedEndo, unusedEndo;
for (int i = 0; i < (int) equations.size(); i++)
equations[i]->findUnusedEndogenous(unusedEndogs);
equations[i]->collectVariables(eEndogenous, usedEndo);
set<int> allEndo = symbol_table.getEndogenous();
set_difference(allEndo.begin(), allEndo.end(),
usedEndo.begin(), usedEndo.end(),
inserter(unusedEndo, unusedEndo.begin()));
return unusedEndo;
}
void
@ -3585,17 +3591,17 @@ DynamicModel::computeDerivIDs()
set<pair<int, int> > dynvars;
for (int i = 0; i < (int) equations.size(); i++)
equations[i]->collectVariables(eEndogenous, dynvars);
equations[i]->collectDynamicVariables(eEndogenous, dynvars);
dynJacobianColsNbr = dynvars.size();
for (int i = 0; i < (int) equations.size(); i++)
{
equations[i]->collectVariables(eExogenous, dynvars);
equations[i]->collectVariables(eExogenousDet, dynvars);
equations[i]->collectVariables(eParameter, dynvars);
equations[i]->collectVariables(eTrend, dynvars);
equations[i]->collectVariables(eLogTrend, dynvars);
equations[i]->collectDynamicVariables(eExogenous, dynvars);
equations[i]->collectDynamicVariables(eExogenousDet, dynvars);
equations[i]->collectDynamicVariables(eParameter, dynvars);
equations[i]->collectDynamicVariables(eTrend, dynvars);
equations[i]->collectDynamicVariables(eLogTrend, dynvars);
}
for (set<pair<int, int> >::const_iterator it = dynvars.begin();
@ -3993,7 +3999,7 @@ DynamicModel::substituteLeadLagInternal(aux_var_t type, bool deterministic_model
// Substitute in used model local variables
set<int> used_local_vars;
for (size_t i = 0; i < equations.size(); i++)
equations[i]->collectModelLocalVariables(used_local_vars);
equations[i]->collectVariables(eModelLocalVariable, used_local_vars);
for (set<int>::const_iterator it = used_local_vars.begin();
it != used_local_vars.end(); ++it)
@ -4230,7 +4236,7 @@ DynamicModel::isModelLocalVariableUsed() const
size_t i = 0;
while (i < equations.size() && used_local_vars.size() == 0)
{
equations[i]->collectModelLocalVariables(used_local_vars);
equations[i]->collectVariables(eModelLocalVariable, used_local_vars);
i++;
}
return used_local_vars.size() > 0;

View File

@ -223,7 +223,7 @@ public:
void toStatic(StaticModel &static_model) const;
//! Find endogenous variables not used in model
void findUnusedEndogenous(set<int> &unusedEndogs);
set<int> findUnusedEndogenous();
//! Copies a dynamic model (only the equations)
/*! It assumes that the dynamic model given in argument has just been allocated */

View File

@ -80,11 +80,20 @@ ExprNode::cost(const temporary_terms_t &temporary_terms, bool is_matlab) const
return 0;
}
void
ExprNode::collectVariables(SymbolType type, set<int> &result) const
{
set<pair<int, int> > symbs_lags;
collectDynamicVariables(type, symbs_lags);
transform(symbs_lags.begin(), symbs_lags.end(), inserter(result, result.begin()),
boost::bind(&pair<int,int>::first,_1));
}
void
ExprNode::collectEndogenous(set<pair<int, int> > &result) const
{
set<pair<int, int> > symb_ids;
collectVariables(eEndogenous, symb_ids);
collectDynamicVariables(eEndogenous, symb_ids);
for (set<pair<int, int> >::const_iterator it = symb_ids.begin();
it != symb_ids.end(); it++)
result.insert(make_pair(datatree.symbol_table.getTypeSpecificID(it->first), it->second));
@ -94,21 +103,12 @@ void
ExprNode::collectExogenous(set<pair<int, int> > &result) const
{
set<pair<int, int> > symb_ids;
collectVariables(eExogenous, symb_ids);
collectDynamicVariables(eExogenous, symb_ids);
for (set<pair<int, int> >::const_iterator it = symb_ids.begin();
it != symb_ids.end(); it++)
result.insert(make_pair(datatree.symbol_table.getTypeSpecificID(it->first), it->second));
}
void
ExprNode::collectModelLocalVariables(set<int> &result) const
{
set<pair<int, int> > symb_ids;
collectVariables(eModelLocalVariable, symb_ids);
transform(symb_ids.begin(), symb_ids.end(), inserter(result, result.begin()),
boost::bind(&pair<int,int>::first,_1));
}
void
ExprNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
temporary_terms_t &temporary_terms,
@ -325,12 +325,7 @@ NumConstNode::compile(ostream &CompileCode, unsigned int &instruction_number,
}
void
NumConstNode::collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const
{
}
void
NumConstNode::findUnusedEndogenous(set<int> &unusedEndogs) const
NumConstNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int> > &result) const
{
}
@ -875,20 +870,12 @@ VariableNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
}
void
VariableNode::collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const
VariableNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int> > &result) const
{
if (type == type_arg)
result.insert(make_pair(symb_id, lag));
if (type == eModelLocalVariable)
datatree.local_variables_table[symb_id]->collectVariables(type_arg, result);
}
void
VariableNode::findUnusedEndogenous(set<int> &unusedEndogs) const
{
set<int>::iterator it = unusedEndogs.find(symb_id);
if (it != unusedEndogs.end())
unusedEndogs.erase(it);
datatree.local_variables_table[symb_id]->collectDynamicVariables(type_arg, result);
}
pair<int, expr_t>
@ -2003,15 +1990,9 @@ UnaryOpNode::compile(ostream &CompileCode, unsigned int &instruction_number,
}
void
UnaryOpNode::collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const
UnaryOpNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int> > &result) const
{
arg->collectVariables(type_arg, result);
}
void
UnaryOpNode::findUnusedEndogenous(set<int> &unusedEndogs) const
{
arg->findUnusedEndogenous(unusedEndogs);
arg->collectDynamicVariables(type_arg, result);
}
pair<int, expr_t>
@ -3080,17 +3061,10 @@ BinaryOpNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &
}
void
BinaryOpNode::collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const
BinaryOpNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int> > &result) const
{
arg1->collectVariables(type_arg, result);
arg2->collectVariables(type_arg, result);
}
void
BinaryOpNode::findUnusedEndogenous(set<int> &unusedEndogs) const
{
arg1->findUnusedEndogenous(unusedEndogs);
arg2->findUnusedEndogenous(unusedEndogs);
arg1->collectDynamicVariables(type_arg, result);
arg2->collectDynamicVariables(type_arg, result);
}
expr_t
@ -4057,19 +4031,11 @@ TrinaryOpNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int
}
void
TrinaryOpNode::collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const
TrinaryOpNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int> > &result) const
{
arg1->collectVariables(type_arg, result);
arg2->collectVariables(type_arg, result);
arg3->collectVariables(type_arg, result);
}
void
TrinaryOpNode::findUnusedEndogenous(set<int> &unusedEndogs) const
{
arg1->findUnusedEndogenous(unusedEndogs);
arg2->findUnusedEndogenous(unusedEndogs);
arg3->findUnusedEndogenous(unusedEndogs);
arg1->collectDynamicVariables(type_arg, result);
arg2->collectDynamicVariables(type_arg, result);
arg3->collectDynamicVariables(type_arg, result);
}
pair<int, expr_t>
@ -4625,19 +4591,11 @@ ExternalFunctionNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
}
void
ExternalFunctionNode::collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const
ExternalFunctionNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int> > &result) const
{
for (vector<expr_t>::const_iterator it = arguments.begin();
it != arguments.end(); it++)
(*it)->collectVariables(type_arg, result);
}
void
ExternalFunctionNode::findUnusedEndogenous(set<int> &unusedEndogs) const
{
for (vector<expr_t>::const_iterator it = arguments.begin();
it != arguments.end(); it++)
(*it)->findUnusedEndogenous(unusedEndogs);
(*it)->collectDynamicVariables(type_arg, result);
}
void

View File

@ -202,14 +202,23 @@ public:
const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
deriv_node_temp_terms_t &tef_terms) const;
//! Computes the set of all variables of a given symbol type in the expression
//! Computes the set of all variables of a given symbol type in the expression (with information on lags)
/*!
Variables are stored as integer pairs of the form (symb_id, lag).
They are added to the set given in argument.
Note that model local variables are substituted by their expression in the computation
(and added if type_arg = ModelLocalVariable).
*/
virtual void collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const = 0;
virtual void collectDynamicVariables(SymbolType type_arg, set<pair<int, int> > &result) const = 0;
//! Computes the set of all variables of a given symbol type in the expression (without information on lags)
/*!
Variables are stored as symb_id.
They are added to the set given in argument.
Note that model local variables are substituted by their expression in the computation
(and added if type_arg = ModelLocalVariable).
*/
void collectVariables(SymbolType type_arg, set<int> &result) const;
//! Computes the set of endogenous variables in the expression
/*!
@ -227,18 +236,8 @@ public:
*/
virtual void collectExogenous(set<pair<int, int> > &result) const;
//! Computes the set of model local variables in the expression
/*!
Symbol IDs of these model local variables are added to the set given in argument.
Note that this method is called recursively on the expressions associated to the model local variables detected.
*/
virtual void collectModelLocalVariables(set<int> &result) const;
virtual void collectTemporary_terms(const temporary_terms_t &temporary_terms, temporary_terms_inuse_t &temporary_terms_inuse, int Curr_Block) const = 0;
//! Removes used endogenous variables from the provided list of endogs
virtual void findUnusedEndogenous(set<int> &unusedEndogs) const = 0;
virtual void computeTemporaryTerms(map<expr_t, int> &reference_count,
temporary_terms_t &temporary_terms,
map<expr_t, pair<int, int> > &first_occurence,
@ -438,8 +437,7 @@ public:
};
virtual void prepareForDerivation();
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, deriv_node_temp_terms_t &tef_terms) const;
virtual void collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
virtual void findUnusedEndogenous(set<int> &unusedEndogs) const;
virtual void collectDynamicVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
virtual void collectTemporary_terms(const temporary_terms_t &temporary_terms, temporary_terms_inuse_t &temporary_terms_inuse, int Curr_Block) const;
virtual double eval(const eval_context_t &eval_context) const throw (EvalException, EvalExternalFunctionException);
virtual void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const map_idx_t &map_idx, bool dynamic, bool steady_dynamic, deriv_node_temp_terms_t &tef_terms) const;
@ -484,8 +482,7 @@ public:
VariableNode(DataTree &datatree_arg, int symb_id_arg, int lag_arg);
virtual void prepareForDerivation();
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, deriv_node_temp_terms_t &tef_terms) const;
virtual void collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
virtual void findUnusedEndogenous(set<int> &unusedEndogs) const;
virtual void collectDynamicVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
virtual void computeTemporaryTerms(map<expr_t, int> &reference_count,
temporary_terms_t &temporary_terms,
map<expr_t, pair<int, int> > &first_occurence,
@ -564,8 +561,7 @@ public:
int Curr_block,
vector< vector<temporary_terms_t> > &v_temporary_terms,
int equation) const;
virtual void collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
virtual void findUnusedEndogenous(set<int> &unusedEndogs) const;
virtual void collectDynamicVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
virtual void collectTemporary_terms(const temporary_terms_t &temporary_terms, temporary_terms_inuse_t &temporary_terms_inuse, int Curr_Block) const;
static double eval_opcode(UnaryOpcode op_code, double v) throw (EvalException, EvalExternalFunctionException);
virtual double eval(const eval_context_t &eval_context) const throw (EvalException, EvalExternalFunctionException);
@ -643,8 +639,7 @@ public:
int Curr_block,
vector< vector<temporary_terms_t> > &v_temporary_terms,
int equation) const;
virtual void collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
virtual void findUnusedEndogenous(set<int> &unusedEndogs) const;
virtual void collectDynamicVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
virtual void collectTemporary_terms(const temporary_terms_t &temporary_terms, temporary_terms_inuse_t &temporary_terms_inuse, int Curr_Block) const;
static double eval_opcode(double v1, BinaryOpcode op_code, double v2, int derivOrder) throw (EvalException, EvalExternalFunctionException);
virtual double eval(const eval_context_t &eval_context) const throw (EvalException, EvalExternalFunctionException);
@ -738,8 +733,7 @@ public:
int Curr_block,
vector< vector<temporary_terms_t> > &v_temporary_terms,
int equation) const;
virtual void collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
virtual void findUnusedEndogenous(set<int> &unusedEndogs) const;
virtual void collectDynamicVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
virtual void collectTemporary_terms(const temporary_terms_t &temporary_terms, temporary_terms_inuse_t &temporary_terms_inuse, int Curr_Block) const;
static double eval_opcode(double v1, TrinaryOpcode op_code, double v2, double v3) throw (EvalException, EvalExternalFunctionException);
virtual double eval(const eval_context_t &eval_context) const throw (EvalException, EvalExternalFunctionException);
@ -810,8 +804,7 @@ public:
int Curr_block,
vector< vector<temporary_terms_t> > &v_temporary_terms,
int equation) const;
virtual void collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
virtual void findUnusedEndogenous(set<int> &unusedEndogs) const;
virtual void collectDynamicVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
virtual void collectTemporary_terms(const temporary_terms_t &temporary_terms, temporary_terms_inuse_t &temporary_terms_inuse, int Curr_Block) const;
virtual double eval(const eval_context_t &eval_context) const throw (EvalException, EvalExternalFunctionException);
unsigned int compileExternalFunctionArguments(ostream &CompileCode, unsigned int &instruction_number,

View File

@ -293,8 +293,7 @@ ModFile::transformPass(bool nostrict)
{
if (nostrict)
{
set<int> unusedEndogs = symbol_table.getEndogenous();
dynamic_model.findUnusedEndogenous(unusedEndogs);
set<int> unusedEndogs = dynamic_model.findUnusedEndogenous();
for (set<int>::iterator it = unusedEndogs.begin(); it != unusedEndogs.end(); it++)
{
symbol_table.changeType(*it, eUnusedEndogenous);

View File

@ -1179,7 +1179,7 @@ ModelTree::writeModelLocalVariables(ostream &output, ExprNodeOutputType output_t
const temporary_terms_t tt;
for (size_t i = 0; i < equations.size(); i++)
equations[i]->collectModelLocalVariables(used_local_vars);
equations[i]->collectVariables(eModelLocalVariable, used_local_vars);
for (set<int>::const_iterator it = used_local_vars.begin();
it != used_local_vars.end(); ++it)

View File

@ -374,10 +374,10 @@ ParsingDriver::end_nonstationary_var(bool log_deflator, expr_t deflator)
error("Variable " + e.name + " was listed more than once as following a trend.");
}
set<pair<int, int> > r;
set<int> r;
deflator->collectVariables(eEndogenous, r);
for (set<pair<int, int> >::const_iterator it = r.begin(); it != r.end(); ++it)
if (dynamic_model->isNonstationary(it->first))
for (set<int>::const_iterator it = r.begin(); it != r.end(); ++it)
if (dynamic_model->isNonstationary(*it))
error("The deflator contains a non-stationary endogenous variable. This is not allowed. Please use only stationary endogenous and/or {log_}trend_vars.");
declared_nonstationary_vars.clear();

View File

@ -286,23 +286,18 @@ ShocksStatement::checkPass(ModFileStructure &mod_file_struct, WarningConsolidati
}
// Fill in mod_file_struct.parameters_with_shocks_values (related to #469)
set<pair<int, int> > params_lags;
for (var_and_std_shocks_t::const_iterator it = var_shocks.begin();
it != var_shocks.end(); ++it)
it->second->collectVariables(eParameter, params_lags);
it->second->collectVariables(eParameter, mod_file_struct.parameters_within_shocks_values);
for (var_and_std_shocks_t::const_iterator it = std_shocks.begin();
it != std_shocks.end(); ++it)
it->second->collectVariables(eParameter, params_lags);
it->second->collectVariables(eParameter, mod_file_struct.parameters_within_shocks_values);
for (covar_and_corr_shocks_t::const_iterator it = covar_shocks.begin();
it != covar_shocks.end(); ++it)
it->second->collectVariables(eParameter, params_lags);
it->second->collectVariables(eParameter, mod_file_struct.parameters_within_shocks_values);
for (covar_and_corr_shocks_t::const_iterator it = corr_shocks.begin();
it != corr_shocks.end(); ++it)
it->second->collectVariables(eParameter, params_lags);
for (set<pair<int, int> >::const_iterator it = params_lags.begin();
it != params_lags.end(); ++it)
mod_file_struct.parameters_within_shocks_values.insert(it->first);
it->second->collectVariables(eParameter, mod_file_struct.parameters_within_shocks_values);
}
MShocksStatement::MShocksStatement(const det_shocks_t &det_shocks_arg,

View File

@ -78,16 +78,16 @@ SteadyStateModel::checkPass(bool ramsey_policy) const
// Check that expression has no undefined symbol
if (!ramsey_policy)
{
set<pair<int, int> > used_symbols;
set<int> used_symbols;
expr_t expr = def_table.find(symb_ids)->second;
expr->collectVariables(eEndogenous, used_symbols);
expr->collectVariables(eModFileLocalVariable, used_symbols);
for (set<pair<int, int> >::const_iterator it = used_symbols.begin();
for (set<int>::const_iterator it = used_symbols.begin();
it != used_symbols.end(); ++it)
if (find(so_far_defined.begin(), so_far_defined.end(), it->first)
if (find(so_far_defined.begin(), so_far_defined.end(), *it)
== so_far_defined.end())
{
cerr << "ERROR: in the 'steady_state' block, variable '" << symbol_table.getName(it->first)
cerr << "ERROR: in the 'steady_state' block, variable '" << symbol_table.getName(*it)
<< "' is undefined in the declaration of variable '" << symbol_table.getName(symb_ids[0]) << "'" << endl;
exit(EXIT_FAILURE);
}