Preprocessor: when removing lags greater than 2 on endogenous or lags on exogenous, don't substitute model local variables that do not need to

issue#70
Sébastien Villemot 2010-08-19 15:20:54 +02:00
parent e44c41334e
commit 19ebd12a5d
2 changed files with 129 additions and 2 deletions

View File

@ -341,6 +341,18 @@ NumConstNode::maxExoLead() const
return 0; return 0;
} }
int
NumConstNode::maxEndoLag() const
{
return 0;
}
int
NumConstNode::maxExoLag() const
{
return 0;
}
NodeID NodeID
NumConstNode::decreaseLeadsLags(int n) const NumConstNode::decreaseLeadsLags(int n) const
{ {
@ -871,6 +883,34 @@ VariableNode::maxExoLead() const
} }
} }
int
VariableNode::maxEndoLag() const
{
switch (type)
{
case eEndogenous:
return max(-lag, 0);
case eModelLocalVariable:
return datatree.local_variables_table[symb_id]->maxEndoLag();
default:
return 0;
}
}
int
VariableNode::maxExoLag() const
{
switch (type)
{
case eExogenous:
return max(-lag, 0);
case eModelLocalVariable:
return datatree.local_variables_table[symb_id]->maxExoLag();
default:
return 0;
}
}
NodeID NodeID
VariableNode::decreaseLeadsLags(int n) const VariableNode::decreaseLeadsLags(int n) const
{ {
@ -922,6 +962,7 @@ NodeID
VariableNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const VariableNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
{ {
VariableNode *substexpr; VariableNode *substexpr;
NodeID value;
subst_table_t::const_iterator it; subst_table_t::const_iterator it;
int cur_lag; int cur_lag;
switch (type) switch (type)
@ -958,7 +999,11 @@ VariableNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector
return substexpr; return substexpr;
case eModelLocalVariable: case eModelLocalVariable:
return datatree.local_variables_table[symb_id]->substituteEndoLagGreaterThanTwo(subst_table, neweqs); value = datatree.local_variables_table[symb_id];
if (value->maxEndoLag() <= 1)
return const_cast<VariableNode *>(this);
else
return value->substituteEndoLagGreaterThanTwo(subst_table, neweqs);
default: default:
return const_cast<VariableNode *>(this); return const_cast<VariableNode *>(this);
} }
@ -990,6 +1035,7 @@ NodeID
VariableNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const VariableNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
{ {
VariableNode *substexpr; VariableNode *substexpr;
NodeID value;
subst_table_t::const_iterator it; subst_table_t::const_iterator it;
int cur_lag; int cur_lag;
switch (type) switch (type)
@ -1026,7 +1072,11 @@ VariableNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *
return substexpr; return substexpr;
case eModelLocalVariable: case eModelLocalVariable:
return datatree.local_variables_table[symb_id]->substituteExoLag(subst_table, neweqs); value = datatree.local_variables_table[symb_id];
if (value->maxExoLag() == 0)
return const_cast<VariableNode *>(this);
else
return value->substituteExoLag(subst_table, neweqs);
default: default:
return const_cast<VariableNode *>(this); return const_cast<VariableNode *>(this);
} }
@ -1741,6 +1791,18 @@ UnaryOpNode::maxExoLead() const
return arg->maxExoLead(); return arg->maxExoLead();
} }
int
UnaryOpNode::maxEndoLag() const
{
return arg->maxEndoLag();
}
int
UnaryOpNode::maxExoLag() const
{
return arg->maxExoLag();
}
NodeID NodeID
UnaryOpNode::decreaseLeadsLags(int n) const UnaryOpNode::decreaseLeadsLags(int n) const
{ {
@ -2767,6 +2829,18 @@ BinaryOpNode::maxExoLead() const
return max(arg1->maxExoLead(), arg2->maxExoLead()); return max(arg1->maxExoLead(), arg2->maxExoLead());
} }
int
BinaryOpNode::maxEndoLag() const
{
return max(arg1->maxEndoLag(), arg2->maxEndoLag());
}
int
BinaryOpNode::maxExoLag() const
{
return max(arg1->maxExoLag(), arg2->maxExoLag());
}
NodeID NodeID
BinaryOpNode::decreaseLeadsLags(int n) const BinaryOpNode::decreaseLeadsLags(int n) const
{ {
@ -3324,6 +3398,18 @@ TrinaryOpNode::maxExoLead() const
return max(arg1->maxExoLead(), max(arg2->maxExoLead(), arg3->maxExoLead())); return max(arg1->maxExoLead(), max(arg2->maxExoLead(), arg3->maxExoLead()));
} }
int
TrinaryOpNode::maxEndoLag() const
{
return max(arg1->maxEndoLag(), max(arg2->maxEndoLag(), arg3->maxEndoLag()));
}
int
TrinaryOpNode::maxExoLag() const
{
return max(arg1->maxExoLag(), max(arg2->maxExoLag(), arg3->maxExoLag()));
}
NodeID NodeID
TrinaryOpNode::decreaseLeadsLags(int n) const TrinaryOpNode::decreaseLeadsLags(int n) const
{ {
@ -3638,6 +3724,26 @@ ExternalFunctionNode::maxExoLead() const
return val; return val;
} }
int
ExternalFunctionNode::maxEndoLag() const
{
int val = 0;
for (vector<NodeID>::const_iterator it = arguments.begin();
it != arguments.end(); it++)
val = max(val, (*it)->maxEndoLag());
return val;
}
int
ExternalFunctionNode::maxExoLag() const
{
int val = 0;
for (vector<NodeID>::const_iterator it = arguments.begin();
it != arguments.end(); it++)
val = max(val, (*it)->maxExoLag());
return val;
}
NodeID NodeID
ExternalFunctionNode::decreaseLeadsLags(int n) const ExternalFunctionNode::decreaseLeadsLags(int n) const
{ {

View File

@ -261,6 +261,14 @@ public:
/*! Always returns a non-negative value */ /*! Always returns a non-negative value */
virtual int maxExoLead() const = 0; virtual int maxExoLead() const = 0;
//! Returns the maximum lag of endogenous in this expression
/*! Always returns a non-negative value */
virtual int maxEndoLag() const = 0;
//! Returns the maximum lag of exogenous in this expression
/*! Always returns a non-negative value */
virtual int maxExoLag() const = 0;
//! Returns a new expression where all the leads/lags have been shifted backwards by the same amount //! Returns a new expression where all the leads/lags have been shifted backwards by the same amount
/*! /*!
Only acts on endogenous, exogenous, exogenous det Only acts on endogenous, exogenous, exogenous det
@ -387,6 +395,8 @@ public:
virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables); virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
virtual int maxEndoLead() const; virtual int maxEndoLead() const;
virtual int maxExoLead() const; virtual int maxExoLead() const;
virtual int maxEndoLag() const;
virtual int maxExoLag() const;
virtual NodeID decreaseLeadsLags(int n) const; virtual NodeID decreaseLeadsLags(int n) const;
virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const; virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
virtual NodeID substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const; virtual NodeID substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
@ -405,6 +415,7 @@ private:
//! Id from the symbol table //! Id from the symbol table
const int symb_id; const int symb_id;
const SymbolType type; const SymbolType type;
//! A positive value is a lead, a negative is a lag
const int lag; const int lag;
virtual NodeID computeDerivative(int deriv_id); virtual NodeID computeDerivative(int deriv_id);
public: public:
@ -432,6 +443,8 @@ public:
virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables); virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
virtual int maxEndoLead() const; virtual int maxEndoLead() const;
virtual int maxExoLead() const; virtual int maxExoLead() const;
virtual int maxEndoLag() const;
virtual int maxExoLag() const;
virtual NodeID decreaseLeadsLags(int n) const; virtual NodeID decreaseLeadsLags(int n) const;
virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const; virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
virtual NodeID substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const; virtual NodeID substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
@ -493,6 +506,8 @@ public:
virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables); virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
virtual int maxEndoLead() const; virtual int maxEndoLead() const;
virtual int maxExoLead() const; virtual int maxExoLead() const;
virtual int maxEndoLag() const;
virtual int maxExoLag() const;
virtual NodeID decreaseLeadsLags(int n) const; virtual NodeID decreaseLeadsLags(int n) const;
virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const; virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
//! Creates another UnaryOpNode with the same opcode, but with a possibly different datatree and argument //! Creates another UnaryOpNode with the same opcode, but with a possibly different datatree and argument
@ -561,6 +576,8 @@ public:
virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables); virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
virtual int maxEndoLead() const; virtual int maxEndoLead() const;
virtual int maxExoLead() const; virtual int maxExoLead() const;
virtual int maxEndoLag() const;
virtual int maxExoLag() const;
virtual NodeID decreaseLeadsLags(int n) const; virtual NodeID decreaseLeadsLags(int n) const;
virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const; virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
//! Creates another BinaryOpNode with the same opcode, but with a possibly different datatree and arguments //! Creates another BinaryOpNode with the same opcode, but with a possibly different datatree and arguments
@ -611,6 +628,8 @@ public:
virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables); virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
virtual int maxEndoLead() const; virtual int maxEndoLead() const;
virtual int maxExoLead() const; virtual int maxExoLead() const;
virtual int maxEndoLag() const;
virtual int maxExoLag() const;
virtual NodeID decreaseLeadsLags(int n) const; virtual NodeID decreaseLeadsLags(int n) const;
virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const; virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
//! Creates another TrinaryOpNode with the same opcode, but with a possibly different datatree and arguments //! Creates another TrinaryOpNode with the same opcode, but with a possibly different datatree and arguments
@ -667,6 +686,8 @@ public:
virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables); virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
virtual int maxEndoLead() const; virtual int maxEndoLead() const;
virtual int maxExoLead() const; virtual int maxExoLead() const;
virtual int maxEndoLag() const;
virtual int maxExoLag() const;
virtual NodeID decreaseLeadsLags(int n) const; virtual NodeID decreaseLeadsLags(int n) const;
virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const; virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
virtual NodeID substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const; virtual NodeID substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;