diff --git a/src/DataTree.hh b/src/DataTree.hh index d84ca9df..623641c1 100644 --- a/src/DataTree.hh +++ b/src/DataTree.hh @@ -38,19 +38,7 @@ using namespace std; class DataTree { - friend class ExprNode; - friend class NumConstNode; - friend class VariableNode; - friend class UnaryOpNode; - friend class BinaryOpNode; - friend class TrinaryOpNode; - friend class AbstractExternalFunctionNode; - friend class ExternalFunctionNode; - friend class FirstDerivExternalFunctionNode; - friend class SecondDerivExternalFunctionNode; - friend class VarExpectationNode; - friend class PacExpectationNode; -protected: +public: //! A reference to the symbol table SymbolTable &symbol_table; //! Reference to numerical constants table @@ -62,6 +50,7 @@ protected: //! A reference to the VAR model table VarModelTable &var_model_table; +protected: //! num_constant_id -> NumConstNode using num_const_node_map_t = map; num_const_node_map_t num_const_node_map; @@ -126,7 +115,6 @@ private: //! The list of nodes vector> node_list; - inline expr_t AddPossiblyNegativeConstant(double val); inline expr_t AddUnaryOp(UnaryOpcode op_code, expr_t arg, int arg_exp_info_set = 0, int param1_symb_id = 0, int param2_symb_id = 0, const string &adl_param_name = "", const vector &adl_lags = vector()); inline expr_t AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder = 0); inline expr_t AddTrinaryOp(expr_t arg1, TrinaryOpcode op_code, expr_t arg2, expr_t arg3); @@ -157,6 +145,7 @@ public: { }; + inline expr_t AddPossiblyNegativeConstant(double val); //! Adds a non-negative numerical constant (possibly Inf or NaN) expr_t AddNonNegativeConstant(const string &value); //! Adds a variable @@ -316,6 +305,25 @@ public: { return false; }; + + class UnknownLocalVariableException + { + public: + //! Symbol ID + int id; + UnknownLocalVariableException(int id_arg) : id(id_arg) + { + } + }; + + expr_t getLocalVariable(int symb_id) const + { + auto it = local_variables_table.find(symb_id); + if (it == local_variables_table.end()) + throw UnknownLocalVariableException(symb_id); + + return it->second; + } }; inline expr_t diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 42611a1f..ea81f1d2 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -727,9 +727,9 @@ VariableNode::prepareForDerivation() non_null_derivatives.insert(datatree.getDerivID(symb_id, lag)); break; case SymbolType::modelLocalVariable: - datatree.local_variables_table[symb_id]->prepareForDerivation(); + datatree.getLocalVariable(symb_id)->prepareForDerivation(); // Non null derivatives are those of the value of the local parameter - non_null_derivatives = datatree.local_variables_table[symb_id]->non_null_derivatives; + non_null_derivatives = datatree.getLocalVariable(symb_id)->non_null_derivatives; break; case SymbolType::modFileLocalVariable: case SymbolType::statementDeclaredVariable: @@ -762,7 +762,7 @@ VariableNode::computeDerivative(int deriv_id) else return datatree.Zero; case SymbolType::modelLocalVariable: - return datatree.local_variables_table[symb_id]->getDerivative(deriv_id); + return datatree.getLocalVariable(symb_id)->getDerivative(deriv_id); case SymbolType::modFileLocalVariable: cerr << "ModFileLocalVariable is not derivable" << endl; exit(EXIT_FAILURE); @@ -791,7 +791,7 @@ VariableNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, t if (it != temporary_terms.end()) temporary_terms_inuse.insert(idx); if (type == SymbolType::modelLocalVariable) - datatree.local_variables_table[symb_id]->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block); + datatree.getLocalVariable(symb_id)->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block); } bool @@ -866,7 +866,7 @@ VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type, || output_type == oCDynamicSteadyStateOperator) { output << "("; - datatree.local_variables_table[symb_id]->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms); + datatree.getLocalVariable(symb_id)->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms); output << ")"; } else @@ -1108,7 +1108,7 @@ VariableNode::compile(ostream &CompileCode, unsigned int &instruction_number, const deriv_node_temp_terms_t &tef_terms) const { if (type == SymbolType::modelLocalVariable || type == SymbolType::modFileLocalVariable) - datatree.local_variables_table[symb_id]->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms); + datatree.getLocalVariable(symb_id)->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms); else { int tsid = datatree.symbol_table.getTypeSpecificID(symb_id); @@ -1184,7 +1184,7 @@ VariableNode::computeTemporaryTerms(map &reference_count, int equation) const { if (type == SymbolType::modelLocalVariable) - datatree.local_variables_table[symb_id]->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation); + datatree.getLocalVariable(symb_id)->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation); } void @@ -1205,7 +1205,7 @@ VariableNode::collectDynamicVariables(SymbolType type_arg, set> & if (type == type_arg) result.emplace(symb_id, lag); if (type == SymbolType::modelLocalVariable) - datatree.local_variables_table[symb_id]->collectDynamicVariables(type_arg, result); + datatree.getLocalVariable(symb_id)->collectDynamicVariables(type_arg, result); } pair @@ -1229,14 +1229,14 @@ VariableNode::normalizeEquation(int var_endo, vector &recur return datatree.Zero; } case SymbolType::modelLocalVariable: - return datatree.local_variables_table[symb_id]->getChainRuleDerivative(deriv_id, recursive_variables); + return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables); case SymbolType::modFileLocalVariable: cerr << "ModFileLocalVariable is not derivable" << endl; exit(EXIT_FAILURE); @@ -1351,7 +1351,7 @@ VariableNode::maxEndoLead() const case SymbolType::endogenous: return max(lag, 0); case SymbolType::modelLocalVariable: - return datatree.local_variables_table[symb_id]->maxEndoLead(); + return datatree.getLocalVariable(symb_id)->maxEndoLead(); default: return 0; } @@ -1365,7 +1365,7 @@ VariableNode::maxExoLead() const case SymbolType::exogenous: return max(lag, 0); case SymbolType::modelLocalVariable: - return datatree.local_variables_table[symb_id]->maxExoLead(); + return datatree.getLocalVariable(symb_id)->maxExoLead(); default: return 0; } @@ -1379,7 +1379,7 @@ VariableNode::maxEndoLag() const case SymbolType::endogenous: return max(-lag, 0); case SymbolType::modelLocalVariable: - return datatree.local_variables_table[symb_id]->maxEndoLag(); + return datatree.getLocalVariable(symb_id)->maxEndoLag(); default: return 0; } @@ -1393,7 +1393,7 @@ VariableNode::maxExoLag() const case SymbolType::exogenous: return max(-lag, 0); case SymbolType::modelLocalVariable: - return datatree.local_variables_table[symb_id]->maxExoLag(); + return datatree.getLocalVariable(symb_id)->maxExoLag(); default: return 0; } @@ -1409,7 +1409,7 @@ VariableNode::maxLead() const case SymbolType::exogenous: return lag; case SymbolType::modelLocalVariable: - return datatree.local_variables_table[symb_id]->maxLead(); + return datatree.getLocalVariable(symb_id)->maxLead(); default: return 0; } @@ -1428,7 +1428,7 @@ VariableNode::VarMinLag() const else return 1; // Can have contemporaneus exog in VAR case SymbolType::modelLocalVariable: - return datatree.local_variables_table[symb_id]->VarMinLag(); + return datatree.getLocalVariable(symb_id)->VarMinLag(); default: return 1; } @@ -1444,7 +1444,7 @@ VariableNode::maxLag() const case SymbolType::exogenous: return -lag; case SymbolType::modelLocalVariable: - return datatree.local_variables_table[symb_id]->maxLag(); + return datatree.getLocalVariable(symb_id)->maxLag(); default: return 0; } @@ -1532,7 +1532,7 @@ VariableNode::decreaseLeadsLags(int n) const case SymbolType::logTrend: return datatree.AddVariable(symb_id, lag-n); case SymbolType::modelLocalVariable: - return datatree.local_variables_table[symb_id]->decreaseLeadsLags(n); + return datatree.getLocalVariable(symb_id)->decreaseLeadsLags(n); default: return const_cast(this); } @@ -1559,7 +1559,7 @@ VariableNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vecto else return createEndoLeadAuxiliaryVarForMyself(subst_table, neweqs); case SymbolType::modelLocalVariable: - value = datatree.local_variables_table[symb_id]; + value = datatree.getLocalVariable(symb_id); if (value->maxEndoLead() <= 1) return const_cast(this); else @@ -1610,7 +1610,7 @@ VariableNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector return substexpr; case SymbolType::modelLocalVariable: - value = datatree.local_variables_table[symb_id]; + value = datatree.getLocalVariable(symb_id); if (value->maxEndoLag() <= 1) return const_cast(this); else @@ -1632,7 +1632,7 @@ VariableNode::substituteExoLead(subst_table_t &subst_table, vectormaxExoLead() == 0) return const_cast(this); else @@ -1683,7 +1683,7 @@ VariableNode::substituteExoLag(subst_table_t &subst_table, vectormaxExoLag() == 0) return const_cast(this); else @@ -1729,7 +1729,7 @@ VariableNode::differentiateForwardVars(const vector &subset, subst_table return datatree.AddPlus(datatree.AddVariable(symb_id, 0), diffvar); } case SymbolType::modelLocalVariable: - value = datatree.local_variables_table[symb_id]; + value = datatree.getLocalVariable(symb_id); if (value->maxEndoLead() <= 0) return const_cast(this); else