From 19ebd12a5dd0d47bab7a573adb47b75ed8c17c0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= Date: Thu, 19 Aug 2010 15:20:54 +0200 Subject: [PATCH] Preprocessor: when removing lags greater than 2 on endogenous or lags on exogenous, don't substitute model local variables that do not need to --- ExprNode.cc | 110 +++++++++++++++++++++++++++++++++++++++++++++++++++- ExprNode.hh | 21 ++++++++++ 2 files changed, 129 insertions(+), 2 deletions(-) diff --git a/ExprNode.cc b/ExprNode.cc index 883e3451..e038809a 100644 --- a/ExprNode.cc +++ b/ExprNode.cc @@ -341,6 +341,18 @@ NumConstNode::maxExoLead() const return 0; } +int +NumConstNode::maxEndoLag() const +{ + return 0; +} + +int +NumConstNode::maxExoLag() const +{ + return 0; +} + NodeID NumConstNode::decreaseLeadsLags(int n) const { @@ -871,6 +883,34 @@ VariableNode::maxExoLead() const } } +int +VariableNode::maxEndoLag() const +{ + switch (type) + { + case eEndogenous: + return max(-lag, 0); + case eModelLocalVariable: + return datatree.local_variables_table[symb_id]->maxEndoLag(); + default: + return 0; + } +} + +int +VariableNode::maxExoLag() const +{ + switch (type) + { + case eExogenous: + return max(-lag, 0); + case eModelLocalVariable: + return datatree.local_variables_table[symb_id]->maxExoLag(); + default: + return 0; + } +} + NodeID VariableNode::decreaseLeadsLags(int n) const { @@ -922,6 +962,7 @@ NodeID VariableNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const { VariableNode *substexpr; + NodeID value; subst_table_t::const_iterator it; int cur_lag; switch (type) @@ -958,7 +999,11 @@ VariableNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector return substexpr; case eModelLocalVariable: - return datatree.local_variables_table[symb_id]->substituteEndoLagGreaterThanTwo(subst_table, neweqs); + value = datatree.local_variables_table[symb_id]; + if (value->maxEndoLag() <= 1) + return const_cast(this); + else + return value->substituteEndoLagGreaterThanTwo(subst_table, neweqs); default: return const_cast(this); } @@ -990,6 +1035,7 @@ NodeID VariableNode::substituteExoLag(subst_table_t &subst_table, vector &neweqs) const { VariableNode *substexpr; + NodeID value; subst_table_t::const_iterator it; int cur_lag; switch (type) @@ -1026,7 +1072,11 @@ VariableNode::substituteExoLag(subst_table_t &subst_table, vectorsubstituteExoLag(subst_table, neweqs); + value = datatree.local_variables_table[symb_id]; + if (value->maxExoLag() == 0) + return const_cast(this); + else + return value->substituteExoLag(subst_table, neweqs); default: return const_cast(this); } @@ -1741,6 +1791,18 @@ UnaryOpNode::maxExoLead() const return arg->maxExoLead(); } +int +UnaryOpNode::maxEndoLag() const +{ + return arg->maxEndoLag(); +} + +int +UnaryOpNode::maxExoLag() const +{ + return arg->maxExoLag(); +} + NodeID UnaryOpNode::decreaseLeadsLags(int n) const { @@ -2767,6 +2829,18 @@ BinaryOpNode::maxExoLead() const return max(arg1->maxExoLead(), arg2->maxExoLead()); } +int +BinaryOpNode::maxEndoLag() const +{ + return max(arg1->maxEndoLag(), arg2->maxEndoLag()); +} + +int +BinaryOpNode::maxExoLag() const +{ + return max(arg1->maxExoLag(), arg2->maxExoLag()); +} + NodeID BinaryOpNode::decreaseLeadsLags(int n) const { @@ -3324,6 +3398,18 @@ TrinaryOpNode::maxExoLead() const return max(arg1->maxExoLead(), max(arg2->maxExoLead(), arg3->maxExoLead())); } +int +TrinaryOpNode::maxEndoLag() const +{ + return max(arg1->maxEndoLag(), max(arg2->maxEndoLag(), arg3->maxEndoLag())); +} + +int +TrinaryOpNode::maxExoLag() const +{ + return max(arg1->maxExoLag(), max(arg2->maxExoLag(), arg3->maxExoLag())); +} + NodeID TrinaryOpNode::decreaseLeadsLags(int n) const { @@ -3638,6 +3724,26 @@ ExternalFunctionNode::maxExoLead() const return val; } +int +ExternalFunctionNode::maxEndoLag() const +{ + int val = 0; + for (vector::const_iterator it = arguments.begin(); + it != arguments.end(); it++) + val = max(val, (*it)->maxEndoLag()); + return val; +} + +int +ExternalFunctionNode::maxExoLag() const +{ + int val = 0; + for (vector::const_iterator it = arguments.begin(); + it != arguments.end(); it++) + val = max(val, (*it)->maxExoLag()); + return val; +} + NodeID ExternalFunctionNode::decreaseLeadsLags(int n) const { diff --git a/ExprNode.hh b/ExprNode.hh index ecc4cbf9..179dd12b 100644 --- a/ExprNode.hh +++ b/ExprNode.hh @@ -261,6 +261,14 @@ public: /*! Always returns a non-negative value */ virtual int maxExoLead() const = 0; + //! Returns the maximum lag of endogenous in this expression + /*! Always returns a non-negative value */ + virtual int maxEndoLag() const = 0; + + //! Returns the maximum lag of exogenous in this expression + /*! Always returns a non-negative value */ + virtual int maxExoLag() const = 0; + //! Returns a new expression where all the leads/lags have been shifted backwards by the same amount /*! Only acts on endogenous, exogenous, exogenous det @@ -387,6 +395,8 @@ public: virtual NodeID getChainRuleDerivative(int deriv_id, const map &recursive_variables); virtual int maxEndoLead() const; virtual int maxExoLead() const; + virtual int maxEndoLag() const; + virtual int maxExoLag() const; virtual NodeID decreaseLeadsLags(int n) const; virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const; virtual NodeID substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const; @@ -405,6 +415,7 @@ private: //! Id from the symbol table const int symb_id; const SymbolType type; + //! A positive value is a lead, a negative is a lag const int lag; virtual NodeID computeDerivative(int deriv_id); public: @@ -432,6 +443,8 @@ public: virtual NodeID getChainRuleDerivative(int deriv_id, const map &recursive_variables); virtual int maxEndoLead() const; virtual int maxExoLead() const; + virtual int maxEndoLag() const; + virtual int maxExoLag() const; virtual NodeID decreaseLeadsLags(int n) const; virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const; virtual NodeID substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const; @@ -493,6 +506,8 @@ public: virtual NodeID getChainRuleDerivative(int deriv_id, const map &recursive_variables); virtual int maxEndoLead() const; virtual int maxExoLead() const; + virtual int maxEndoLag() const; + virtual int maxExoLag() const; virtual NodeID decreaseLeadsLags(int n) const; virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const; //! Creates another UnaryOpNode with the same opcode, but with a possibly different datatree and argument @@ -561,6 +576,8 @@ public: virtual NodeID getChainRuleDerivative(int deriv_id, const map &recursive_variables); virtual int maxEndoLead() const; virtual int maxExoLead() const; + virtual int maxEndoLag() const; + virtual int maxExoLag() const; virtual NodeID decreaseLeadsLags(int n) const; virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const; //! Creates another BinaryOpNode with the same opcode, but with a possibly different datatree and arguments @@ -611,6 +628,8 @@ public: virtual NodeID getChainRuleDerivative(int deriv_id, const map &recursive_variables); virtual int maxEndoLead() const; virtual int maxExoLead() const; + virtual int maxEndoLag() const; + virtual int maxExoLag() const; virtual NodeID decreaseLeadsLags(int n) const; virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const; //! Creates another TrinaryOpNode with the same opcode, but with a possibly different datatree and arguments @@ -667,6 +686,8 @@ public: virtual NodeID getChainRuleDerivative(int deriv_id, const map &recursive_variables); virtual int maxEndoLead() const; virtual int maxExoLead() const; + virtual int maxEndoLag() const; + virtual int maxExoLag() const; virtual NodeID decreaseLeadsLags(int n) const; virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const; virtual NodeID substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const;