Preprocessor: changes to code for chain rule derivation

* fixed a bug in the handling of VariableNode: we now make a copy of the recursive_variables map, instead of modifying that of the caller
* factorized code shared with standard derivation
* various minor cleanups


git-svn-id: https://www.dynare.org/svn/dynare/trunk@2811 ac1d8469-bf42-47a9-8791-bf33cf982152
time-shift
sebastien 2009-07-06 09:34:21 +00:00
parent 05189497a5
commit 1cde972d91
3 changed files with 68 additions and 267 deletions

View File

@ -664,9 +664,7 @@ end:
//cout << "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"; //cout << "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n";
//cout << "derivaive eq=" << eq << " var=" << var << " k0=" << k << "\n"; //cout << "derivaive eq=" << eq << " var=" << var << " k0=" << k << "\n";
int deriv_id = getDerivID(symbol_table.getID(eEndogenous, var),0); int deriv_id = getDerivID(symbol_table.getID(eEndogenous, var),0);
map<int, NodeID> recursive_variables_save(recursive_variables); NodeID ChaineRule_Derivative = tmp_n->getChainRuleDerivative(deriv_id, recursive_variables);
NodeID ChaineRule_Derivative = tmp_n->getChaineRuleDerivative(deriv_id ,recursive_variables, var, 0);
recursive_variables = recursive_variables_save;
ChaineRule_Derivative->writeOutput(output, oMatlabDynamicModelSparse, temporary_terms); ChaineRule_Derivative->writeOutput(output, oMatlabDynamicModelSparse, temporary_terms);
output << ";"; output << ";";
output << " %2 variable=" << symbol_table.getName(symbol_table.getID(eEndogenous, var)) output << " %2 variable=" << symbol_table.getName(symbol_table.getID(eEndogenous, var))

View File

