From 1b952a12e68f6b8946b34d580b9550d616edf72e Mon Sep 17 00:00:00 2001 From: Houtan Bastani Date: Thu, 31 May 2018 15:34:25 +0200 Subject: [PATCH] fix bug in var max lag and simplify code --- src/DynamicModel.cc | 20 ++++---- src/ExprNode.cc | 109 ++++++++++++++++++++++---------------------- src/ExprNode.hh | 20 ++++---- 3 files changed, 72 insertions(+), 77 deletions(-) diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc index 73d34bac..5ee7bccf 100644 --- a/src/DynamicModel.cc +++ b/src/DynamicModel.cc @@ -3355,23 +3355,19 @@ DynamicModel::checkVarMinLag(vector &eqnumber) const int DynamicModel::getVarMaxLag(StaticModel &static_model, vector &eqnumber) const { - vector lhs; + set lhs; for (vector::const_iterator it = eqnumber.begin(); it != eqnumber.end(); it++) + equations[*it]->get_arg1()->collectVARLHSVariable(lhs); + + if (eqnumber.size() != lhs.size()) { - set lhs_set; - equations[*it]->get_arg1()->collectVARLHSVariable(lhs_set); - if (lhs_set.size() != 1) - { - cerr << "ERROR: in Equation " - << ". A VAR may only have one endogenous variable on the LHS. " << endl; - exit(EXIT_FAILURE); - } - lhs.push_back(*(lhs_set.begin())); + cerr << "The LHS variables of the VAR are not unique" << endl; + exit(EXIT_FAILURE); } set lhs_static; - for(vector::const_iterator it = lhs.begin(); + for(set::const_iterator it = lhs.begin(); it != lhs.end(); it++) lhs_static.insert((*it)->toStatic(static_model)); @@ -3390,7 +3386,7 @@ DynamicModel::getVarLhsDiffAndInfo(vector &eqnumber, vector &diff, for (vector::const_iterator it = eqnumber.begin(); it != eqnumber.end(); it++) { - diff.push_back(equations[*it]->get_arg1()->isDiffPresent()); + equations[*it]->get_arg1()->countDiffs() > 0 ? diff.push_back(true) : diff.push_back(false); if (diff.back()) { set > diff_set; diff --git a/src/ExprNode.cc b/src/ExprNode.cc index c38154e0..a0a62ca4 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -294,12 +294,6 @@ ExprNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_ar return false; } -bool -ExprNode::isDiffPresent() const -{ - return false; -} - void ExprNode::getEndosAndMaxLags(map &model_endos_and_lags) const { @@ -313,10 +307,10 @@ NumConstNode::NumConstNode(DataTree &datatree_arg, int id_arg) : datatree.num_const_node_map[id] = this; } -bool -NumConstNode::isDiffPresent() const +int +NumConstNode::countDiffs() const { - return false; + return 0; } void @@ -389,6 +383,8 @@ NumConstNode::compile(ostream &CompileCode, unsigned int &instruction_number, void NumConstNode::collectVARLHSVariable(set &result) const { + cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl; + exit(EXIT_FAILURE); } void @@ -1144,7 +1140,7 @@ VariableNode::collectVARLHSVariable(set &result) const result.insert(const_cast(this)); else { - cerr << "ERROR: A VAR must have one endogenous variable on the LHS." << endl; + cerr << "ERROR: you can only have endogenous variables or unary ops on LHS of VAR" << endl; exit(EXIT_FAILURE); } } @@ -1730,10 +1726,10 @@ VariableNode::detrend(int symb_id, bool log_trend, expr_t trend) const } } -bool -VariableNode::isDiffPresent() const +int +VariableNode::countDiffs() const { - return false; + return 0; } expr_t @@ -2965,19 +2961,20 @@ UnaryOpNode::VarMaxLag(DataTree &static_datatree, set &static_lhs, int & arg->VarMaxLag(static_datatree, static_lhs, max_lag); else { - for (set::const_iterator it = static_lhs.begin(); - it != static_lhs.end(); it++) - if (*it == this->toStatic(static_datatree)) - { - int max_lag_tmp = arg->maxLag(); - if (max_lag_tmp > max_lag) - max_lag = max_lag_tmp; - return; - } - int max_lag_tmp = 0; - arg->VarMaxLag(static_datatree, static_lhs, max_lag_tmp); - if (max_lag_tmp + 1 > max_lag) - max_lag = max_lag_tmp + 1; + set::const_iterator it = static_lhs.find(this->toStatic(static_datatree)); + if (it != static_lhs.end()) + { + int max_lag_tmp = arg->maxLag() - arg->countDiffs(); + if (max_lag_tmp > max_lag) + max_lag = max_lag_tmp; + } + else + { + int max_lag_tmp = 0; + arg->VarMaxLag(static_datatree, static_lhs, max_lag_tmp); + if (max_lag_tmp + 1 > max_lag) + max_lag = max_lag_tmp + 1; + } } } @@ -3027,12 +3024,12 @@ UnaryOpNode::substituteAdl() const return retval; } -bool -UnaryOpNode::isDiffPresent() const +int +UnaryOpNode::countDiffs() const { if (op_code == oDiff) - return true; - return arg->isDiffPresent(); + return arg->countDiffs() + 1; + return arg->countDiffs(); } bool @@ -4414,8 +4411,8 @@ BinaryOpNode::VarMaxLag(DataTree &static_datatree, set &static_lhs, int void BinaryOpNode::collectVARLHSVariable(set &result) const { - arg1->collectVARLHSVariable(result); - arg2->collectVARLHSVariable(result); + cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl; + exit(EXIT_FAILURE); } void @@ -5011,10 +5008,10 @@ BinaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &no return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree); } -bool -BinaryOpNode::isDiffPresent() const +int +BinaryOpNode::countDiffs() const { - return arg1->isDiffPresent() || arg2->isDiffPresent(); + return arg1->countDiffs() + arg2->countDiffs(); } expr_t @@ -5657,9 +5654,8 @@ TrinaryOpNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int void TrinaryOpNode::collectVARLHSVariable(set &result) const { - arg1->collectVARLHSVariable(result); - arg2->collectVARLHSVariable(result); - arg3->collectVARLHSVariable(result); + cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl; + exit(EXIT_FAILURE); } void @@ -5923,10 +5919,10 @@ TrinaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &n return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree); } -bool -TrinaryOpNode::isDiffPresent() const +int +TrinaryOpNode::countDiffs() const { - return arg1->isDiffPresent() || arg2->isDiffPresent() || arg3->isDiffPresent(); + return arg1->countDiffs() + arg2->countDiffs() + arg3->countDiffs(); } expr_t @@ -6127,9 +6123,8 @@ AbstractExternalFunctionNode::compileExternalFunctionArguments(ostream &CompileC void AbstractExternalFunctionNode::collectVARLHSVariable(set &result) const { - for (vector::const_iterator it = arguments.begin(); - it != arguments.end(); it++) - (*it)->collectVARLHSVariable(result); + cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl; + exit(EXIT_FAILURE); } void @@ -6362,13 +6357,13 @@ AbstractExternalFunctionNode::substituteUnaryOpNodes(DataTree &static_datatree, return buildSimilarExternalFunctionNode(arguments_subst, datatree); } -bool -AbstractExternalFunctionNode::isDiffPresent() const +int +AbstractExternalFunctionNode::countDiffs() const { - bool result = false; + int ndiffs = 0; for (vector::const_iterator it = arguments.begin(); it != arguments.end(); it++) - result = result || (*it)->isDiffPresent(); - return result; + ndiffs += (*it)->countDiffs(); + return ndiffs; } expr_t @@ -7829,10 +7824,10 @@ VarExpectationNode::eval(const eval_context_t &eval_context) const throw (EvalEx return it->second; } -bool -VarExpectationNode::isDiffPresent() const +int +VarExpectationNode::countDiffs() const { - return false; + return 0; } void @@ -7843,6 +7838,8 @@ VarExpectationNode::computeXrefs(EquationInfo &ei) const void VarExpectationNode::collectVARLHSVariable(set &result) const { + cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl; + exit(EXIT_FAILURE); } void @@ -8281,6 +8278,8 @@ PacExpectationNode::computeXrefs(EquationInfo &ei) const void PacExpectationNode::collectVARLHSVariable(set &result) const { + cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl; + exit(EXIT_FAILURE); } void @@ -8306,10 +8305,10 @@ PacExpectationNode::compile(ostream &CompileCode, unsigned int &instruction_numb exit(EXIT_FAILURE); } -bool -PacExpectationNode::isDiffPresent() const +int +PacExpectationNode::countDiffs() const { - return false; + return 0; } pair diff --git a/src/ExprNode.hh b/src/ExprNode.hh index 221b85aa..428ecc6c 100644 --- a/src/ExprNode.hh +++ b/src/ExprNode.hh @@ -465,8 +465,8 @@ class ExprNode //! Returns true if the expression contains one or several exogenous variable virtual bool containsExogenous() const = 0; - //! Returns true if the expression contains a diff operator - virtual bool isDiffPresent(void) const = 0; + //! Returns the number of diffs present + virtual int countDiffs() const = 0; //! Return true if the nodeID is a variable withe a type equal to type_arg, a specific variable id aqual to varfiable_id and a lag equal to lag_arg and false otherwise /*! @@ -595,7 +595,7 @@ public: virtual bool isNumConstNodeEqualTo(double value) const; virtual bool containsEndogenous(void) const; virtual bool containsExogenous() const; - virtual bool isDiffPresent(void) const; + virtual int countDiffs() const; virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const; virtual expr_t replaceTrendVar() const; virtual expr_t detrend(int symb_id, bool log_trend, expr_t trend) const; @@ -685,7 +685,7 @@ public: virtual bool isNumConstNodeEqualTo(double value) const; virtual bool containsEndogenous(void) const; virtual bool containsExogenous() const; - virtual bool isDiffPresent(void) const; + virtual int countDiffs() const; virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const; virtual expr_t replaceTrendVar() const; virtual expr_t detrend(int symb_id, bool log_trend, expr_t trend) const; @@ -799,7 +799,7 @@ public: virtual bool isNumConstNodeEqualTo(double value) const; virtual bool containsEndogenous(void) const; virtual bool containsExogenous() const; - virtual bool isDiffPresent(void) const; + virtual int countDiffs() const; virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const; virtual expr_t replaceTrendVar() const; virtual expr_t detrend(int symb_id, bool log_trend, expr_t trend) const; @@ -928,7 +928,7 @@ public: virtual bool isNumConstNodeEqualTo(double value) const; virtual bool containsEndogenous(void) const; virtual bool containsExogenous() const; - virtual bool isDiffPresent(void) const; + virtual int countDiffs() const; virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const; virtual expr_t replaceTrendVar() const; virtual expr_t detrend(int symb_id, bool log_trend, expr_t trend) const; @@ -1033,7 +1033,7 @@ public: virtual bool isNumConstNodeEqualTo(double value) const; virtual bool containsEndogenous(void) const; virtual bool containsExogenous() const; - virtual bool isDiffPresent(void) const; + virtual int countDiffs() const; virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const; virtual expr_t replaceTrendVar() const; virtual expr_t detrend(int symb_id, bool log_trend, expr_t trend) const; @@ -1139,7 +1139,7 @@ public: virtual bool isNumConstNodeEqualTo(double value) const; virtual bool containsEndogenous(void) const; virtual bool containsExogenous() const; - virtual bool isDiffPresent(void) const; + virtual int countDiffs() const; virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const; virtual void writePrhs(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, deriv_node_temp_terms_t &tef_terms, const string &ending) const; virtual expr_t replaceTrendVar() const; @@ -1338,7 +1338,7 @@ public: virtual void collectDynamicVariables(SymbolType type_arg, set > &result) const; virtual bool containsEndogenous(void) const; virtual bool containsExogenous() const; - virtual bool isDiffPresent(void) const; + virtual int countDiffs() const; virtual bool isNumConstNodeEqualTo(double value) const; virtual expr_t differentiateForwardVars(const vector &subset, subst_table_t &subst_table, vector &neweqs) const; virtual expr_t decreaseLeadsLagsPredeterminedVariables() const; @@ -1423,7 +1423,7 @@ public: virtual void collectDynamicVariables(SymbolType type_arg, set > &result) const; virtual bool containsEndogenous(void) const; virtual bool containsExogenous() const; - virtual bool isDiffPresent(void) const; + virtual int countDiffs() const; virtual bool isNumConstNodeEqualTo(double value) const; virtual expr_t differentiateForwardVars(const vector &subset, subst_table_t &subst_table, vector &neweqs) const; virtual expr_t decreaseLeadsLagsPredeterminedVariables() const;