From 2f1382fab55804b2bc707ecd8c0e99913d38ce16 Mon Sep 17 00:00:00 2001 From: sebastien Date: Wed, 30 Sep 2009 15:10:31 +0000 Subject: [PATCH] preprocessor: * In stochastic mode, now transforms the model by removing leads and lags greater or equal to 2 (creating auxiliary variables and equations in the process) * Information about these variables is in structure M_.aux_vars * Automatically add the necessary initialization for auxiliary vars after the initval block or load_params_and_steady_state git-svn-id: https://www.dynare.org/svn/dynare/trunk@3002 ac1d8469-bf42-47a9-8791-bf33cf982152 --- DataTree.cc | 44 +-- DataTree.hh | 20 +- DynamicModel.cc | 205 +++++++++--- DynamicModel.hh | 27 +- DynareMain2.cc | 3 + ExprNode.cc | 624 ++++++++++++++++++++++++++++++------- ExprNode.hh | 94 +++++- ModFile.cc | 40 ++- ModFile.hh | 2 + ModelTree.cc | 9 + ModelTree.hh | 8 +- NumericalInitialization.cc | 18 +- NumericalInitialization.hh | 2 + StaticDllModel.cc | 107 +------ StaticDllModel.hh | 26 +- StaticModel.cc | 21 +- StaticModel.hh | 5 +- SymbolTable.cc | 75 ++++- SymbolTable.hh | 29 +- 19 files changed, 995 insertions(+), 364 deletions(-) diff --git a/DataTree.cc b/DataTree.cc index 8642d2b2..5c545779 100644 --- a/DataTree.cc +++ b/DataTree.cc @@ -59,23 +59,28 @@ DataTree::AddNumConstant(const string &value) return new NumConstNode(*this, id); } -NodeID -DataTree::AddVariableInternal(const string &name, int lag) +VariableNode * +DataTree::AddVariableInternal(int symb_id, int lag) { - int symb_id = symbol_table.getID(name); - variable_node_map_type::iterator it = variable_node_map.find(make_pair(symb_id, lag)); if (it != variable_node_map.end()) return it->second; else - return new VariableNode(*this, symb_id, lag, computeDerivID(symb_id, lag)); + return new VariableNode(*this, symb_id, lag); } -NodeID +VariableNode * DataTree::AddVariable(const string &name, int lag) +{ + int symb_id = symbol_table.getID(name); + return AddVariable(symb_id, lag); +} + +VariableNode * +DataTree::AddVariable(int symb_id, int lag) { assert(lag == 0); - return AddVariableInternal(name, lag); + return AddVariableInternal(symb_id, lag); } NodeID @@ -445,25 +450,6 @@ DataTree::AddUnknownFunction(const string &function_name, const vector & return new UnknownFunctionNode(*this, id, arguments); } -void -DataTree::fillEvalContext(eval_context_type &eval_context) const -{ - for(map::const_iterator it = local_variables_table.begin(); - it != local_variables_table.end(); it++) - { - try - { - const NodeID expression = it->second; - double val = expression->eval(eval_context); - eval_context[it->first] = val; - } - catch(ExprNode::EvalException &e) - { - // Do nothing - } - } -} - bool DataTree::isSymbolUsed(int symb_id) const { @@ -478,12 +464,6 @@ DataTree::isSymbolUsed(int symb_id) const return false; } -int -DataTree::computeDerivID(int symb_id, int lag) -{ - return -1; -} - int DataTree::getDerivID(int symb_id, int lag) const throw (UnknownDerivIDException) { diff --git a/DataTree.hh b/DataTree.hh index d83fba6b..494ee6a9 100644 --- a/DataTree.hh +++ b/DataTree.hh @@ -49,26 +49,24 @@ protected: //! Reference to numerical constants table NumericalConstants &num_constants; - typedef map num_const_node_map_type; + typedef map num_const_node_map_type; num_const_node_map_type num_const_node_map; //! Pair (symbol_id, lag) used as key - typedef map, NodeID> variable_node_map_type; + typedef map, VariableNode *> variable_node_map_type; variable_node_map_type variable_node_map; - typedef map, NodeID> unary_op_node_map_type; + typedef map, UnaryOpNode *> unary_op_node_map_type; unary_op_node_map_type unary_op_node_map; - typedef map, int>, NodeID> binary_op_node_map_type; + typedef map, int>, BinaryOpNode *> binary_op_node_map_type; binary_op_node_map_type binary_op_node_map; - typedef map,NodeID>, int>, NodeID> trinary_op_node_map_type; + typedef map,NodeID>, int>, TrinaryOpNode *> trinary_op_node_map_type; trinary_op_node_map_type trinary_op_node_map; //! Stores local variables value (maps symbol ID to corresponding node) map local_variables_table; //! Internal implementation of AddVariable(), without the check on the lag - NodeID AddVariableInternal(const string &name, int lag); + VariableNode *AddVariableInternal(int symb_id, int lag); - //! Computes a new deriv_id, or returns -1 if the variable is not one w.r. to which to derive - virtual int computeDerivID(int symb_id, int lag); private: typedef list node_list_type; //! The list of nodes @@ -100,7 +98,9 @@ public: NodeID AddNumConstant(const string &value); //! Adds a variable /*! The default implementation of the method refuses any lag != 0 */ - virtual NodeID AddVariable(const string &name, int lag = 0); + virtual VariableNode *AddVariable(int symb_id, int lag = 0); + //! Adds a variable, using its symbol name + VariableNode *AddVariable(const string &name, int lag = 0); //! Adds "arg1+arg2" to model tree NodeID AddPlus(NodeID iArg1, NodeID iArg2); //! Adds "arg1-arg2" to model tree @@ -172,8 +172,6 @@ public: //! Adds an unknown function node /*! \todo Use a map to share identical nodes */ NodeID AddUnknownFunction(const string &function_name, const vector &arguments); - //! Fill eval context with values of local variables - void fillEvalContext(eval_context_type &eval_context) const; //! Checks if a given symbol is used somewhere in the data tree bool isSymbolUsed(int symb_id) const; //! Thrown when trying to access an unknown variable by deriv_id diff --git a/DynamicModel.cc b/DynamicModel.cc index 39ecf286..1451f7a9 100644 --- a/DynamicModel.cc +++ b/DynamicModel.cc @@ -48,10 +48,10 @@ DynamicModel::DynamicModel(SymbolTable &symbol_table_arg, { } -NodeID -DynamicModel::AddVariable(const string &name, int lag) +VariableNode * +DynamicModel::AddVariable(int symb_id, int lag) { - return AddVariableInternal(name, lag); + return AddVariableInternal(symb_id, lag); } void @@ -2292,7 +2292,10 @@ DynamicModel::computingPass(bool jacobianExo, bool hessian, bool thirdDerivative { assert(jacobianExo || !(hessian || thirdDerivatives || paramsDerivatives)); - // Computes dynamic jacobian columns + // Prepare for derivation + computeDerivIDs(); + + // Computes dynamic jacobian columns, must be done after computeDerivIDs() computeDynJacobianCols(jacobianExo); // Compute derivatives w.r. to all endogenous, and possibly exogenous and exogenous deterministic @@ -2408,6 +2411,11 @@ DynamicModel::toStatic(StaticModel &static_model) const for (vector::const_iterator it = equations.begin(); it != equations.end(); it++) static_model.addEquation((*it)->toStatic(static_model)); + + // Convert auxiliary equations + for (deque::const_iterator it = aux_equations.begin(); + it != aux_equations.end(); it++) + static_model.addAuxEquation((*it)->toStatic(static_model)); } void @@ -2424,60 +2432,65 @@ DynamicModel::toStaticDll(StaticDllModel &static_model) const static_model.addEquation((*it)->toStatic(static_model)); } -int -DynamicModel::computeDerivID(int symb_id, int lag) +void +DynamicModel::computeDerivIDs() { - // Setting maximum and minimum lags - if (max_lead < lag) - max_lead = lag; - else if (-max_lag > lag) - max_lag = -lag; + set > dynvars; - SymbolType type = symbol_table.getType(symb_id); + for(int i = 0; i < (int) equations.size(); i++) + equations[i]->collectVariables(eEndogenous, dynvars); - switch (type) + dynJacobianColsNbr = dynvars.size(); + + for(int i = 0; i < (int) equations.size(); i++) { - case eEndogenous: - if (max_endo_lead < lag) - max_endo_lead = lag; - else if (-max_endo_lag > lag) - max_endo_lag = -lag; - break; - case eExogenous: - if (max_exo_lead < lag) - max_exo_lead = lag; - else if (-max_exo_lag > lag) - max_exo_lag = -lag; - break; - case eExogenousDet: - if (max_exo_det_lead < lag) - max_exo_det_lead = lag; - else if (-max_exo_det_lag > lag) - max_exo_det_lag = -lag; - break; - case eParameter: - // We wan't to compute a derivation ID for parameters - break; - default: - return -1; + equations[i]->collectVariables(eExogenous, dynvars); + equations[i]->collectVariables(eExogenousDet, dynvars); + equations[i]->collectVariables(eParameter, dynvars); } - // Check if dynamic variable already has a deriv_id - pair key = make_pair(symb_id, lag); - deriv_id_table_t::const_iterator it = deriv_id_table.find(key); - if (it != deriv_id_table.end()) - return it->second; + for(set >::const_iterator it = dynvars.begin(); + it != dynvars.end(); it++) + { + int lag = it->second; + SymbolType type = symbol_table.getType(it->first); - // Create a new deriv_id - int deriv_id = deriv_id_table.size(); + // Setting maximum and minimum lags + if (max_lead < lag) + max_lead = lag; + else if (-max_lag > lag) + max_lag = -lag; - deriv_id_table[key] = deriv_id; - inv_deriv_id_table.push_back(key); + switch (type) + { + case eEndogenous: + if (max_endo_lead < lag) + max_endo_lead = lag; + else if (-max_endo_lag > lag) + max_endo_lag = -lag; + break; + case eExogenous: + if (max_exo_lead < lag) + max_exo_lead = lag; + else if (-max_exo_lag > lag) + max_exo_lag = -lag; + break; + case eExogenousDet: + if (max_exo_det_lead < lag) + max_exo_det_lead = lag; + else if (-max_exo_det_lag > lag) + max_exo_det_lag = -lag; + break; + default: + break; + } - if (type == eEndogenous) - dynJacobianColsNbr++; + // Create a new deriv_id + int deriv_id = deriv_id_table.size(); - return deriv_id; + deriv_id_table[*it] = deriv_id; + inv_deriv_id_table.push_back(*it); + } } SymbolType @@ -2814,4 +2827,100 @@ DynamicModel::hessianHelper(ostream &output, int row_nb, int col_nb, ExprNodeOut output << RIGHT_ARRAY_SUBSCRIPT(output_type); } +void +DynamicModel::substituteLeadGreaterThanTwo() +{ + ExprNode::subst_table_t subst_table; + vector neweqs; + // Substitute in model local variables + for(map::iterator it = local_variables_table.begin(); + it != local_variables_table.end(); it++) + it->second = it->second->substituteLeadGreaterThanTwo(subst_table, neweqs); + + // Substitute in equations + for(int i = 0; i < (int) equations.size(); i++) + { + BinaryOpNode *substeq = dynamic_cast(equations[i]->substituteLeadGreaterThanTwo(subst_table, neweqs)); + assert(substeq != NULL); + equations[i] = substeq; + } + + // Add new equations + for(int i = 0; i < (int) neweqs.size(); i++) + addEquation(neweqs[i]); + + // Add the new set of equations at the *beginning* of aux_equations + copy(neweqs.rbegin(), neweqs.rend(), front_inserter(aux_equations)); + + if (neweqs.size() > 0) + cout << "Substitution of leads >= 2: added " << neweqs.size() << " auxiliary variables and equations." << endl; +} + +void +DynamicModel::substituteLagGreaterThanTwo() +{ + ExprNode::subst_table_t subst_table; + vector neweqs; + + // Substitute in model local variables + for(map::iterator it = local_variables_table.begin(); + it != local_variables_table.end(); it++) + it->second = it->second->substituteLagGreaterThanTwo(subst_table, neweqs); + + // Substitute in equations + for(int i = 0; i < (int) equations.size(); i++) + { + BinaryOpNode *substeq = dynamic_cast(equations[i]->substituteLagGreaterThanTwo(subst_table, neweqs)); + assert(substeq != NULL); + equations[i] = substeq; + } + + // Add new equations + for(int i = 0; i < (int) neweqs.size(); i++) + addEquation(neweqs[i]); + + // Add the new set of equations at the *beginning* of aux_equations + copy(neweqs.rbegin(), neweqs.rend(), front_inserter(aux_equations)); + + if (neweqs.size() > 0) + cout << "Substitution of lags >= 2: added " << neweqs.size() << " auxiliary variables and equations." << endl; +} + +void +DynamicModel::fillEvalContext(eval_context_type &eval_context) const +{ + // First, auxiliary variables + for(deque::const_iterator it = aux_equations.begin(); + it != aux_equations.end(); it++) + { + assert((*it)->get_op_code() == oEqual); + VariableNode *auxvar = dynamic_cast((*it)->get_arg1()); + assert(auxvar != NULL); + try + { + double val = (*it)->get_arg2()->eval(eval_context); + eval_context[auxvar->get_symb_id()] = val; + } + catch(ExprNode::EvalException &e) + { + // Do nothing + } + } + + // Second, model local variables + for(map::const_iterator it = local_variables_table.begin(); + it != local_variables_table.end(); it++) + { + try + { + const NodeID expression = it->second; + double val = expression->eval(eval_context); + eval_context[it->first] = val; + } + catch(ExprNode::EvalException &e) + { + // Do nothing + } + } +} diff --git a/DynamicModel.hh b/DynamicModel.hh index ed8952ca..d9fa77df 100644 --- a/DynamicModel.hh +++ b/DynamicModel.hh @@ -31,7 +31,6 @@ using namespace std; //! Stores a dynamic model class DynamicModel : public ModelTree { -public: private: typedef map, int> deriv_id_table_t; //! Maps a pair (symbol_id, lag) to a deriv ID @@ -44,20 +43,20 @@ private: map dyn_jacobian_cols_table; //! Maximum lag and lead over all types of variables (positive values) - /*! Set by computeDerivID() */ + /*! Set by computeDerivIDs() */ int max_lag, max_lead; //! Maximum lag and lead over endogenous variables (positive values) - /*! Set by computeDerivID() */ + /*! Set by computeDerivIDs() */ int max_endo_lag, max_endo_lead; //! Maximum lag and lead over exogenous variables (positive values) - /*! Set by computeDerivID() */ + /*! Set by computeDerivIDs() */ int max_exo_lag, max_exo_lead; //! Maximum lag and lead over deterministic exogenous variables (positive values) - /*! Set by computeDerivID() */ + /*! Set by computeDerivIDs() */ int max_exo_det_lag, max_exo_det_lead; //! Number of columns of dynamic jacobian - /*! Set by computeDerivID() and computeDynJacobianCols() */ + /*! Set by computeDerivID()s and computeDynJacobianCols() */ int dynJacobianColsNbr; //! Derivatives of the residuals w.r. to parameters @@ -113,7 +112,6 @@ private: //! Write chain rule derivative code of an equation w.r. to a variable void compileChainRuleDerivative(ofstream &code_file, int eq, int var, int lag, map_idx_type &map_idx) const; - virtual int computeDerivID(int symb_id, int lag); //! Get the type corresponding to a derivation ID SymbolType getTypeByDerivID(int deriv_id) const throw (UnknownDerivIDException); //! Get the lag corresponding to a derivation ID @@ -131,6 +129,10 @@ private: //! Collect only the first derivatives map >, NodeID> collect_first_order_derivatives_endogenous(); + //! Allocates the derivation IDs for all dynamic variables of the model + /*! Also computes max_{endo,exo}_{lead_lag}, and initializes dynJacobianColsNbr to the number of dynamic endos */ + void computeDerivIDs(); + //! Helper for writing the Jacobian elements in MATLAB and C /*! Writes either (i+1,j+1) or [i+j*no_eq] */ void jacobianHelper(ostream &output, int eq_nb, int col_nb, ExprNodeOutputType output_type) const; @@ -147,7 +149,7 @@ public: DynamicModel(SymbolTable &symbol_table_arg, NumericalConstants &num_constants); //! Adds a variable node /*! This implementation allows for non-zero lag */ - virtual NodeID AddVariable(const string &name, int lag = 0); + virtual VariableNode *AddVariable(int symb_id, int lag = 0); //! Absolute value under which a number is considered to be zero double cutoff; //! Compute the minimum feedback set in the dynamic model: @@ -196,6 +198,15 @@ public: //! Returns true indicating that this is a dynamic model virtual bool isDynamic() const { return true; }; + + //! Transforms the model by removing all leads greater or equal than 2 + void substituteLeadGreaterThanTwo(); + + //! Transforms the model by removing all lags greater or equal than 2 + void substituteLagGreaterThanTwo(); + + //! Fills eval context with values of model local variables and auxiliary variables + void fillEvalContext(eval_context_type &eval_context) const; }; #endif diff --git a/DynareMain2.cc b/DynareMain2.cc index 5ae76a02..15049213 100644 --- a/DynareMain2.cc +++ b/DynareMain2.cc @@ -35,6 +35,9 @@ main2(stringstream &in, string &basename, bool debug, bool clear_all, bool no_tm // Run checking pass mod_file->checkPass(); + // Perform transformations on the model (creation of auxiliary vars and equations) + mod_file->transformPass(); + // Evaluate parameters initialization, initval, endval and pounds mod_file->evalAllExpressions(); diff --git a/ExprNode.cc b/ExprNode.cc index 79728da0..87f4bc39 100644 --- a/ExprNode.cc +++ b/ExprNode.cc @@ -34,7 +34,7 @@ using namespace __gnu_cxx; #include "DataTree.hh" #include "BlockTriangular.hh" -ExprNode::ExprNode(DataTree &datatree_arg) : datatree(datatree_arg) +ExprNode::ExprNode(DataTree &datatree_arg) : datatree(datatree_arg), preparedForDerivation(false) { // Add myself to datatree datatree.node_list.push_back(this); @@ -50,6 +50,9 @@ ExprNode::~ExprNode() NodeID ExprNode::getDerivative(int deriv_id) { + if (!preparedForDerivation) + prepareForDerivation(); + // Return zero if derivative is necessarily null (using symbolic a priori) set::const_iterator it = non_null_derivatives.find(deriv_id); if (it == non_null_derivatives.end()) @@ -144,6 +147,42 @@ ExprNode::writeOutput(ostream &output) writeOutput(output, oMatlabOutsideModel, temporary_terms_type()); } +VariableNode * +ExprNode::createLeadAuxiliaryVarForMyself(subst_table_t &subst_table, vector &neweqs) const +{ + int n = maxEndoLead(); + assert(n >= 2); + + subst_table_t::const_iterator it = subst_table.find(this); + if (it != subst_table.end()) + return const_cast(it->second); + + NodeID substexpr = decreaseLeadsLags(n-1); + int lag = n-2; + + // Each iteration tries to create an auxvar such that auxvar(+1)=expr(-lag) + // At the beginning (resp. end) of each iteration, substexpr is an expression (possibly an auxvar) equivalent to expr(-lag-1) (resp. expr(-lag)) + while(lag >= 0) + { + NodeID orig_expr = decreaseLeadsLags(lag); + it = subst_table.find(orig_expr); + if (it == subst_table.end()) + { + int symb_id = datatree.symbol_table.addLeadAuxiliaryVar(orig_expr->idx); + neweqs.push_back(dynamic_cast(datatree.AddEqual(datatree.AddVariable(symb_id, 0), substexpr))); + substexpr = datatree.AddVariable(symb_id, +1); + assert(dynamic_cast(substexpr) != NULL); + subst_table[orig_expr] = dynamic_cast(substexpr); + } + else + substexpr = const_cast(it->second); + + lag--; + } + + return dynamic_cast(substexpr); +} + NumConstNode::NumConstNode(DataTree &datatree_arg, int id_arg) : ExprNode(datatree_arg), @@ -151,7 +190,12 @@ NumConstNode::NumConstNode(DataTree &datatree_arg, int id_arg) : { // Add myself to the num const map datatree.num_const_node_map[id] = this; +} +void +NumConstNode::prepareForDerivation() +{ + preparedForDerivation = true; // All derivatives are null, so non_null_derivatives is left empty } @@ -221,19 +265,51 @@ NumConstNode::toStatic(DataTree &static_datatree) const return static_datatree.AddNumConstant(datatree.num_constants.get(id)); } +int +NumConstNode::maxEndoLead() const +{ + return 0; +} -VariableNode::VariableNode(DataTree &datatree_arg, int symb_id_arg, int lag_arg, int deriv_id_arg) : +NodeID +NumConstNode::decreaseLeadsLags(int n) const +{ + return const_cast(this); +} + +NodeID +NumConstNode::substituteLeadGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const +{ + return const_cast(this); +} + +NodeID +NumConstNode::substituteLagGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const +{ + return const_cast(this); +} + +VariableNode::VariableNode(DataTree &datatree_arg, int symb_id_arg, int lag_arg) : ExprNode(datatree_arg), symb_id(symb_id_arg), type(datatree.symbol_table.getType(symb_id_arg)), - lag(lag_arg), - deriv_id(deriv_id_arg) + lag(lag_arg) { // Add myself to the variable map datatree.variable_node_map[make_pair(symb_id, lag)] = this; // It makes sense to allow a lead/lag on parameters: during steady state calibration, endogenous and parameters can be swapped - assert(lag == 0 || (type != eModelLocalVariable && type != eModFileLocalVariable && type != eUnknownFunction)); + assert(type != eUnknownFunction + && (lag == 0 || (type != eModelLocalVariable && type != eModFileLocalVariable))); +} + +void +VariableNode::prepareForDerivation() +{ + if (preparedForDerivation) + return; + + preparedForDerivation = true; // Fill in non_null_derivatives switch (type) @@ -243,9 +319,10 @@ VariableNode::VariableNode(DataTree &datatree_arg, int symb_id_arg, int lag_arg, case eExogenousDet: case eParameter: // For a variable or a parameter, the only non-null derivative is with respect to itself - non_null_derivatives.insert(deriv_id); + non_null_derivatives.insert(datatree.getDerivID(symb_id, lag)); break; case eModelLocalVariable: + datatree.local_variables_table[symb_id]->prepareForDerivation(); // Non null derivatives are those of the value of the local parameter non_null_derivatives = datatree.local_variables_table[symb_id]->non_null_derivatives; break; @@ -253,13 +330,13 @@ VariableNode::VariableNode(DataTree &datatree_arg, int symb_id_arg, int lag_arg, // Such a variable is never derived break; case eUnknownFunction: - cerr << "Attempt to construct a VariableNode with an unknown function name" << endl; + cerr << "VariableNode::prepareForDerivation: impossible case" << endl; exit(EXIT_FAILURE); } } NodeID -VariableNode::computeDerivative(int deriv_id_arg) +VariableNode::computeDerivative(int deriv_id) { switch (type) { @@ -267,12 +344,12 @@ VariableNode::computeDerivative(int deriv_id_arg) case eExogenous: case eExogenousDet: case eParameter: - if (deriv_id == deriv_id_arg) + if (deriv_id == datatree.getDerivID(symb_id, lag)) return datatree.One; else return datatree.Zero; case eModelLocalVariable: - return datatree.local_variables_table[symb_id]->getDerivative(deriv_id_arg); + return datatree.local_variables_table[symb_id]->getDerivative(deriv_id); case eModFileLocalVariable: cerr << "ModFileLocalVariable is not derivable" << endl; exit(EXIT_FAILURE); @@ -359,7 +436,7 @@ VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type, { case oMatlabDynamicModel: case oCDynamicModel: - i = datatree.getDynJacobianCol(deriv_id) + ARRAY_SUBSCRIPT_OFFSET(output_type); + i = datatree.getDynJacobianCol(datatree.getDerivID(symb_id, lag)) + ARRAY_SUBSCRIPT_OFFSET(output_type); output << "y" << LEFT_ARRAY_SUBSCRIPT(output_type) << i << RIGHT_ARRAY_SUBSCRIPT(output_type); break; case oMatlabStaticModel: @@ -590,19 +667,19 @@ VariableNode::normalizeEquation(int var_endo, vector &recursive_variables) +VariableNode::getChainRuleDerivative(int deriv_id, const map &recursive_variables) { switch (type) { @@ -610,15 +687,15 @@ VariableNode::getChainRuleDerivative(int deriv_id_arg, const map &r case eExogenous: case eExogenousDet: case eParameter: - if (deriv_id == deriv_id_arg) + if (deriv_id == datatree.getDerivID(symb_id, lag)) return datatree.One; else { //if there is in the equation a recursive variable we could use a chaine rule derivation - map::const_iterator it = recursive_variables.find(deriv_id); + map::const_iterator it = recursive_variables.find(datatree.getDerivID(symb_id, lag)); if (it != recursive_variables.end()) { - map::const_iterator it2 = derivatives.find(deriv_id_arg); + map::const_iterator it2 = derivatives.find(deriv_id); if (it2 != derivatives.end()) return it2->second; else @@ -626,9 +703,9 @@ VariableNode::getChainRuleDerivative(int deriv_id_arg, const map &r map recursive_vars2(recursive_variables); recursive_vars2.erase(it->first); //NodeID c = datatree.AddNumConstant("1"); - NodeID d = datatree.AddUMinus(it->second->getChainRuleDerivative(deriv_id_arg, recursive_vars2)); + NodeID d = datatree.AddUMinus(it->second->getChainRuleDerivative(deriv_id, recursive_vars2)); //d = datatree.AddTimes(c, d); - derivatives[deriv_id_arg] = d; + derivatives[deriv_id] = d; return d; } } @@ -636,7 +713,7 @@ VariableNode::getChainRuleDerivative(int deriv_id_arg, const map &r return datatree.Zero; } case eModelLocalVariable: - return datatree.local_variables_table[symb_id]->getChainRuleDerivative(deriv_id_arg, recursive_variables); + return datatree.local_variables_table[symb_id]->getChainRuleDerivative(deriv_id, recursive_variables); case eModFileLocalVariable: cerr << "ModFileLocalVariable is not derivable" << endl; exit(EXIT_FAILURE); @@ -656,6 +733,103 @@ VariableNode::toStatic(DataTree &static_datatree) const return static_datatree.AddVariable(datatree.symbol_table.getName(symb_id)); } +int +VariableNode::maxEndoLead() const +{ + switch(type) + { + case eEndogenous: + return max(lag, 0); + case eModelLocalVariable: + return datatree.local_variables_table[symb_id]->maxEndoLead(); + default: + return 0; + } +} + +NodeID +VariableNode::decreaseLeadsLags(int n) const +{ + switch(type) + { + case eEndogenous: + case eExogenous: + case eExogenousDet: + return datatree.AddVariable(symb_id, lag-n); + case eModelLocalVariable: + return datatree.local_variables_table[symb_id]->decreaseLeadsLags(n); + default: + return const_cast(this); + } +} + +NodeID +VariableNode::substituteLeadGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const +{ + NodeID value; + switch(type) + { + case eEndogenous: + if (lag <= 1) + return const_cast(this); + else + return createLeadAuxiliaryVarForMyself(subst_table, neweqs); + case eModelLocalVariable: + value = datatree.local_variables_table[symb_id]; + if (value->maxEndoLead() <= 1) + return const_cast(this); + else + return value->substituteLeadGreaterThanTwo(subst_table, neweqs); + default: + return const_cast(this); + } +} + +NodeID +VariableNode::substituteLagGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const +{ + VariableNode *substexpr; + subst_table_t::const_iterator it; + int cur_lag; + switch(type) + { + case eEndogenous: + if (lag >= -1) + return const_cast(this); + + it = subst_table.find(this); + if (it != subst_table.end()) + return const_cast(it->second); + + substexpr = datatree.AddVariable(symb_id, -1); + cur_lag = -2; + + // Each iteration tries to create an auxvar such that auxvar(-1)=curvar(cur_lag) + // At the beginning (resp. end) of each iteration, substexpr is an expression (possibly an auxvar) equivalent to curvar(cur_lag+1) (resp. curvar(cur_lag)) + while(cur_lag >= lag) + { + VariableNode *orig_expr = datatree.AddVariable(symb_id, cur_lag); + it = subst_table.find(orig_expr); + if (it == subst_table.end()) + { + int aux_symb_id = datatree.symbol_table.addLagAuxiliaryVar(symb_id, cur_lag+1); + neweqs.push_back(dynamic_cast(datatree.AddEqual(datatree.AddVariable(aux_symb_id, 0), substexpr))); + substexpr = datatree.AddVariable(aux_symb_id, -1); + subst_table[orig_expr] = substexpr; + } + else + substexpr = const_cast(it->second); + + cur_lag--; + } + return substexpr; + + case eModelLocalVariable: + return datatree.local_variables_table[symb_id]->substituteLagGreaterThanTwo(subst_table, neweqs); + default: + return const_cast(this); + } +} UnaryOpNode::UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const NodeID arg_arg) : ExprNode(datatree_arg), @@ -664,6 +838,17 @@ UnaryOpNode::UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const { // Add myself to the unary op map datatree.unary_op_node_map[make_pair(arg, op_code)] = this; +} + +void +UnaryOpNode::prepareForDerivation() +{ + if (preparedForDerivation) + return; + + preparedForDerivation = true; + + arg->prepareForDerivation(); // Non-null derivatives are those of the argument non_null_derivatives = arg->non_null_derivatives; @@ -1216,52 +1401,94 @@ UnaryOpNode::getChainRuleDerivative(int deriv_id, const map &recurs } NodeID -UnaryOpNode::toStatic(DataTree &static_datatree) const - { - NodeID sarg = arg->toStatic(static_datatree); - switch (op_code) - { - case oUminus: - return static_datatree.AddUMinus(sarg); - case oExp: - return static_datatree.AddExp(sarg); - case oLog: - return static_datatree.AddLog(sarg); - case oLog10: - return static_datatree.AddLog10(sarg); - case oCos: - return static_datatree.AddCos(sarg); - case oSin: - return static_datatree.AddSin(sarg); - case oTan: - return static_datatree.AddTan(sarg); - case oAcos: - return static_datatree.AddAcos(sarg); - case oAsin: - return static_datatree.AddAsin(sarg); - case oAtan: - return static_datatree.AddAtan(sarg); - case oCosh: - return static_datatree.AddCosh(sarg); - case oSinh: - return static_datatree.AddSinh(sarg); - case oTanh: - return static_datatree.AddTanh(sarg); - case oAcosh: - return static_datatree.AddAcosh(sarg); - case oAsinh: - return static_datatree.AddAsinh(sarg); - case oAtanh: - return static_datatree.AddAtanh(sarg); - case oSqrt: - return static_datatree.AddSqrt(sarg); - case oSteadyState: - return static_datatree.AddSteadyState(sarg); - } - // Suppress GCC warning - exit(EXIT_FAILURE); - } +UnaryOpNode::buildSimilarUnaryOpNode(NodeID alt_arg, DataTree &alt_datatree) const +{ + switch (op_code) + { + case oUminus: + return alt_datatree.AddUMinus(alt_arg); + case oExp: + return alt_datatree.AddExp(alt_arg); + case oLog: + return alt_datatree.AddLog(alt_arg); + case oLog10: + return alt_datatree.AddLog10(alt_arg); + case oCos: + return alt_datatree.AddCos(alt_arg); + case oSin: + return alt_datatree.AddSin(alt_arg); + case oTan: + return alt_datatree.AddTan(alt_arg); + case oAcos: + return alt_datatree.AddAcos(alt_arg); + case oAsin: + return alt_datatree.AddAsin(alt_arg); + case oAtan: + return alt_datatree.AddAtan(alt_arg); + case oCosh: + return alt_datatree.AddCosh(alt_arg); + case oSinh: + return alt_datatree.AddSinh(alt_arg); + case oTanh: + return alt_datatree.AddTanh(alt_arg); + case oAcosh: + return alt_datatree.AddAcosh(alt_arg); + case oAsinh: + return alt_datatree.AddAsinh(alt_arg); + case oAtanh: + return alt_datatree.AddAtanh(alt_arg); + case oSqrt: + return alt_datatree.AddSqrt(alt_arg); + case oSteadyState: + return alt_datatree.AddSteadyState(alt_arg); + } + // Suppress GCC warning + exit(EXIT_FAILURE); +} +NodeID +UnaryOpNode::toStatic(DataTree &static_datatree) const +{ + NodeID sarg = arg->toStatic(static_datatree); + return buildSimilarUnaryOpNode(sarg, static_datatree); +} + +int +UnaryOpNode::maxEndoLead() const +{ + return arg->maxEndoLead(); +} + +NodeID +UnaryOpNode::decreaseLeadsLags(int n) const +{ + NodeID argsubst = arg->decreaseLeadsLags(n); + return buildSimilarUnaryOpNode(argsubst, datatree); +} + +NodeID +UnaryOpNode::substituteLeadGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const +{ + if (op_code == oUminus) + { + NodeID argsubst = arg->substituteLeadGreaterThanTwo(subst_table, neweqs); + return buildSimilarUnaryOpNode(argsubst, datatree); + } + else + { + if (maxEndoLead() >= 2) + return createLeadAuxiliaryVarForMyself(subst_table, neweqs); + else + return const_cast(this); + } +} + +NodeID +UnaryOpNode::substituteLagGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const +{ + NodeID argsubst = arg->substituteLagGreaterThanTwo(subst_table, neweqs); + return buildSimilarUnaryOpNode(argsubst, datatree); +} BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg, BinaryOpcode op_code_arg, const NodeID arg2_arg) : @@ -1271,6 +1498,18 @@ BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg, op_code(op_code_arg) { datatree.binary_op_node_map[make_pair(make_pair(arg1, arg2), op_code)] = this; +} + +void +BinaryOpNode::prepareForDerivation() +{ + if (preparedForDerivation) + return; + + preparedForDerivation = true; + + arg1->prepareForDerivation(); + arg2->prepareForDerivation(); // Non-null derivatives are the union of those of the arguments // Compute set union of arg1->non_null_derivatives and arg2->non_null_derivatives @@ -2062,45 +2301,107 @@ BinaryOpNode::getChainRuleDerivative(int deriv_id, const map &recur } NodeID -BinaryOpNode::toStatic(DataTree &static_datatree) const - { - NodeID sarg1 = arg1->toStatic(static_datatree); - NodeID sarg2 = arg2->toStatic(static_datatree); - switch (op_code) - { - case oPlus: - return static_datatree.AddPlus(sarg1, sarg2); - case oMinus: - return static_datatree.AddMinus(sarg1, sarg2); - case oTimes: - return static_datatree.AddTimes(sarg1, sarg2); - case oDivide: - return static_datatree.AddDivide(sarg1, sarg2); - case oPower: - return static_datatree.AddPower(sarg1, sarg2); - case oEqual: - return static_datatree.AddEqual(sarg1, sarg2); - case oMax: - return static_datatree.AddMax(sarg1, sarg2); - case oMin: - return static_datatree.AddMin(sarg1, sarg2); - case oLess: - return static_datatree.AddLess(sarg1, sarg2); - case oGreater: - return static_datatree.AddGreater(sarg1, sarg2); - case oLessEqual: - return static_datatree.AddLessEqual(sarg1, sarg2); - case oGreaterEqual: - return static_datatree.AddGreaterEqual(sarg1, sarg2); - case oEqualEqual: - return static_datatree.AddEqualEqual(sarg1, sarg2); - case oDifferent: - return static_datatree.AddDifferent(sarg1, sarg2); - } - // Suppress GCC warning - exit(EXIT_FAILURE); - } +BinaryOpNode::buildSimilarBinaryOpNode(NodeID alt_arg1, NodeID alt_arg2, DataTree &alt_datatree) const +{ + switch (op_code) + { + case oPlus: + return alt_datatree.AddPlus(alt_arg1, alt_arg2); + case oMinus: + return alt_datatree.AddMinus(alt_arg1, alt_arg2); + case oTimes: + return alt_datatree.AddTimes(alt_arg1, alt_arg2); + case oDivide: + return alt_datatree.AddDivide(alt_arg1, alt_arg2); + case oPower: + return alt_datatree.AddPower(alt_arg1, alt_arg2); + case oEqual: + return alt_datatree.AddEqual(alt_arg1, alt_arg2); + case oMax: + return alt_datatree.AddMax(alt_arg1, alt_arg2); + case oMin: + return alt_datatree.AddMin(alt_arg1, alt_arg2); + case oLess: + return alt_datatree.AddLess(alt_arg1, alt_arg2); + case oGreater: + return alt_datatree.AddGreater(alt_arg1, alt_arg2); + case oLessEqual: + return alt_datatree.AddLessEqual(alt_arg1, alt_arg2); + case oGreaterEqual: + return alt_datatree.AddGreaterEqual(alt_arg1, alt_arg2); + case oEqualEqual: + return alt_datatree.AddEqualEqual(alt_arg1, alt_arg2); + case oDifferent: + return alt_datatree.AddDifferent(alt_arg1, alt_arg2); + } + // Suppress GCC warning + exit(EXIT_FAILURE); +} +NodeID +BinaryOpNode::toStatic(DataTree &static_datatree) const +{ + NodeID sarg1 = arg1->toStatic(static_datatree); + NodeID sarg2 = arg2->toStatic(static_datatree); + return buildSimilarBinaryOpNode(sarg1, sarg2, static_datatree); +} + +int +BinaryOpNode::maxEndoLead() const +{ + return max(arg1->maxEndoLead(), arg2->maxEndoLead()); +} + +NodeID +BinaryOpNode::decreaseLeadsLags(int n) const +{ + NodeID arg1subst = arg1->decreaseLeadsLags(n); + NodeID arg2subst = arg2->decreaseLeadsLags(n); + return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree); +} + +NodeID +BinaryOpNode::substituteLeadGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const +{ + NodeID arg1subst, arg2subst; + int maxlead1 = arg1->maxEndoLead(), maxlead2 = arg2->maxEndoLead(); + + if (maxlead1 < 2 && maxlead2 < 2) + return const_cast(this); + + switch(op_code) + { + case oPlus: + case oMinus: + case oEqual: + arg1subst = maxlead1 >= 2 ? arg1->substituteLeadGreaterThanTwo(subst_table, neweqs) : arg1; + arg2subst = maxlead2 >= 2 ? arg2->substituteLeadGreaterThanTwo(subst_table, neweqs) : arg2; + return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree); + case oTimes: + case oDivide: + if (maxlead1 >= 2 && maxlead2 == 0) + { + arg1subst = arg1->substituteLeadGreaterThanTwo(subst_table, neweqs); + return buildSimilarBinaryOpNode(arg1subst, arg2, datatree); + } + if (maxlead1 == 0 && maxlead2 >= 2 && op_code == oTimes) + { + arg2subst = arg2->substituteLeadGreaterThanTwo(subst_table, neweqs); + return buildSimilarBinaryOpNode(arg1, arg2subst, datatree); + } + return createLeadAuxiliaryVarForMyself(subst_table, neweqs); + default: + return createLeadAuxiliaryVarForMyself(subst_table, neweqs); + } +} + +NodeID +BinaryOpNode::substituteLagGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const +{ + NodeID arg1subst = arg1->substituteLagGreaterThanTwo(subst_table, neweqs); + NodeID arg2subst = arg2->substituteLagGreaterThanTwo(subst_table, neweqs); + return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree); +} TrinaryOpNode::TrinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg, TrinaryOpcode op_code_arg, const NodeID arg2_arg, const NodeID arg3_arg) : @@ -2111,6 +2412,19 @@ TrinaryOpNode::TrinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg, op_code(op_code_arg) { datatree.trinary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), arg3), op_code)] = this; +} + +void +TrinaryOpNode::prepareForDerivation() +{ + if (preparedForDerivation) + return; + + preparedForDerivation = true; + + arg1->prepareForDerivation(); + arg2->prepareForDerivation(); + arg3->prepareForDerivation(); // Non-null derivatives are the union of those of the arguments // Compute set union of arg{1,2,3}->non_null_derivatives @@ -2412,20 +2726,58 @@ TrinaryOpNode::getChainRuleDerivative(int deriv_id, const map &recu } NodeID -TrinaryOpNode::toStatic(DataTree &static_datatree) const - { - NodeID sarg1 = arg1->toStatic(static_datatree); - NodeID sarg2 = arg2->toStatic(static_datatree); - NodeID sarg3 = arg3->toStatic(static_datatree); - switch (op_code) - { - case oNormcdf: - return static_datatree.AddNormcdf(sarg1, sarg2, sarg3); - } - // Suppress GCC warning - exit(EXIT_FAILURE); - } +TrinaryOpNode::buildSimilarTrinaryOpNode(NodeID alt_arg1, NodeID alt_arg2, NodeID alt_arg3, DataTree &alt_datatree) const +{ + switch (op_code) + { + case oNormcdf: + return alt_datatree.AddNormcdf(alt_arg1, alt_arg2, alt_arg3); + } + // Suppress GCC warning + exit(EXIT_FAILURE); +} +NodeID +TrinaryOpNode::toStatic(DataTree &static_datatree) const +{ + NodeID sarg1 = arg1->toStatic(static_datatree); + NodeID sarg2 = arg2->toStatic(static_datatree); + NodeID sarg3 = arg3->toStatic(static_datatree); + return buildSimilarTrinaryOpNode(sarg1, sarg2, sarg3, static_datatree); +} + +int +TrinaryOpNode::maxEndoLead() const +{ + return max(arg1->maxEndoLead(), max(arg2->maxEndoLead(), arg3->maxEndoLead())); +} + +NodeID +TrinaryOpNode::decreaseLeadsLags(int n) const +{ + NodeID arg1subst = arg1->decreaseLeadsLags(n); + NodeID arg2subst = arg2->decreaseLeadsLags(n); + NodeID arg3subst = arg3->decreaseLeadsLags(n); + return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree); +} + +NodeID +TrinaryOpNode::substituteLeadGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const +{ + if (maxEndoLead() < 2) + return const_cast(this); + else + return createLeadAuxiliaryVarForMyself(subst_table, neweqs); +} + +NodeID +TrinaryOpNode::substituteLagGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const +{ + NodeID arg1subst = arg1->substituteLagGreaterThanTwo(subst_table, neweqs); + NodeID arg2subst = arg2->substituteLagGreaterThanTwo(subst_table, neweqs); + NodeID arg3subst = arg3->substituteLagGreaterThanTwo(subst_table, neweqs); + return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree); +} UnknownFunctionNode::UnknownFunctionNode(DataTree &datatree_arg, int symb_id_arg, @@ -2436,6 +2788,13 @@ UnknownFunctionNode::UnknownFunctionNode(DataTree &datatree_arg, { } +void +UnknownFunctionNode::prepareForDerivation() +{ + cerr << "UnknownFunctionNode::prepareForDerivation: operation impossible!" << endl; + exit(EXIT_FAILURE); +} + NodeID UnknownFunctionNode::computeDerivative(int deriv_id) { @@ -2550,3 +2909,34 @@ UnknownFunctionNode::toStatic(DataTree &static_datatree) const static_arguments.push_back((*it)->toStatic(static_datatree)); return static_datatree.AddUnknownFunction(datatree.symbol_table.getName(symb_id), static_arguments); } + +int +UnknownFunctionNode::maxEndoLead() const +{ + int val = 0; + for(vector::const_iterator it = arguments.begin(); + it != arguments.end(); it++) + val = max(val, (*it)->maxEndoLead()); + return val; +} + +NodeID +UnknownFunctionNode::decreaseLeadsLags(int n) const +{ + cerr << "UnknownFunctionNode::decreaseLeadsLags: not implemented!" << endl; + exit(EXIT_FAILURE); +} + +NodeID +UnknownFunctionNode::substituteLeadGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const +{ + cerr << "UnknownFunctionNode::substituteLeadGreaterThanTwo: not implemented!" << endl; + exit(EXIT_FAILURE); +} + +NodeID +UnknownFunctionNode::substituteLagGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const +{ + cerr << "UnknownFunctionNode::substituteLagGreaterThanTwo: not implemented!" << endl; + exit(EXIT_FAILURE); +} diff --git a/ExprNode.hh b/ExprNode.hh index 2e0e4653..08bafae8 100644 --- a/ExprNode.hh +++ b/ExprNode.hh @@ -31,6 +31,8 @@ using namespace std; #include "CodeInterpreter.hh" class DataTree; +class VariableNode; +class BinaryOpNode; typedef class ExprNode *NodeID; @@ -120,6 +122,9 @@ protected: //! Index number int idx; + //! Is the data member non_null_derivatives initialized ? + bool preparedForDerivation; + //! Set of derivation IDs with respect to which the derivative is potentially non-null set non_null_derivatives; @@ -134,6 +139,9 @@ public: ExprNode(DataTree &datatree_arg); virtual ~ExprNode(); + //! Initializes data member non_null_derivatives + virtual void prepareForDerivation() = 0; + //! Returns derivative w.r. to derivation ID /*! Uses a symbolic a priori to pre-detect null derivatives, and caches the result for other derivatives (to avoid computing it several times) For an equal node, returns the derivative of lhs minus rhs */ @@ -216,6 +224,50 @@ public: virtual NodeID toStatic(DataTree &static_datatree) const = 0; //! Try to normalize an equation linear in its endogenous variable virtual pair normalizeEquation(int symb_id_endo, vector > > &List_of_Op_RHS) const = 0; + + //! Returns the maximum lead of endogenous in this expression + /*! Always returns a non-negative value */ + virtual int maxEndoLead() const = 0; + + //! Returns a new expression where all the leads/lags have been shifted backwards by the same amount + /*! + Only acts on endogenous, exogenous, exogenous det + \param[in] n The number of lags by which to shift + \return The same expression except that leads/lags have been shifted backwards + */ + virtual NodeID decreaseLeadsLags(int n) const = 0; + + //! Type for the substitution map used in the process of creating auxiliary vars for leads >= 2 + typedef map subst_table_t; + + //! Creates auxiliary lead variables corresponding to this expression + /*! + If maximum endogenous lead >= 3, this method will also create intermediary auxiliary var, and will add the equations of the form aux1 = aux2(+1) to the substitution table. + \pre This expression is assumed to have maximum endogenous lead >= 2 + \param[in,out] subst_table The table to which new auxiliary variables and their correspondance will be added + \return The new variable node corresponding to the current expression + */ + VariableNode *createLeadAuxiliaryVarForMyself(subst_table_t &subst_table, vector &neweqs) const; + + //! Constructs a new expression where sub-expressions with max endo lead >= 2 have been replaced by auxiliary variables + /*! + \param[in,out] subst_table Map used to store expressions that have already be substituted and their corresponding variable, in order to avoid creating two auxiliary variables for the same sub-expr. + \param[out] neweqs Equations to be added to the model to match the creation of auxiliary variables. + + If the method detects a sub-expr which needs to be substituted, two cases are possible: + - if this expr is in the table, then it will use the corresponding variable and return the substituted expression + - if this expr is not in the table, then it will create an auxiliary endogenous variable, add the substitution in the table and return the substituted expression + + \return A new equivalent expression where sub-expressions with max endo lead >= 2 have been replaced by auxiliary variables + */ + virtual NodeID substituteLeadGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const = 0; + + //! Constructs a new expression where endo variables with max endo lag >= 2 have been replaced by auxiliary variables + /*! + \param[in,out] subst_table Map used to store expressions that have already be substituted and their corresponding variable, in order to avoid creating two auxiliary variables for the same sub-expr. + \param[out] neweqs Equations to be added to the model to match the creation of auxiliary variables. + */ + virtual NodeID substituteLagGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const = 0; }; //! Object used to compare two nodes (using their indexes) @@ -237,6 +289,7 @@ private: virtual NodeID computeDerivative(int deriv_id); public: NumConstNode(DataTree &datatree_arg, int id_arg); + virtual void prepareForDerivation(); virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const; virtual void collectVariables(SymbolType type_arg, set > &result) const; virtual void collectTemporary_terms(const temporary_terms_type &temporary_terms, Model_Block *ModelBlock, int Curr_Block) const; @@ -245,6 +298,10 @@ public: virtual NodeID toStatic(DataTree &static_datatree) const; virtual pair normalizeEquation(int symb_id_endo, vector > > &List_of_Op_RHS) const; virtual NodeID getChainRuleDerivative(int deriv_id, const map &recursive_variables); + virtual int maxEndoLead() const; + virtual NodeID decreaseLeadsLags(int n) const; + virtual NodeID substituteLeadGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const; + virtual NodeID substituteLagGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const; }; //! Symbol or variable node @@ -255,11 +312,10 @@ private: const int symb_id; const SymbolType type; const int lag; - //! Derivation ID - const int deriv_id; - virtual NodeID computeDerivative(int deriv_id_arg); + virtual NodeID computeDerivative(int deriv_id); public: - VariableNode(DataTree &datatree_arg, int symb_id_arg, int lag_arg, int deriv_id_arg); + VariableNode(DataTree &datatree_arg, int symb_id_arg, int lag_arg); + virtual void prepareForDerivation(); virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const; virtual void collectVariables(SymbolType type_arg, set > &result) const; virtual void computeTemporaryTerms(map &reference_count, @@ -276,6 +332,10 @@ public: int get_symb_id() const { return symb_id; }; virtual pair normalizeEquation(int symb_id_endo, vector > > &List_of_Op_RHS) const; virtual NodeID getChainRuleDerivative(int deriv_id, const map &recursive_variables); + virtual int maxEndoLead() const; + virtual NodeID decreaseLeadsLags(int n) const; + virtual NodeID substituteLeadGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const; + virtual NodeID substituteLagGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const; }; //! Unary operator node @@ -290,6 +350,7 @@ private: NodeID composeDerivatives(NodeID darg); public: UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const NodeID arg_arg); + virtual void prepareForDerivation(); virtual void computeTemporaryTerms(map &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const; virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const; virtual void computeTemporaryTerms(map &reference_count, @@ -311,6 +372,12 @@ public: virtual NodeID toStatic(DataTree &static_datatree) const; virtual pair normalizeEquation(int symb_id_endo, vector > > &List_of_Op_RHS) const; virtual NodeID getChainRuleDerivative(int deriv_id, const map &recursive_variables); + virtual int maxEndoLead() const; + virtual NodeID decreaseLeadsLags(int n) const; + virtual NodeID substituteLeadGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const; + //! Creates another UnaryOpNode with the same opcode, but with a possibly different datatree and argument + NodeID buildSimilarUnaryOpNode(NodeID alt_arg, DataTree &alt_datatree) const; + virtual NodeID substituteLagGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const; }; //! Binary operator node @@ -326,6 +393,7 @@ private: public: BinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg, BinaryOpcode op_code_arg, const NodeID arg2_arg); + virtual void prepareForDerivation(); virtual int precedence(ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const; virtual void computeTemporaryTerms(map &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const; virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const; @@ -351,6 +419,12 @@ public: virtual NodeID toStatic(DataTree &static_datatree) const; virtual pair normalizeEquation(int symb_id_endo, vector > > &List_of_Op_RHS) const; virtual NodeID getChainRuleDerivative(int deriv_id, const map &recursive_variables); + virtual int maxEndoLead() const; + virtual NodeID decreaseLeadsLags(int n) const; + virtual NodeID substituteLeadGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const; + //! Creates another BinaryOpNode with the same opcode, but with a possibly different datatree and arguments + NodeID buildSimilarBinaryOpNode(NodeID alt_arg1, NodeID alt_arg2, DataTree &alt_datatree) const; + virtual NodeID substituteLagGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const; }; //! Trinary operator node @@ -367,6 +441,7 @@ private: public: TrinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg, TrinaryOpcode op_code_arg, const NodeID arg2_arg, const NodeID arg3_arg); + virtual void prepareForDerivation(); virtual int precedence(ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const; virtual void computeTemporaryTerms(map &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const; virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const; @@ -385,6 +460,12 @@ public: virtual NodeID toStatic(DataTree &static_datatree) const; virtual pair normalizeEquation(int symb_id_endo, vector > > &List_of_Op_RHS) const; virtual NodeID getChainRuleDerivative(int deriv_id, const map &recursive_variables); + virtual int maxEndoLead() const; + virtual NodeID decreaseLeadsLags(int n) const; + virtual NodeID substituteLeadGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const; + //! Creates another TrinaryOpNode with the same opcode, but with a possibly different datatree and arguments + NodeID buildSimilarTrinaryOpNode(NodeID alt_arg1, NodeID alt_arg2, NodeID alt_arg3, DataTree &alt_datatree) const; + virtual NodeID substituteLagGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const; }; //! Unknown function node @@ -397,6 +478,7 @@ private: public: UnknownFunctionNode(DataTree &datatree_arg, int symb_id_arg, const vector &arguments_arg); + virtual void prepareForDerivation(); virtual void computeTemporaryTerms(map &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const; virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const; virtual void computeTemporaryTerms(map &reference_count, @@ -413,6 +495,10 @@ public: virtual NodeID toStatic(DataTree &static_datatree) const; virtual pair normalizeEquation(int symb_id_endo, vector > > &List_of_Op_RHS) const; virtual NodeID getChainRuleDerivative(int deriv_id, const map &recursive_variables); + virtual int maxEndoLead() const; + virtual NodeID decreaseLeadsLags(int n) const; + virtual NodeID substituteLeadGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const; + virtual NodeID substituteLagGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const; }; #endif diff --git a/ModFile.cc b/ModFile.cc index cadb916d..155eed70 100644 --- a/ModFile.cc +++ b/ModFile.cc @@ -129,6 +129,21 @@ ModFile::checkPass() cerr << "ERROR: In 'model' block, can't use option 'bytecode' without option 'block'" << endl; exit(EXIT_FAILURE); } +} + +void +ModFile::transformPass() +{ + // In stochastic models, create auxiliary vars for leads and lags greater than 2 + if (mod_file_struct.stoch_simul_present + || mod_file_struct.estimation_present + || mod_file_struct.forecast_present + || mod_file_struct.osr_present + || mod_file_struct.ramsey_policy_present) + { + dynamic_model.substituteLeadGreaterThanTwo(); + dynamic_model.substituteLagGreaterThanTwo(); + } // Freeze the symbol table symbol_table.freeze(); @@ -294,7 +309,30 @@ ModFile::writeOutputFiles(const string &basename, bool clear_all) const // Print statements for(vector::const_iterator it = statements.begin(); it != statements.end(); it++) - (*it)->writeOutput(mOutputFile, basename); + { + (*it)->writeOutput(mOutputFile, basename); + + // Special treatment for initval block: insert initial values for the auxiliary variables + InitValStatement *ivs = dynamic_cast(*it); + if (ivs != NULL) + { + if (!byte_code) + static_model.writeAuxVarInitval(mOutputFile); + else + static_dll_model.writeAuxVarInitval(mOutputFile); + ivs->writeOutputPostInit(mOutputFile); + } + + // Special treatment for load params and steady state statement: insert initial values for the auxiliary variables + LoadParamsAndSteadyStateStatement *lpass = dynamic_cast(*it); + if (lpass) + { + if (!byte_code) + static_model.writeAuxVarInitval(mOutputFile); + else + static_dll_model.writeAuxVarInitval(mOutputFile); + } + } // Remove path for block option with M-files if (block && !byte_code) diff --git a/ModFile.hh b/ModFile.hh index f9e69a25..6a7f3baa 100644 --- a/ModFile.hh +++ b/ModFile.hh @@ -81,6 +81,8 @@ public: //! Do some checking and fills mod_file_struct /*! \todo add check for number of equations and endogenous if ramsey_policy is present */ void checkPass(); + //! Perform some transformations on the model (creation of auxiliary vars and equations) + void transformPass(); //! Execute computations /*! \param no_tmp_terms if true, no temporary terms will be computed in the static and dynamic files */ void computingPass(bool no_tmp_terms); diff --git a/ModelTree.cc b/ModelTree.cc index d2a88e30..ac19cc4c 100644 --- a/ModelTree.cc +++ b/ModelTree.cc @@ -285,3 +285,12 @@ ModelTree::addEquationTags(int i, const string &key, const string &value) { equation_tags.push_back(make_pair(i, make_pair(key, value))); } + +void +ModelTree::addAuxEquation(NodeID eq) +{ + BinaryOpNode *beq = dynamic_cast(eq); + assert(beq != NULL && beq->get_op_code() == oEqual); + + aux_equations.push_back(beq); +} diff --git a/ModelTree.hh b/ModelTree.hh index 29e10f76..aa2e4c18 100644 --- a/ModelTree.hh +++ b/ModelTree.hh @@ -24,6 +24,7 @@ using namespace std; #include #include +#include #include #include @@ -33,9 +34,12 @@ using namespace std; class ModelTree : public DataTree { protected: - //! Stores declared equations + //! Stores declared and generated auxiliary equations vector equations; + //! Only stores generated auxiliary equations, in an order meaningful for evaluation + deque aux_equations; + //! Stores equation tags vector > > equation_tags; @@ -102,6 +106,8 @@ public: void addEquation(NodeID eq); //! Adds tags to equation number i void addEquationTags(int i, const string &key, const string &value); + //! Declare a node as an auxiliary equation of the model, adding it at the end of the list of auxiliary equations + void addAuxEquation(NodeID eq); //! Returns the number of equations in the model int equation_number() const; }; diff --git a/NumericalInitialization.cc b/NumericalInitialization.cc index 2557bb83..64a14b49 100644 --- a/NumericalInitialization.cc +++ b/NumericalInitialization.cc @@ -120,14 +120,18 @@ InitValStatement::writeOutput(ostream &output, const string &basename) const output << "options_.initval_file = 0;" << endl; writeInitValues(output); +} - output << "oo_.endo_simul=[oo_.steady_state*ones(1,M_.maximum_lag)];\n"; - output << "if M_.exo_nbr > 0;\n"; - output << "\too_.exo_simul = [ones(M_.maximum_lag,1)*oo_.exo_steady_state'];\n"; - output <<"end;\n"; - output << "if M_.exo_det_nbr > 0;\n"; - output << "\too_.exo_det_simul = [ones(M_.maximum_lag,1)*oo_.exo_det_steady_state'];\n"; - output <<"end;\n"; +void +InitValStatement::writeOutputPostInit(ostream &output) const +{ + output << "oo_.endo_simul=[oo_.steady_state*ones(1,M_.maximum_lag)];" << endl + << "if M_.exo_nbr > 0;" << endl + << "\too_.exo_simul = [ones(M_.maximum_lag,1)*oo_.exo_steady_state'];" << endl + <<"end;" << endl + << "if M_.exo_det_nbr > 0;" << endl + << "\too_.exo_det_simul = [ones(M_.maximum_lag,1)*oo_.exo_det_steady_state'];" << endl + <<"end;" << endl; } diff --git a/NumericalInitialization.hh b/NumericalInitialization.hh index c444fc62..83799418 100644 --- a/NumericalInitialization.hh +++ b/NumericalInitialization.hh @@ -70,6 +70,8 @@ public: InitValStatement(const init_values_type &init_values_arg, const SymbolTable &symbol_table_arg); virtual void writeOutput(ostream &output, const string &basename) const; + //! Writes initializations for oo_.endo_simul, oo_.exo_simul and oo_.exo_det_simul + void writeOutputPostInit(ostream &output) const; }; class EndValStatement : public InitOrEndValStatement diff --git a/StaticDllModel.cc b/StaticDllModel.cc index 56457b53..b135122c 100644 --- a/StaticDllModel.cc +++ b/StaticDllModel.cc @@ -37,11 +37,6 @@ StaticDllModel::StaticDllModel(SymbolTable &symbol_table_arg, NumericalConstants &num_constants_arg) : ModelTree(symbol_table_arg, num_constants_arg), - max_lag(0), max_lead(0), - max_endo_lag(0), max_endo_lead(0), - max_exo_lag(0), max_exo_lead(0), - max_exo_det_lag(0), max_exo_det_lead(0), - dynJacobianColsNbr(0), cutoff(1e-15), mfs(0), block_triangular(symbol_table_arg, num_constants_arg) @@ -799,9 +794,6 @@ StaticDllModel::computingPass(const eval_context_type &eval_context, bool no_tmp { assert(block); - // Computes static jacobian columns - computeStatJacobianCols(); - // Compute derivatives w.r. to all endogenous, and possibly exogenous and exogenous deterministic set vars; for (deriv_id_table_t::const_iterator it = deriv_id_table.begin(); @@ -875,29 +867,6 @@ StaticDllModel::writeStaticFile(const string &basename, bool block) const block_triangular.incidencematrix.Free_IM(); } -int -StaticDllModel::computeDerivID(int symb_id, int lag) -{ - // Check if static variable already has a deriv_id - pair key = make_pair(symb_id, lag); - deriv_id_table_t::const_iterator it = deriv_id_table.find(key); - if (it != deriv_id_table.end()) - return it->second; - - // Create a new deriv_id - int deriv_id = deriv_id_table.size(); - - deriv_id_table[key] = deriv_id; - inv_deriv_id_table.push_back(key); - - SymbolType type = symbol_table.getType(symb_id); - - if (type == eEndogenous) - dynJacobianColsNbr++; - - return deriv_id; -} - SymbolType StaticDllModel::getTypeByDerivID(int deriv_id) const throw (UnknownDerivIDException) { @@ -925,72 +894,12 @@ StaticDllModel::getSymbIDByDerivID(int deriv_id) const throw (UnknownDerivIDExce int StaticDllModel::getDerivID(int symb_id, int lag) const throw (UnknownDerivIDException) { - deriv_id_table_t::const_iterator it = deriv_id_table.find(make_pair(symb_id, lag)); - if (it == deriv_id_table.end()) - throw UnknownDerivIDException(); + if (symbol_table.getType(symb_id) == eEndogenous) + return symb_id; else - return it->second; + return -1; } -void -StaticDllModel::computeStatJacobianCols() -{ - /* Sort the static endogenous variables by lexicographic order over (lag, type_specific_symbol_id) - and fill the static columns for exogenous and exogenous deterministic */ - map, int> ordered_dyn_endo; - - for (deriv_id_table_t::const_iterator it = deriv_id_table.begin(); - it != deriv_id_table.end(); it++) - { - const int &symb_id = it->first.first; - const int &lag = it->first.second; - const int &deriv_id = it->second; - SymbolType type = symbol_table.getType(symb_id); - int tsid = symbol_table.getTypeSpecificID(symb_id); - - switch (type) - { - case eEndogenous: - ordered_dyn_endo[make_pair(lag, tsid)] = deriv_id; - break; - case eExogenous: - // At this point, dynJacobianColsNbr contains the number of static endogenous - break; - case eExogenousDet: - // At this point, dynJacobianColsNbr contains the number of static endogenous - break; - case eParameter: - // We don't assign a static jacobian column to parameters - break; - case eModelLocalVariable: - // We don't assign a static jacobian column to model local variables - break; - default: - // Shut up GCC - cerr << "StaticDllModel::computeStatJacobianCols: impossible case" << endl; - exit(EXIT_FAILURE); - } - } - - // Fill in static jacobian columns for endogenous - int sorted_id = 0; - for (map, int>::const_iterator it = ordered_dyn_endo.begin(); - it != ordered_dyn_endo.end(); it++) - dyn_jacobian_cols_table[it->second] = sorted_id++; - -} - -int -StaticDllModel::getDynJacobianCol(int deriv_id) const throw (UnknownDerivIDException) -{ - map::const_iterator it = dyn_jacobian_cols_table.find(deriv_id); - if (it == dyn_jacobian_cols_table.end()) - throw UnknownDerivIDException(); - else - return it->second; -} - - void StaticDllModel::computeChainRuleJacobian(Model_Block *ModelBlock) { @@ -1110,4 +1019,12 @@ StaticDllModel::hessianHelper(ostream &output, int row_nb, int col_nb, ExprNodeO output << RIGHT_ARRAY_SUBSCRIPT(output_type); } - +void +StaticDllModel::writeAuxVarInitval(ostream &output) const +{ + for(int i = 0; i < (int) aux_equations.size(); i++) + { + dynamic_cast(aux_equations[i])->writeOutput(output); + output << ";" << endl; + } +} diff --git a/StaticDllModel.hh b/StaticDllModel.hh index 26f71250..8bfc4330 100644 --- a/StaticDllModel.hh +++ b/StaticDllModel.hh @@ -37,27 +37,6 @@ private: //! Maps a deriv ID to a pair (symbol_id, lag) vector > inv_deriv_id_table; - //! Maps a deriv_id to the column index of the static Jacobian - /*! Contains only endogenous, exogenous and exogenous deterministic */ - map dyn_jacobian_cols_table; - - //! Maximum lag and lead over all types of variables (positive values) - /*! Set by computeDerivID() */ - int max_lag, max_lead; - //! Maximum lag and lead over endogenous variables (positive values) - /*! Set by computeDerivID() */ - int max_endo_lag, max_endo_lead; - //! Maximum lag and lead over exogenous variables (positive values) - /*! Set by computeDerivID() */ - int max_exo_lag, max_exo_lead; - //! Maximum lag and lead over deterministic exogenous variables (positive values) - /*! Set by computeDerivID() */ - int max_exo_det_lag, max_exo_det_lead; - - //! Number of columns of static jacobian - /*! Set by computeDerivID() and computeDynJacobianCols() */ - int dynJacobianColsNbr; - //! Temporary terms for the file containing parameters dervicatives temporary_terms_type params_derivs_temporary_terms; @@ -91,7 +70,6 @@ private: //! Write chain rule derivative code of an equation w.r. to a variable void compileChainRuleDerivative(ofstream &code_file, int eq, int var, int lag, map_idx_type &map_idx) const; - virtual int computeDerivID(int symb_id, int lag); //! Get the type corresponding to a derivation ID SymbolType getTypeByDerivID(int deriv_id) const throw (UnknownDerivIDException); //! Get the lag corresponding to a derivation ID @@ -150,8 +128,10 @@ public: //! Writes LaTeX file with the equations of the static model void writeLatexFile(const string &basename) const; + //! Writes initializations in oo_.steady_state for the auxiliary variables + void writeAuxVarInitval(ostream &output) const; + virtual int getDerivID(int symb_id, int lag) const throw (UnknownDerivIDException); - virtual int getDynJacobianCol(int deriv_id) const throw (UnknownDerivIDException); }; #endif diff --git a/StaticModel.cc b/StaticModel.cc index bda4f1ec..d848b739 100644 --- a/StaticModel.cc +++ b/StaticModel.cc @@ -210,22 +210,13 @@ StaticModel::computingPass(bool block, bool hessian, bool no_tmp_terms) computeTemporaryTerms(true); } -int -StaticModel::computeDerivID(int symb_id, int lag) -{ - if (symbol_table.getType(symb_id) == eEndogenous) - return symb_id; - else - return -1; -} - int StaticModel::getDerivID(int symb_id, int lag) const throw (UnknownDerivIDException) { if (symbol_table.getType(symb_id) == eEndogenous) return symb_id; else - throw UnknownDerivIDException(); + return -1; } void @@ -663,3 +654,13 @@ StaticModel::writeStaticBlockMFSFile(ostream &output, const string &func_name) c output << " end" << endl << "end" << endl; } + +void +StaticModel::writeAuxVarInitval(ostream &output) const +{ + for(int i = 0; i < (int) aux_equations.size(); i++) + { + dynamic_cast(aux_equations[i])->writeOutput(output); + output << ";" << endl; + } +} diff --git a/StaticModel.hh b/StaticModel.hh index ddc5f0ea..a55e8cad 100644 --- a/StaticModel.hh +++ b/StaticModel.hh @@ -53,8 +53,6 @@ private: //! Writes static model file (block+MFS version) void writeStaticBlockMFSFile(ostream &output, const string &func_name) const; - virtual int computeDerivID(int symb_id, int lag); - //! Computes normalization of the static model void computeNormalization(); @@ -103,6 +101,9 @@ public: //! Writes LaTeX file with the equations of the static model void writeLatexFile(const string &basename) const; + //! Writes initializations in oo_.steady_state for the auxiliary variables + void writeAuxVarInitval(ostream &output) const; + virtual int getDerivID(int symb_id, int lag) const throw (UnknownDerivIDException); }; diff --git a/SymbolTable.cc b/SymbolTable.cc index e73b3771..d82dc9ff 100644 --- a/SymbolTable.cc +++ b/SymbolTable.cc @@ -19,6 +19,7 @@ #include #include +#include #include "SymbolTable.hh" @@ -26,7 +27,7 @@ SymbolTable::SymbolTable() : frozen(false), size(0) { } -void +int SymbolTable::addSymbol(const string &name, SymbolType type, const string &tex_name) throw (AlreadyDeclaredException, FrozenException) { if (frozen) @@ -46,9 +47,11 @@ SymbolTable::addSymbol(const string &name, SymbolType type, const string &tex_na type_table.push_back(type); name_table.push_back(name); tex_name_table.push_back(tex_name); + + return id; } -void +int SymbolTable::addSymbol(const string &name, SymbolType type) throw (AlreadyDeclaredException, FrozenException) { // Construct "tex_name" by prepending an antislash to all underscores in "name" @@ -59,7 +62,7 @@ SymbolTable::addSymbol(const string &name, SymbolType type) throw (AlreadyDeclar tex_name.insert(pos, "\\"); pos += 2; } - addSymbol(name, type, tex_name); + return addSymbol(name, type, tex_name); } void @@ -197,4 +200,70 @@ SymbolTable::writeOutput(ostream &output) const throw (NotYetFrozenException) << "M_.param_nbr = " << param_nbr() << ";" << endl; output << "M_.Sigma_e = zeros(" << exo_nbr() << ", " << exo_nbr() << ");" << endl; + + // Write the auxiliary variable table + for(int i = 0; i < (int) aux_vars.size(); i++) + { + output << "M_.aux_vars(" << i+1 << ").endo_index = " << getTypeSpecificID(aux_vars[i].symb_id)+1 << ";" << endl + << "M_.aux_vars(" << i+1 << ").type = " << aux_vars[i].type << ";" << endl; + switch(aux_vars[i].type) + { + case avLead: + break; + case avLag: + output << "M_.aux_vars(" << i+1 << ").orig_endo_index = " << getTypeSpecificID(aux_vars[i].orig_symb_id)+1 << ";" << endl + << "M_.aux_vars(" << i+1 << ").orig_lag = " << aux_vars[i].orig_lag << ";" << endl; + break; + } + } +} + +int +SymbolTable::addLeadAuxiliaryVar(int index) throw (FrozenException) +{ + ostringstream varname; + varname << "AUXLEAD_" << index; + int symb_id; + try + { + symb_id = addSymbol(varname.str(), eEndogenous); + } + catch(AlreadyDeclaredException &e) + { + cerr << "ERROR: you should rename your variable called " << varname.str() << ", this name is internally used by Dynare" << endl; + exit(EXIT_FAILURE); + } + + AuxVarInfo avi; + avi.symb_id = symb_id; + avi.type = avLead; + aux_vars.push_back(avi); + + return symb_id; +} + +int +SymbolTable::addLagAuxiliaryVar(int orig_symb_id, int orig_lag) throw (FrozenException) +{ + ostringstream varname; + varname << "AUXLAG_" << orig_symb_id << "_" << -orig_lag; + int symb_id; + try + { + symb_id = addSymbol(varname.str(), eEndogenous); + } + catch(AlreadyDeclaredException &e) + { + cerr << "ERROR: you should rename your variable called " << varname.str() << ", this name is internally used by Dynare" << endl; + exit(EXIT_FAILURE); + } + + AuxVarInfo avi; + avi.symb_id = symb_id; + avi.type = avLag; + avi.orig_symb_id = orig_symb_id; + avi.orig_lag = orig_lag; + aux_vars.push_back(avi); + + return symb_id; } diff --git a/SymbolTable.hh b/SymbolTable.hh index 11a8ba22..baed9a62 100644 --- a/SymbolTable.hh +++ b/SymbolTable.hh @@ -29,6 +29,22 @@ using namespace std; #include "CodeInterpreter.hh" +//! Types of auxiliary variables +enum aux_var_t + { + avLead = 0, //!< Substitute for leads >= 2 + avLag = 1 //!< Substitute for lags >= 2 + }; + +//! Information on some auxiliary variables +struct AuxVarInfo +{ + int symb_id; //!< Symbol ID of the auxiliary variable + aux_var_t type; //!< Its type + int orig_symb_id; //!< Symbol ID of the endo of the original model represented by this aux var. Only for avLag + int orig_lag; //!< Lag of the endo of the original model represented by this aux var. Only for avLag +}; + //! Stores the symbol table /*! A symbol is given by its name, and is internally represented by a unique integer. @@ -70,6 +86,8 @@ private: vector exo_det_ids; //! Maps type specific IDs of parameters to symbol IDs vector param_ids; + //! Information about auxiliary variables + vector aux_vars; public: SymbolTable(); //! Thrown when trying to access an unknown symbol (by name) @@ -115,9 +133,16 @@ public: { }; //! Add a symbol - void addSymbol(const string &name, SymbolType type, const string &tex_name) throw (AlreadyDeclaredException, FrozenException); + /*! Returns the symbol ID */ + int addSymbol(const string &name, SymbolType type, const string &tex_name) throw (AlreadyDeclaredException, FrozenException); //! Add a symbol without its TeX name (will be equal to its name) - void addSymbol(const string &name, SymbolType type) throw (AlreadyDeclaredException, FrozenException); + /*! Returns the symbol ID */ + int addSymbol(const string &name, SymbolType type) throw (AlreadyDeclaredException, FrozenException); + //! Adds an auxiliary variable for leads >=2 + /*! Uses the given argument to construct the variable name. + Will exit the preprocessor with an error message if the variable name already declared by the user. Returns the symbol ID. */ + int addLeadAuxiliaryVar(int index) throw (FrozenException); + int addLagAuxiliaryVar(int orig_symb_id, int orig_lag) throw (FrozenException); //! Tests if symbol already exists inline bool exists(const string &name) const; //! Get symbol name (by ID)