Preprocessor: changes to code for chain rule derivation
* fixed a bug in the handling of VariableNode: we now make a copy of the recursive_variables map, instead of modifying that of the caller * factorized code shared with standard derivation * various minor cleanups git-svn-id: https://www.dynare.org/svn/dynare/trunk@2811 ac1d8469-bf42-47a9-8791-bf33cf982152time-shift
parent
05189497a5
commit
1cde972d91
|
@ -664,9 +664,7 @@ end:
|
|||
//cout << "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n";
|
||||
//cout << "derivaive eq=" << eq << " var=" << var << " k0=" << k << "\n";
|
||||
int deriv_id = getDerivID(symbol_table.getID(eEndogenous, var),0);
|
||||
map<int, NodeID> recursive_variables_save(recursive_variables);
|
||||
NodeID ChaineRule_Derivative = tmp_n->getChaineRuleDerivative(deriv_id ,recursive_variables, var, 0);
|
||||
recursive_variables = recursive_variables_save;
|
||||
NodeID ChaineRule_Derivative = tmp_n->getChainRuleDerivative(deriv_id, recursive_variables);
|
||||
ChaineRule_Derivative->writeOutput(output, oMatlabDynamicModelSparse, temporary_terms);
|
||||
output << ";";
|
||||
output << " %2 variable=" << symbol_table.getName(symbol_table.getID(eEndogenous, var))
|
||||
|
|
|
@ -60,32 +60,6 @@ ExprNode::getDerivative(int deriv_id)
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
NodeID
|
||||
ExprNode::getChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_)
|
||||
{
|
||||
// Return zero if derivative is necessarily null (using symbolic a priori)
|
||||
/*set<int>::const_iterator it = non_null_derivatives.find(deriv_id);
|
||||
if (it == non_null_derivatives.end())
|
||||
{
|
||||
cout << "0\n";
|
||||
return datatree.Zero;
|
||||
}
|
||||
*/
|
||||
|
||||
// If derivative is stored in cache, use the cached value, otherwise compute it (and cache it)
|
||||
/*map<int, NodeID>::const_iterator it2 = derivatives.find(deriv_id);
|
||||
if (it2 != derivatives.end())
|
||||
return it2->second;
|
||||
else*/
|
||||
{
|
||||
NodeID d = computeChaineRuleDerivative(deriv_id, recursive_variables, var, lag_);
|
||||
//derivatives[deriv_id] = d;
|
||||
return d;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
int
|
||||
ExprNode::precedence(ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const
|
||||
{
|
||||
|
@ -213,20 +187,11 @@ NumConstNode::normalizeLinearInEndoEquation(int var_endo, NodeID Derivative) con
|
|||
}
|
||||
|
||||
NodeID
|
||||
NumConstNode::computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_)
|
||||
NumConstNode::getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables)
|
||||
{
|
||||
return datatree.Zero;
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
pair<bool, NodeID>
|
||||
NumConstNode::computeDerivativeRespectToFeedbackVariable(int equ, int var, int varr, int lag_, int max_lag, vector<int> &recursive_variables, vector<int> &feeback_variables) const
|
||||
{
|
||||
return(make_pair(false, datatree.Zero));
|
||||
}
|
||||
*/
|
||||
|
||||
NodeID
|
||||
NumConstNode::toStatic(DataTree &static_datatree) const
|
||||
{
|
||||
|
@ -606,7 +571,7 @@ pair<InputIterator, int> find_r ( InputIterator first, InputIterator last, const
|
|||
|
||||
|
||||
NodeID
|
||||
VariableNode::computeChaineRuleDerivative(int deriv_id_arg, map<int, NodeID> &recursive_variables, int var, int lag_)
|
||||
VariableNode::getChainRuleDerivative(int deriv_id_arg, const map<int, NodeID> &recursive_variables)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
|
@ -614,28 +579,23 @@ VariableNode::computeChaineRuleDerivative(int deriv_id_arg, map<int, NodeID> &re
|
|||
case eExogenous:
|
||||
case eExogenousDet:
|
||||
case eParameter:
|
||||
//cout << "deriv_id=" << deriv_id << " deriv_id_arg=" << deriv_id_arg << " symb_id=" << symb_id << " type=" << type << " lag=" << lag << " var=" << var << " lag_ = " << lag_ << "\n";
|
||||
if (deriv_id == deriv_id_arg)
|
||||
return datatree.One;
|
||||
else
|
||||
{
|
||||
//if there is in the equation a recursive variable we could use a chaine rule derivation
|
||||
if(lag == lag_)
|
||||
{
|
||||
map<int, NodeID>::const_iterator it = recursive_variables.find(deriv_id);
|
||||
if (it != recursive_variables.end())
|
||||
{
|
||||
recursive_variables.erase(it->first);
|
||||
return datatree.AddUMinus(it->second->getChaineRuleDerivative(deriv_id_arg, recursive_variables, var, lag_));
|
||||
}
|
||||
else
|
||||
return datatree.Zero;
|
||||
map<int, NodeID> recursive_vars2(recursive_variables);
|
||||
recursive_vars2.erase(it->first);
|
||||
return datatree.AddUMinus(it->second->getChainRuleDerivative(deriv_id_arg, recursive_vars2));
|
||||
}
|
||||
else
|
||||
return datatree.Zero;
|
||||
}
|
||||
case eModelLocalVariable:
|
||||
return datatree.local_variables_table[symb_id]->getChaineRuleDerivative(deriv_id_arg, recursive_variables, var, lag_);
|
||||
return datatree.local_variables_table[symb_id]->getChainRuleDerivative(deriv_id_arg, recursive_variables);
|
||||
case eModFileLocalVariable:
|
||||
cerr << "ModFileLocalVariable is not derivable" << endl;
|
||||
exit(EXIT_FAILURE);
|
||||
|
@ -669,10 +629,8 @@ UnaryOpNode::UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const
|
|||
}
|
||||
|
||||
NodeID
|
||||
UnaryOpNode::computeDerivative(int deriv_id)
|
||||
UnaryOpNode::composeDerivatives(NodeID darg)
|
||||
{
|
||||
NodeID darg = arg->getDerivative(deriv_id);
|
||||
|
||||
NodeID t11, t12, t13;
|
||||
|
||||
switch (op_code)
|
||||
|
@ -738,6 +696,13 @@ UnaryOpNode::computeDerivative(int deriv_id)
|
|||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
NodeID
|
||||
UnaryOpNode::computeDerivative(int deriv_id)
|
||||
{
|
||||
NodeID darg = arg->getDerivative(deriv_id);
|
||||
return composeDerivatives(darg);
|
||||
}
|
||||
|
||||
int
|
||||
UnaryOpNode::cost(const temporary_terms_type &temporary_terms, bool is_matlab) const
|
||||
{
|
||||
|
@ -1117,76 +1082,12 @@ UnaryOpNode::normalizeLinearInEndoEquation(int var_endo, NodeID Derivative) cons
|
|||
|
||||
|
||||
NodeID
|
||||
UnaryOpNode::computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_)
|
||||
UnaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables)
|
||||
{
|
||||
NodeID darg = arg->getChaineRuleDerivative(deriv_id, recursive_variables, var, lag_);
|
||||
|
||||
NodeID t11, t12, t13;
|
||||
|
||||
switch (op_code)
|
||||
{
|
||||
case oUminus:
|
||||
return datatree.AddUMinus(darg);
|
||||
case oExp:
|
||||
return datatree.AddTimes(darg, this);
|
||||
case oLog:
|
||||
return datatree.AddDivide(darg, arg);
|
||||
case oLog10:
|
||||
t11 = datatree.AddExp(datatree.One);
|
||||
t12 = datatree.AddLog10(t11);
|
||||
t13 = datatree.AddDivide(darg, arg);
|
||||
return datatree.AddTimes(t12, t13);
|
||||
case oCos:
|
||||
t11 = datatree.AddSin(arg);
|
||||
t12 = datatree.AddUMinus(t11);
|
||||
return datatree.AddTimes(darg, t12);
|
||||
case oSin:
|
||||
t11 = datatree.AddCos(arg);
|
||||
return datatree.AddTimes(darg, t11);
|
||||
case oTan:
|
||||
t11 = datatree.AddTimes(this, this);
|
||||
t12 = datatree.AddPlus(t11, datatree.One);
|
||||
return datatree.AddTimes(darg, t12);
|
||||
case oAcos:
|
||||
t11 = datatree.AddSin(this);
|
||||
t12 = datatree.AddDivide(darg, t11);
|
||||
return datatree.AddUMinus(t12);
|
||||
case oAsin:
|
||||
t11 = datatree.AddCos(this);
|
||||
return datatree.AddDivide(darg, t11);
|
||||
case oAtan:
|
||||
t11 = datatree.AddTimes(arg, arg);
|
||||
t12 = datatree.AddPlus(datatree.One, t11);
|
||||
return datatree.AddDivide(darg, t12);
|
||||
case oCosh:
|
||||
t11 = datatree.AddSinh(arg);
|
||||
return datatree.AddTimes(darg, t11);
|
||||
case oSinh:
|
||||
t11 = datatree.AddCosh(arg);
|
||||
return datatree.AddTimes(darg, t11);
|
||||
case oTanh:
|
||||
t11 = datatree.AddTimes(this, this);
|
||||
t12 = datatree.AddMinus(datatree.One, t11);
|
||||
return datatree.AddTimes(darg, t12);
|
||||
case oAcosh:
|
||||
t11 = datatree.AddSinh(this);
|
||||
return datatree.AddDivide(darg, t11);
|
||||
case oAsinh:
|
||||
t11 = datatree.AddCosh(this);
|
||||
return datatree.AddDivide(darg, t11);
|
||||
case oAtanh:
|
||||
t11 = datatree.AddTimes(arg, arg);
|
||||
t12 = datatree.AddMinus(datatree.One, t11);
|
||||
return datatree.AddTimes(darg, t12);
|
||||
case oSqrt:
|
||||
t11 = datatree.AddPlus(this, this);
|
||||
return datatree.AddDivide(darg, t11);
|
||||
}
|
||||
// Suppress GCC warning
|
||||
exit(EXIT_FAILURE);
|
||||
NodeID darg = arg->getChainRuleDerivative(deriv_id, recursive_variables);
|
||||
return composeDerivatives(darg);
|
||||
}
|
||||
|
||||
|
||||
NodeID
|
||||
UnaryOpNode::toStatic(DataTree &static_datatree) const
|
||||
{
|
||||
|
@ -1252,11 +1153,8 @@ BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg,
|
|||
}
|
||||
|
||||
NodeID
|
||||
BinaryOpNode::computeDerivative(int deriv_id)
|
||||
BinaryOpNode::composeDerivatives(NodeID darg1, NodeID darg2)
|
||||
{
|
||||
NodeID darg1 = arg1->getDerivative(deriv_id);
|
||||
NodeID darg2 = arg2->getDerivative(deriv_id);
|
||||
|
||||
NodeID t11, t12, t13, t14, t15;
|
||||
|
||||
switch (op_code)
|
||||
|
@ -1328,6 +1226,14 @@ BinaryOpNode::computeDerivative(int deriv_id)
|
|||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
NodeID
|
||||
BinaryOpNode::computeDerivative(int deriv_id)
|
||||
{
|
||||
NodeID darg1 = arg1->getDerivative(deriv_id);
|
||||
NodeID darg2 = arg2->getDerivative(deriv_id);
|
||||
return composeDerivatives(darg1, darg2);
|
||||
}
|
||||
|
||||
int
|
||||
BinaryOpNode::precedence(ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const
|
||||
{
|
||||
|
@ -1880,80 +1786,11 @@ BinaryOpNode::normalizeLinearInEndoEquation(int var_endo, NodeID Derivative) con
|
|||
|
||||
|
||||
NodeID
|
||||
BinaryOpNode::computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_)
|
||||
BinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables)
|
||||
{
|
||||
NodeID darg1 = arg1->getChaineRuleDerivative(deriv_id, recursive_variables, var, lag_);
|
||||
NodeID darg2 = arg2->getChaineRuleDerivative(deriv_id, recursive_variables, var, lag_);
|
||||
|
||||
NodeID t11, t12, t13, t14, t15;
|
||||
|
||||
switch (op_code)
|
||||
{
|
||||
case oPlus:
|
||||
return datatree.AddPlus(darg1, darg2);
|
||||
case oMinus:
|
||||
return datatree.AddMinus(darg1, darg2);
|
||||
case oTimes:
|
||||
t11 = datatree.AddTimes(darg1, arg2);
|
||||
t12 = datatree.AddTimes(darg2, arg1);
|
||||
return datatree.AddPlus(t11, t12);
|
||||
case oDivide:
|
||||
if (darg2!=datatree.Zero)
|
||||
{
|
||||
t11 = datatree.AddTimes(darg1, arg2);
|
||||
t12 = datatree.AddTimes(darg2, arg1);
|
||||
t13 = datatree.AddMinus(t11, t12);
|
||||
t14 = datatree.AddTimes(arg2, arg2);
|
||||
return datatree.AddDivide(t13, t14);
|
||||
}
|
||||
else
|
||||
return datatree.AddDivide(darg1, arg2);
|
||||
case oLess:
|
||||
case oGreater:
|
||||
case oLessEqual:
|
||||
case oGreaterEqual:
|
||||
case oEqualEqual:
|
||||
case oDifferent:
|
||||
return datatree.Zero;
|
||||
case oPower:
|
||||
if (darg2 == datatree.Zero)
|
||||
{
|
||||
if (darg1 == datatree.Zero)
|
||||
return datatree.Zero;
|
||||
else
|
||||
{
|
||||
t11 = datatree.AddMinus(arg2, datatree.One);
|
||||
t12 = datatree.AddPower(arg1, t11);
|
||||
t13 = datatree.AddTimes(arg2, t12);
|
||||
return datatree.AddTimes(darg1, t13);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
t11 = datatree.AddLog(arg1);
|
||||
t12 = datatree.AddTimes(darg2, t11);
|
||||
t13 = datatree.AddTimes(darg1, arg2);
|
||||
t14 = datatree.AddDivide(t13, arg1);
|
||||
t15 = datatree.AddPlus(t12, t14);
|
||||
return datatree.AddTimes(t15, this);
|
||||
}
|
||||
case oMax:
|
||||
t11 = datatree.AddGreater(arg1,arg2);
|
||||
t12 = datatree.AddTimes(t11,darg1);
|
||||
t13 = datatree.AddMinus(datatree.One,t11);
|
||||
t14 = datatree.AddTimes(t13,darg2);
|
||||
return datatree.AddPlus(t14,t12);
|
||||
case oMin:
|
||||
t11 = datatree.AddGreater(arg2,arg1);
|
||||
t12 = datatree.AddTimes(t11,darg1);
|
||||
t13 = datatree.AddMinus(datatree.One,t11);
|
||||
t14 = datatree.AddTimes(t13,darg2);
|
||||
return datatree.AddPlus(t14,t12);
|
||||
case oEqual:
|
||||
return datatree.AddMinus(darg1, darg2);
|
||||
}
|
||||
// Suppress GCC warning
|
||||
exit(EXIT_FAILURE);
|
||||
NodeID darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables);
|
||||
NodeID darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables);
|
||||
return composeDerivatives(darg1, darg2);
|
||||
}
|
||||
|
||||
NodeID
|
||||
|
@ -2023,11 +1860,8 @@ TrinaryOpNode::TrinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg,
|
|||
}
|
||||
|
||||
NodeID
|
||||
TrinaryOpNode::computeDerivative(int deriv_id)
|
||||
TrinaryOpNode::composeDerivatives(NodeID darg1, NodeID darg2, NodeID darg3)
|
||||
{
|
||||
NodeID darg1 = arg1->getDerivative(deriv_id);
|
||||
NodeID darg2 = arg2->getDerivative(deriv_id);
|
||||
NodeID darg3 = arg3->getDerivative(deriv_id);
|
||||
|
||||
NodeID t11, t12, t13, t14, t15;
|
||||
|
||||
|
@ -2073,6 +1907,15 @@ TrinaryOpNode::computeDerivative(int deriv_id)
|
|||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
NodeID
|
||||
TrinaryOpNode::computeDerivative(int deriv_id)
|
||||
{
|
||||
NodeID darg1 = arg1->getDerivative(deriv_id);
|
||||
NodeID darg2 = arg2->getDerivative(deriv_id);
|
||||
NodeID darg3 = arg3->getDerivative(deriv_id);
|
||||
return composeDerivatives(darg1, darg2, darg3);
|
||||
}
|
||||
|
||||
int
|
||||
TrinaryOpNode::precedence(ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const
|
||||
{
|
||||
|
@ -2297,54 +2140,12 @@ TrinaryOpNode::normalizeLinearInEndoEquation(int var_endo, NodeID Derivative) co
|
|||
}
|
||||
|
||||
NodeID
|
||||
TrinaryOpNode::computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_)
|
||||
TrinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables)
|
||||
{
|
||||
NodeID darg1 = arg1->getChaineRuleDerivative(deriv_id, recursive_variables, var, lag_);
|
||||
NodeID darg2 = arg2->getChaineRuleDerivative(deriv_id, recursive_variables, var, lag_);
|
||||
NodeID darg3 = arg3->getChaineRuleDerivative(deriv_id, recursive_variables, var, lag_);
|
||||
|
||||
NodeID t11, t12, t13, t14, t15;
|
||||
|
||||
switch (op_code)
|
||||
{
|
||||
case oNormcdf:
|
||||
// normal pdf is inlined in the tree
|
||||
NodeID y;
|
||||
// sqrt(2*pi)
|
||||
t14 = datatree.AddSqrt(datatree.AddTimes(datatree.Two, datatree.Pi));
|
||||
// x - mu
|
||||
t12 = datatree.AddMinus(arg1,arg2);
|
||||
// y = (x-mu)/sigma
|
||||
y = datatree.AddDivide(t12,arg3);
|
||||
// (x-mu)^2/sigma^2
|
||||
t12 = datatree.AddTimes(y,y);
|
||||
// -(x-mu)^2/sigma^2
|
||||
t13 = datatree.AddUMinus(t12);
|
||||
// -((x-mu)^2/sigma^2)/2
|
||||
t12 = datatree.AddDivide(t13, datatree.Two);
|
||||
// exp(-((x-mu)^2/sigma^2)/2)
|
||||
t13 = datatree.AddExp(t12);
|
||||
// derivative of a standardized normal
|
||||
// t15 = (1/sqrt(2*pi))*exp(-y^2/2)
|
||||
t15 = datatree.AddDivide(t13,t14);
|
||||
// derivatives thru x
|
||||
t11 = datatree.AddDivide(darg1,arg3);
|
||||
// derivatives thru mu
|
||||
t12 = datatree.AddDivide(darg2,arg3);
|
||||
// intermediary sum
|
||||
t14 = datatree.AddMinus(t11,t12);
|
||||
// derivatives thru sigma
|
||||
t11 = datatree.AddDivide(y,arg3);
|
||||
t12 = datatree.AddTimes(t11,darg3);
|
||||
//intermediary sum
|
||||
t11 = datatree.AddMinus(t14,t12);
|
||||
// total derivative:
|
||||
// (darg1/sigma - darg2/sigma - darg3*(x-mu)/sigma^2) * t15
|
||||
// where t15 is the derivative of a standardized normal
|
||||
return datatree.AddTimes(t11, t15);
|
||||
}
|
||||
// Suppress GCC warning
|
||||
exit(EXIT_FAILURE);
|
||||
NodeID darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables);
|
||||
NodeID darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables);
|
||||
NodeID darg3 = arg3->getChainRuleDerivative(deriv_id, recursive_variables);
|
||||
return composeDerivatives(darg1, darg2, darg3);
|
||||
}
|
||||
|
||||
NodeID
|
||||
|
@ -2380,9 +2181,9 @@ UnknownFunctionNode::computeDerivative(int deriv_id)
|
|||
}
|
||||
|
||||
NodeID
|
||||
UnknownFunctionNode::computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_)
|
||||
UnknownFunctionNode::getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables)
|
||||
{
|
||||
cerr << "UnknownFunctionNode::computeDerivative: operation impossible!" << endl;
|
||||
cerr << "UnknownFunctionNode::getChainRuleDerivative: operation impossible!" << endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
|
|
|
@ -109,10 +109,6 @@ private:
|
|||
//! Computes derivative w.r. to a derivation ID (but doesn't store it in derivatives map)
|
||||
/*! You shoud use getDerivative() to get the benefit of symbolic a priori and of caching */
|
||||
virtual NodeID computeDerivative(int deriv_id) = 0;
|
||||
//! Computes derivative w.r. to a derivation ID and use chaine rule derivatives (but doesn't store it in derivatives map)
|
||||
/*! You shoud use getDerivative() to get the benefit of symbolic a priori and of caching */
|
||||
virtual NodeID computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_) = 0;
|
||||
|
||||
|
||||
protected:
|
||||
//! Reference to the enclosing DataTree
|
||||
|
@ -140,9 +136,12 @@ public:
|
|||
For an equal node, returns the derivative of lhs minus rhs */
|
||||
NodeID getDerivative(int deriv_id);
|
||||
|
||||
//! Returns derivative w.r. to derivation ID and use if it possible chaine rule derivatives
|
||||
NodeID getChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_);
|
||||
|
||||
//! Computes derivatives by applying the chain rule for some variables
|
||||
/*!
|
||||
\param deriv_id The derivation ID with respect to which we are derivating
|
||||
\param recursive_variables Contains the derivation ID for which chain rules must be applied. Keys are derivation IDs, values are equations of the form x=f(y) where x is the key variable and x doesn't appear in y
|
||||
*/
|
||||
virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables) = 0;
|
||||
|
||||
//! Returns precedence of node
|
||||
/*! Equals 100 for constants, variables, unary ops, and temporary terms */
|
||||
|
@ -215,7 +214,6 @@ private:
|
|||
//! Id from numerical constants table
|
||||
const int id;
|
||||
virtual NodeID computeDerivative(int deriv_id);
|
||||
virtual NodeID computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_);
|
||||
public:
|
||||
NumConstNode(DataTree &datatree_arg, int id_arg);
|
||||
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;
|
||||
|
@ -226,6 +224,7 @@ public:
|
|||
virtual void compile(ofstream &CompileCode, bool lhs_rhs, const temporary_terms_type &temporary_terms, map_idx_type &map_idx) const;
|
||||
virtual NodeID toStatic(DataTree &static_datatree) const;
|
||||
virtual pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const;
|
||||
virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
|
||||
};
|
||||
|
||||
//! Symbol or variable node
|
||||
|
@ -239,7 +238,6 @@ private:
|
|||
//! Derivation ID
|
||||
const int deriv_id;
|
||||
virtual NodeID computeDerivative(int deriv_id_arg);
|
||||
virtual NodeID computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_);
|
||||
public:
|
||||
VariableNode(DataTree &datatree_arg, int symb_id_arg, int lag_arg, int deriv_id_arg);
|
||||
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms = temporary_terms_type()) const;
|
||||
|
@ -258,6 +256,7 @@ public:
|
|||
virtual NodeID toStatic(DataTree &static_datatree) const;
|
||||
int get_symb_id() const { return symb_id; };
|
||||
virtual pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const;
|
||||
virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
|
||||
};
|
||||
|
||||
//! Unary operator node
|
||||
|
@ -267,9 +266,9 @@ private:
|
|||
const NodeID arg;
|
||||
const UnaryOpcode op_code;
|
||||
virtual NodeID computeDerivative(int deriv_id);
|
||||
virtual NodeID computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_);
|
||||
|
||||
virtual int cost(const temporary_terms_type &temporary_terms, bool is_matlab) const;
|
||||
//! Returns the derivative of this node if darg is the derivative of the argument
|
||||
NodeID composeDerivatives(NodeID darg);
|
||||
public:
|
||||
UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const NodeID arg_arg);
|
||||
virtual void computeTemporaryTerms(map<NodeID, int> &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const;
|
||||
|
@ -293,6 +292,7 @@ public:
|
|||
UnaryOpcode get_op_code() const { return(op_code); };
|
||||
virtual NodeID toStatic(DataTree &static_datatree) const;
|
||||
virtual pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const;
|
||||
virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
|
||||
};
|
||||
|
||||
//! Binary operator node
|
||||
|
@ -302,9 +302,9 @@ private:
|
|||
const NodeID arg1, arg2;
|
||||
const BinaryOpcode op_code;
|
||||
virtual NodeID computeDerivative(int deriv_id);
|
||||
virtual NodeID computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_);
|
||||
|
||||
virtual int cost(const temporary_terms_type &temporary_terms, bool is_matlab) const;
|
||||
//! Returns the derivative of this node if darg1 and darg2 are the derivatives of the arguments
|
||||
NodeID composeDerivatives(NodeID darg1, NodeID darg2);
|
||||
public:
|
||||
BinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg,
|
||||
BinaryOpcode op_code_arg, const NodeID arg2_arg);
|
||||
|
@ -332,6 +332,7 @@ public:
|
|||
BinaryOpcode get_op_code() const { return(op_code); };
|
||||
virtual NodeID toStatic(DataTree &static_datatree) const;
|
||||
pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const;
|
||||
virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
|
||||
};
|
||||
|
||||
//! Trinary operator node
|
||||
|
@ -342,9 +343,9 @@ private:
|
|||
const NodeID arg1, arg2, arg3;
|
||||
const TrinaryOpcode op_code;
|
||||
virtual NodeID computeDerivative(int deriv_id);
|
||||
virtual NodeID computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_);
|
||||
|
||||
virtual int cost(const temporary_terms_type &temporary_terms, bool is_matlab) const;
|
||||
//! Returns the derivative of this node if darg1, darg2 and darg3 are the derivatives of the arguments
|
||||
NodeID composeDerivatives(NodeID darg1, NodeID darg2, NodeID darg3);
|
||||
public:
|
||||
TrinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg,
|
||||
TrinaryOpcode op_code_arg, const NodeID arg2_arg, const NodeID arg3_arg);
|
||||
|
@ -366,6 +367,7 @@ public:
|
|||
virtual void compile(ofstream &CompileCode, bool lhs_rhs, const temporary_terms_type &temporary_terms, map_idx_type &map_idx) const;
|
||||
virtual NodeID toStatic(DataTree &static_datatree) const;
|
||||
virtual pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const;
|
||||
virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
|
||||
};
|
||||
|
||||
//! Unknown function node
|
||||
|
@ -375,7 +377,6 @@ private:
|
|||
const int symb_id;
|
||||
const vector<NodeID> arguments;
|
||||
virtual NodeID computeDerivative(int deriv_id);
|
||||
virtual NodeID computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_);
|
||||
public:
|
||||
UnknownFunctionNode(DataTree &datatree_arg, int symb_id_arg,
|
||||
const vector<NodeID> &arguments_arg);
|
||||
|
@ -395,6 +396,7 @@ public:
|
|||
virtual void compile(ofstream &CompileCode, bool lhs_rhs, const temporary_terms_type &temporary_terms, map_idx_type &map_idx) const;
|
||||
virtual NodeID toStatic(DataTree &static_datatree) const;
|
||||
virtual pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const;
|
||||
virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
|
||||
};
|
||||
|
||||
//! For one lead/lag of one block, stores mapping of information between original model and block-decomposed model
|
||||
|
|
Loading…
Reference in New Issue