@ -60,32 +60,6 @@ ExprNode::getDerivative(int deriv_id)
} }
} }
NodeID
ExprNode::getChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_)
{
// Return zero if derivative is necessarily null (using symbolic a priori)
/*set<int>::const_iterator it = non_null_derivatives.find(deriv_id);
if (it == non_null_derivatives.end())
{
cout << "0\n";
return datatree.Zero;
}
*/
// If derivative is stored in cache, use the cached value, otherwise compute it (and cache it)
/*map<int, NodeID>::const_iterator it2 = derivatives.find(deriv_id);
if (it2 != derivatives.end())
return it2->second;
else*/
{
NodeID d = computeChaineRuleDerivative(deriv_id, recursive_variables, var, lag_);
//derivatives[deriv_id] = d;
return d;
}
}
int int
ExprNode::precedence(ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const ExprNode::precedence(ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const
{ {
@ -213,20 +187,11 @@ NumConstNode::normalizeLinearInEndoEquation(int var_endo, NodeID Derivative) con
} }
NodeID NodeID
NumConstNode::computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_) NumConstNode::getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables)
{ {
return datatree.Zero; return datatree.Zero;
} }
/*
pair<bool, NodeID>
NumConstNode::computeDerivativeRespectToFeedbackVariable(int equ, int var, int varr, int lag_, int max_lag, vector<int> &recursive_variables, vector<int> &feeback_variables) const
{
return(make_pair(false, datatree.Zero));
}
*/
NodeID NodeID
NumConstNode::toStatic(DataTree &static_datatree) const NumConstNode::toStatic(DataTree &static_datatree) const
{ {
@ -606,7 +571,7 @@ pair<InputIterator, int> find_r ( InputIterator first, InputIterator last, const
NodeID NodeID
VariableNode::computeChaineRuleDerivative(int deriv_id_arg, map<int, NodeID> &recursive_variables, int var, int lag_) VariableNode::getChainRuleDerivative(int deriv_id_arg, const map<int, NodeID> &recursive_variables)
{ {
switch (type) switch (type)
{ {
@ -614,28 +579,23 @@ VariableNode::computeChaineRuleDerivative(int deriv_id_arg, map<int, NodeID> &re
case eExogenous: case eExogenous:
case eExogenousDet: case eExogenousDet:
case eParameter: case eParameter:
//cout << "deriv_id=" << deriv_id << " deriv_id_arg=" << deriv_id_arg << " symb_id=" << symb_id << " type=" << type << " lag=" << lag << " var=" << var << " lag_ = " << lag_ << "\n";
if (deriv_id == deriv_id_arg) if (deriv_id == deriv_id_arg)
return datatree.One; return datatree.One;
else else
{ {
//if there is in the equation a recursive variable we could use a chaine rule derivation //if there is in the equation a recursive variable we could use a chaine rule derivation
if(lag == lag_) map<int, NodeID>::const_iterator it = recursive_variables.find(deriv_id);
if (it != recursive_variables.end())
{ {
map<int, NodeID>::const_iterator it = recursive_variables.find(deriv_id); map<int, NodeID> recursive_vars2(recursive_variables);
if (it != recursive_variables.end()) recursive_vars2.erase(it->first);
{ return datatree.AddUMinus(it->second->getChainRuleDerivative(deriv_id_arg, recursive_vars2));
recursive_variables.erase(it->first);
return datatree.AddUMinus(it->second->getChaineRuleDerivative(deriv_id_arg, recursive_variables, var, lag_));
}
else
return datatree.Zero;
} }
else else
return datatree.Zero; return datatree.Zero;
} }
case eModelLocalVariable: case eModelLocalVariable:
return datatree.local_variables_table[symb_id]->getChaineRuleDerivative(deriv_id_arg, recursive_variables, var, lag_); return datatree.local_variables_table[symb_id]->getChainRuleDerivative(deriv_id_arg, recursive_variables);
case eModFileLocalVariable: case eModFileLocalVariable:
cerr << "ModFileLocalVariable is not derivable" << endl; cerr << "ModFileLocalVariable is not derivable" << endl;
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
@ -669,10 +629,8 @@ UnaryOpNode::UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const
} }
NodeID NodeID
UnaryOpNode::computeDerivative(int deriv_id) UnaryOpNode::composeDerivatives(NodeID darg)
{ {
NodeID darg = arg->getDerivative(deriv_id);
NodeID t11, t12, t13; NodeID t11, t12, t13;
switch (op_code) switch (op_code)
@ -738,6 +696,13 @@ UnaryOpNode::computeDerivative(int deriv_id)
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
NodeID
UnaryOpNode::computeDerivative(int deriv_id)
{
NodeID darg = arg->getDerivative(deriv_id);
return composeDerivatives(darg);
}
int int
UnaryOpNode::cost(const temporary_terms_type &temporary_terms, bool is_matlab) const UnaryOpNode::cost(const temporary_terms_type &temporary_terms, bool is_matlab) const
{ {
@ -1117,76 +1082,12 @@ UnaryOpNode::normalizeLinearInEndoEquation(int var_endo, NodeID Derivative) cons
NodeID NodeID
UnaryOpNode::computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_) UnaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables)
{ {
NodeID darg = arg->getChaineRuleDerivative(deriv_id, recursive_variables, var, lag_); NodeID darg = arg->getChainRuleDerivative(deriv_id, recursive_variables);
return composeDerivatives(darg);
NodeID t11, t12, t13;
switch (op_code)
{
case oUminus:
return datatree.AddUMinus(darg);
case oExp:
return datatree.AddTimes(darg, this);
case oLog:
return datatree.AddDivide(darg, arg);
case oLog10:
t11 = datatree.AddExp(datatree.One);
t12 = datatree.AddLog10(t11);
t13 = datatree.AddDivide(darg, arg);
return datatree.AddTimes(t12, t13);
case oCos:
t11 = datatree.AddSin(arg);
t12 = datatree.AddUMinus(t11);
return datatree.AddTimes(darg, t12);
case oSin:
t11 = datatree.AddCos(arg);
return datatree.AddTimes(darg, t11);
case oTan:
t11 = datatree.AddTimes(this, this);
t12 = datatree.AddPlus(t11, datatree.One);
return datatree.AddTimes(darg, t12);
case oAcos:
t11 = datatree.AddSin(this);
t12 = datatree.AddDivide(darg, t11);
return datatree.AddUMinus(t12);
case oAsin:
t11 = datatree.AddCos(this);
return datatree.AddDivide(darg, t11);
case oAtan:
t11 = datatree.AddTimes(arg, arg);
t12 = datatree.AddPlus(datatree.One, t11);
return datatree.AddDivide(darg, t12);
case oCosh:
t11 = datatree.AddSinh(arg);
return datatree.AddTimes(darg, t11);
case oSinh:
t11 = datatree.AddCosh(arg);
return datatree.AddTimes(darg, t11);
case oTanh:
t11 = datatree.AddTimes(this, this);
t12 = datatree.AddMinus(datatree.One, t11);
return datatree.AddTimes(darg, t12);
case oAcosh:
t11 = datatree.AddSinh(this);
return datatree.AddDivide(darg, t11);
case oAsinh:
t11 = datatree.AddCosh(this);
return datatree.AddDivide(darg, t11);
case oAtanh:
t11 = datatree.AddTimes(arg, arg);
t12 = datatree.AddMinus(datatree.One, t11);
return datatree.AddTimes(darg, t12);
case oSqrt:
t11 = datatree.AddPlus(this, this);
return datatree.AddDivide(darg, t11);
}
// Suppress GCC warning
exit(EXIT_FAILURE);
} }
NodeID NodeID
UnaryOpNode::toStatic(DataTree &static_datatree) const UnaryOpNode::toStatic(DataTree &static_datatree) const
{ {
@ -1252,11 +1153,8 @@ BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg,
} }
NodeID NodeID
BinaryOpNode::computeDerivative(int deriv_id) BinaryOpNode::composeDerivatives(NodeID darg1, NodeID darg2)
{ {
NodeID darg1 = arg1->getDerivative(deriv_id);
NodeID darg2 = arg2->getDerivative(deriv_id);
NodeID t11, t12, t13, t14, t15; NodeID t11, t12, t13, t14, t15;
switch (op_code) switch (op_code)
@ -1328,6 +1226,14 @@ BinaryOpNode::computeDerivative(int deriv_id)
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
NodeID
BinaryOpNode::computeDerivative(int deriv_id)
{
NodeID darg1 = arg1->getDerivative(deriv_id);
NodeID darg2 = arg2->getDerivative(deriv_id);
return composeDerivatives(darg1, darg2);
}
int int
BinaryOpNode::precedence(ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const BinaryOpNode::precedence(ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const
{ {
@ -1880,80 +1786,11 @@ BinaryOpNode::normalizeLinearInEndoEquation(int var_endo, NodeID Derivative) con
NodeID NodeID
BinaryOpNode::computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_) BinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables)
{ {
NodeID darg1 = arg1->getChaineRuleDerivative(deriv_id, recursive_variables, var, lag_); NodeID darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables);
NodeID darg2 = arg2->getChaineRuleDerivative(deriv_id, recursive_variables, var, lag_); NodeID darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables);
return composeDerivatives(darg1, darg2);
NodeID t11, t12, t13, t14, t15;
switch (op_code)
{
case oPlus:
return datatree.AddPlus(darg1, darg2);
case oMinus:
return datatree.AddMinus(darg1, darg2);
case oTimes:
t11 = datatree.AddTimes(darg1, arg2);
t12 = datatree.AddTimes(darg2, arg1);
return datatree.AddPlus(t11, t12);
case oDivide:
if (darg2!=datatree.Zero)
{
t11 = datatree.AddTimes(darg1, arg2);
t12 = datatree.AddTimes(darg2, arg1);
t13 = datatree.AddMinus(t11, t12);
t14 = datatree.AddTimes(arg2, arg2);
return datatree.AddDivide(t13, t14);
}
else
return datatree.AddDivide(darg1, arg2);
case oLess:
case oGreater:
case oLessEqual:
case oGreaterEqual:
case oEqualEqual:
case oDifferent:
return datatree.Zero;
case oPower:
if (darg2 == datatree.Zero)
{
if (darg1 == datatree.Zero)
return datatree.Zero;
else
{
t11 = datatree.AddMinus(arg2, datatree.One);
t12 = datatree.AddPower(arg1, t11);
t13 = datatree.AddTimes(arg2, t12);
return datatree.AddTimes(darg1, t13);
}
}
else
{
t11 = datatree.AddLog(arg1);
t12 = datatree.AddTimes(darg2, t11);
t13 = datatree.AddTimes(darg1, arg2);
t14 = datatree.AddDivide(t13, arg1);
t15 = datatree.AddPlus(t12, t14);
return datatree.AddTimes(t15, this);
}
case oMax:
t11 = datatree.AddGreater(arg1,arg2);
t12 = datatree.AddTimes(t11,darg1);
t13 = datatree.AddMinus(datatree.One,t11);
t14 = datatree.AddTimes(t13,darg2);
return datatree.AddPlus(t14,t12);
case oMin:
t11 = datatree.AddGreater(arg2,arg1);
t12 = datatree.AddTimes(t11,darg1);
t13 = datatree.AddMinus(datatree.One,t11);
t14 = datatree.AddTimes(t13,darg2);
return datatree.AddPlus(t14,t12);
case oEqual:
return datatree.AddMinus(darg1, darg2);
}
// Suppress GCC warning
exit(EXIT_FAILURE);
} }
NodeID NodeID
@ -2023,11 +1860,8 @@ TrinaryOpNode::TrinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg,
} }
NodeID NodeID
TrinaryOpNode::computeDerivative(int deriv_id) TrinaryOpNode::composeDerivatives(NodeID darg1, NodeID darg2, NodeID darg3)
{ {
NodeID darg1 = arg1->getDerivative(deriv_id);
NodeID darg2 = arg2->getDerivative(deriv_id);
NodeID darg3 = arg3->getDerivative(deriv_id);
NodeID t11, t12, t13, t14, t15; NodeID t11, t12, t13, t14, t15;
@ -2073,6 +1907,15 @@ TrinaryOpNode::computeDerivative(int deriv_id)
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
NodeID
TrinaryOpNode::computeDerivative(int deriv_id)
{
NodeID darg1 = arg1->getDerivative(deriv_id);
NodeID darg2 = arg2->getDerivative(deriv_id);
NodeID darg3 = arg3->getDerivative(deriv_id);
return composeDerivatives(darg1, darg2, darg3);
}
int int
TrinaryOpNode::precedence(ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const TrinaryOpNode::precedence(ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const
{ {
@ -2297,54 +2140,12 @@ TrinaryOpNode::normalizeLinearInEndoEquation(int var_endo, NodeID Derivative) co
} }
NodeID NodeID
TrinaryOpNode::computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_) TrinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables)
{ {
NodeID darg1 = arg1->getChaineRuleDerivative(deriv_id, recursive_variables, var, lag_); NodeID darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables);
NodeID darg2 = arg2->getChaineRuleDerivative(deriv_id, recursive_variables, var, lag_); NodeID darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables);
NodeID darg3 = arg3->getChaineRuleDerivative(deriv_id, recursive_variables, var, lag_); NodeID darg3 = arg3->getChainRuleDerivative(deriv_id, recursive_variables);
return composeDerivatives(darg1, darg2, darg3);
NodeID t11, t12, t13, t14, t15;
switch (op_code)
{
case oNormcdf:
// normal pdf is inlined in the tree
NodeID y;
// sqrt(2*pi)
t14 = datatree.AddSqrt(datatree.AddTimes(datatree.Two, datatree.Pi));
// x - mu
t12 = datatree.AddMinus(arg1,arg2);
// y = (x-mu)/sigma
y = datatree.AddDivide(t12,arg3);
// (x-mu)^2/sigma^2
t12 = datatree.AddTimes(y,y);
// -(x-mu)^2/sigma^2
t13 = datatree.AddUMinus(t12);
// -((x-mu)^2/sigma^2)/2
t12 = datatree.AddDivide(t13, datatree.Two);
// exp(-((x-mu)^2/sigma^2)/2)
t13 = datatree.AddExp(t12);
// derivative of a standardized normal
// t15 = (1/sqrt(2*pi))*exp(-y^2/2)
t15 = datatree.AddDivide(t13,t14);
// derivatives thru x
t11 = datatree.AddDivide(darg1,arg3);
// derivatives thru mu
t12 = datatree.AddDivide(darg2,arg3);
// intermediary sum
t14 = datatree.AddMinus(t11,t12);
// derivatives thru sigma
t11 = datatree.AddDivide(y,arg3);
t12 = datatree.AddTimes(t11,darg3);
//intermediary sum
t11 = datatree.AddMinus(t14,t12);
// total derivative:
// (darg1/sigma - darg2/sigma - darg3*(x-mu)/sigma^2) * t15
// where t15 is the derivative of a standardized normal
return datatree.AddTimes(t11, t15);
}
// Suppress GCC warning
exit(EXIT_FAILURE);
} }
NodeID NodeID
@ -2380,9 +2181,9 @@ UnknownFunctionNode::computeDerivative(int deriv_id)
} }
NodeID NodeID
UnknownFunctionNode::computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_) UnknownFunctionNode::getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables)
{ {
cerr << "UnknownFunctionNode::computeDerivative: operation impossible!" << endl; cerr << "UnknownFunctionNode::getChainRuleDerivative: operation impossible!" << endl;
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }

View File

@ -109,10 +109,6 @@ private:
//! Computes derivative w.r. to a derivation ID (but doesn't store it in derivatives map) //! Computes derivative w.r. to a derivation ID (but doesn't store it in derivatives map)
/*! You shoud use getDerivative() to get the benefit of symbolic a priori and of caching */ /*! You shoud use getDerivative() to get the benefit of symbolic a priori and of caching */
virtual NodeID computeDerivative(int deriv_id) = 0; virtual NodeID computeDerivative(int deriv_id) = 0;
//! Computes derivative w.r. to a derivation ID and use chaine rule derivatives (but doesn't store it in derivatives map)
/*! You shoud use getDerivative() to get the benefit of symbolic a priori and of caching */
virtual NodeID computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_) = 0;
protected: protected:
//! Reference to the enclosing DataTree //! Reference to the enclosing DataTree
@ -140,9 +136,12 @@ public:
For an equal node, returns the derivative of lhs minus rhs */ For an equal node, returns the derivative of lhs minus rhs */
NodeID getDerivative(int deriv_id); NodeID getDerivative(int deriv_id);
//! Returns derivative w.r. to derivation ID and use if it possible chaine rule derivatives //! Computes derivatives by applying the chain rule for some variables
NodeID getChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_); /*!
\param deriv_id The derivation ID with respect to which we are derivating
\param recursive_variables Contains the derivation ID for which chain rules must be applied. Keys are derivation IDs, values are equations of the form x=f(y) where x is the key variable and x doesn't appear in y
*/
virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables) = 0;
//! Returns precedence of node //! Returns precedence of node
/*! Equals 100 for constants, variables, unary ops, and temporary terms */ /*! Equals 100 for constants, variables, unary ops, and temporary terms */
@ -215,7 +214,6 @@ private:
//! Id from numerical constants table //! Id from numerical constants table
const int id; const int id;
virtual NodeID computeDerivative(int deriv_id); virtual NodeID computeDerivative(int deriv_id);
virtual NodeID computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_);
public: public:
NumConstNode(DataTree &datatree_arg, int id_arg); NumConstNode(DataTree &datatree_arg, int id_arg);
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const; virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;
@ -226,6 +224,7 @@ public:
virtual void compile(ofstream &CompileCode, bool lhs_rhs, const temporary_terms_type &temporary_terms, map_idx_type &map_idx) const; virtual void compile(ofstream &CompileCode, bool lhs_rhs, const temporary_terms_type &temporary_terms, map_idx_type &map_idx) const;
virtual NodeID toStatic(DataTree &static_datatree) const; virtual NodeID toStatic(DataTree &static_datatree) const;
virtual pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const; virtual pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const;
virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
}; };
//! Symbol or variable node //! Symbol or variable node
@ -239,7 +238,6 @@ private:
//! Derivation ID //! Derivation ID
const int deriv_id; const int deriv_id;
virtual NodeID computeDerivative(int deriv_id_arg); virtual NodeID computeDerivative(int deriv_id_arg);
virtual NodeID computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_);
public: public:
VariableNode(DataTree &datatree_arg, int symb_id_arg, int lag_arg, int deriv_id_arg); VariableNode(DataTree &datatree_arg, int symb_id_arg, int lag_arg, int deriv_id_arg);
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms = temporary_terms_type()) const; virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms = temporary_terms_type()) const;
@ -258,6 +256,7 @@ public:
virtual NodeID toStatic(DataTree &static_datatree) const; virtual NodeID toStatic(DataTree &static_datatree) const;
int get_symb_id() const { return symb_id; }; int get_symb_id() const { return symb_id; };
virtual pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const; virtual pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const;
virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
}; };
//! Unary operator node //! Unary operator node
@ -267,9 +266,9 @@ private:
const NodeID arg; const NodeID arg;
const UnaryOpcode op_code; const UnaryOpcode op_code;
virtual NodeID computeDerivative(int deriv_id); virtual NodeID computeDerivative(int deriv_id);
virtual NodeID computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_);
virtual int cost(const temporary_terms_type &temporary_terms, bool is_matlab) const; virtual int cost(const temporary_terms_type &temporary_terms, bool is_matlab) const;
//! Returns the derivative of this node if darg is the derivative of the argument
NodeID composeDerivatives(NodeID darg);
public: public:
UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const NodeID arg_arg); UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const NodeID arg_arg);
virtual void computeTemporaryTerms(map<NodeID, int> &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const; virtual void computeTemporaryTerms(map<NodeID, int> &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const;
@ -293,6 +292,7 @@ public:
UnaryOpcode get_op_code() const { return(op_code); }; UnaryOpcode get_op_code() const { return(op_code); };
virtual NodeID toStatic(DataTree &static_datatree) const; virtual NodeID toStatic(DataTree &static_datatree) const;
virtual pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const; virtual pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const;
virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
}; };
//! Binary operator node //! Binary operator node
@ -302,9 +302,9 @@ private:
const NodeID arg1, arg2; const NodeID arg1, arg2;
const BinaryOpcode op_code; const BinaryOpcode op_code;
virtual NodeID computeDerivative(int deriv_id); virtual NodeID computeDerivative(int deriv_id);
virtual NodeID computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_);
virtual int cost(const temporary_terms_type &temporary_terms, bool is_matlab) const; virtual int cost(const temporary_terms_type &temporary_terms, bool is_matlab) const;
//! Returns the derivative of this node if darg1 and darg2 are the derivatives of the arguments
NodeID composeDerivatives(NodeID darg1, NodeID darg2);
public: public:
BinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg, BinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg,
BinaryOpcode op_code_arg, const NodeID arg2_arg); BinaryOpcode op_code_arg, const NodeID arg2_arg);
@ -332,6 +332,7 @@ public:
BinaryOpcode get_op_code() const { return(op_code); }; BinaryOpcode get_op_code() const { return(op_code); };
virtual NodeID toStatic(DataTree &static_datatree) const; virtual NodeID toStatic(DataTree &static_datatree) const;
pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const; pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const;
virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
}; };
//! Trinary operator node //! Trinary operator node
@ -342,9 +343,9 @@ private:
const NodeID arg1, arg2, arg3; const NodeID arg1, arg2, arg3;
const TrinaryOpcode op_code; const TrinaryOpcode op_code;
virtual NodeID computeDerivative(int deriv_id); virtual NodeID computeDerivative(int deriv_id);
virtual NodeID computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_);
virtual int cost(const temporary_terms_type &temporary_terms, bool is_matlab) const; virtual int cost(const temporary_terms_type &temporary_terms, bool is_matlab) const;
//! Returns the derivative of this node if darg1, darg2 and darg3 are the derivatives of the arguments
NodeID composeDerivatives(NodeID darg1, NodeID darg2, NodeID darg3);
public: public:
TrinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg, TrinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg,
TrinaryOpcode op_code_arg, const NodeID arg2_arg, const NodeID arg3_arg); TrinaryOpcode op_code_arg, const NodeID arg2_arg, const NodeID arg3_arg);
@ -366,6 +367,7 @@ public:
virtual void compile(ofstream &CompileCode, bool lhs_rhs, const temporary_terms_type &temporary_terms, map_idx_type &map_idx) const; virtual void compile(ofstream &CompileCode, bool lhs_rhs, const temporary_terms_type &temporary_terms, map_idx_type &map_idx) const;
virtual NodeID toStatic(DataTree &static_datatree) const; virtual NodeID toStatic(DataTree &static_datatree) const;
virtual pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const; virtual pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const;
virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
}; };
//! Unknown function node //! Unknown function node
@ -375,7 +377,6 @@ private:
const int symb_id; const int symb_id;
const vector<NodeID> arguments; const vector<NodeID> arguments;
virtual NodeID computeDerivative(int deriv_id); virtual NodeID computeDerivative(int deriv_id);
virtual NodeID computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_);
public: public:
UnknownFunctionNode(DataTree &datatree_arg, int symb_id_arg, UnknownFunctionNode(DataTree &datatree_arg, int symb_id_arg,
const vector<NodeID> &arguments_arg); const vector<NodeID> &arguments_arg);
@ -395,6 +396,7 @@ public:
virtual void compile(ofstream &CompileCode, bool lhs_rhs, const temporary_terms_type &temporary_terms, map_idx_type &map_idx) const; virtual void compile(ofstream &CompileCode, bool lhs_rhs, const temporary_terms_type &temporary_terms, map_idx_type &map_idx) const;
virtual NodeID toStatic(DataTree &static_datatree) const; virtual NodeID toStatic(DataTree &static_datatree) const;
virtual pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const; virtual pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const;
virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
}; };
//! For one lead/lag of one block, stores mapping of information between original model and block-decomposed model //! For one lead/lag of one block, stores mapping of information between original model and block-decomposed model