Performance improvement of chain rule derivation
Commit 23b0c12d8e
introduced caching in chain
rule derivation (used by block decomposition), which increased speed for mfs >
0, but actually decreased it for mfs=0.
This patch introduces the pre-computation of derivatives which are known to be
zero using symbolic a priori (similarly to what is done in the non-chain rule
context). The algorithms are now identical between the two contexts (both
symbolic a priori + caching), the difference being that in the chain rule
context, the symbolic a priori and the cache are not stored within the ExprNode
class, since they depend on the list of recursive variables.
This patch brings a significant performant improvement for all values of the
“mfs” option (the improvement is greater for small values of “mfs”).
master
parent
f8edce01ec
commit
7acf278370
|
@ -2320,6 +2320,7 @@ DynamicModel::computeChainRuleJacobian()
|
|||
}
|
||||
|
||||
// Compute the block derivatives
|
||||
map<expr_t, set<int>> non_null_chain_rule_derivatives;
|
||||
map<pair<expr_t, int>, expr_t> chain_rule_deriv_cache;
|
||||
for (const auto &[indices, derivType] : determineBlockDerivativesType(blk))
|
||||
{
|
||||
|
@ -2337,10 +2338,10 @@ DynamicModel::computeChainRuleJacobian()
|
|||
d = Zero;
|
||||
break;
|
||||
case BlockDerivativeType::chainRule:
|
||||
d = equations[eq_orig]->getChainRuleDerivative(deriv_id, recursive_vars, chain_rule_deriv_cache);
|
||||
d = equations[eq_orig]->getChainRuleDerivative(deriv_id, recursive_vars, non_null_chain_rule_derivatives, chain_rule_deriv_cache);
|
||||
break;
|
||||
case BlockDerivativeType::normalizedChainRule:
|
||||
d = equation_type_and_normalized_equation[eq_orig].second->getChainRuleDerivative(deriv_id, recursive_vars, chain_rule_deriv_cache);
|
||||
d = equation_type_and_normalized_equation[eq_orig].second->getChainRuleDerivative(deriv_id, recursive_vars, non_null_chain_rule_derivatives, chain_rule_deriv_cache);
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
190
src/ExprNode.cc
190
src/ExprNode.cc
|
@ -56,14 +56,24 @@ ExprNode::getDerivative(int deriv_id)
|
|||
|
||||
expr_t
|
||||
ExprNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
map<pair<expr_t, int>, expr_t> &cache)
|
||||
{
|
||||
if (!non_null_chain_rule_derivatives.contains(this))
|
||||
prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
|
||||
|
||||
// Return zero if derivative is necessarily null (using symbolic a priori)
|
||||
if (!non_null_chain_rule_derivatives.at(this).contains(deriv_id))
|
||||
return datatree.Zero;
|
||||
|
||||
// If derivative is in the cache, return that value
|
||||
pair key {this, deriv_id};
|
||||
if (auto it = cache.find(key);
|
||||
it != cache.end())
|
||||
return it->second;
|
||||
|
||||
auto r = computeChainRuleDerivative(deriv_id, recursive_variables, cache);
|
||||
auto r = computeChainRuleDerivative(deriv_id, recursive_variables,
|
||||
non_null_chain_rule_derivatives, cache);
|
||||
|
||||
auto [ignore, success] = cache.emplace(key, r);
|
||||
assert(success); // The element should not already exist
|
||||
|
@ -477,6 +487,13 @@ NumConstNode::prepareForDerivation()
|
|||
// All derivatives are null, so non_null_derivatives is left empty
|
||||
}
|
||||
|
||||
void
|
||||
NumConstNode::prepareForChainRuleDerivation([[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
|
||||
{
|
||||
non_null_chain_rule_derivatives.try_emplace(const_cast<NumConstNode *>(this));
|
||||
}
|
||||
|
||||
expr_t
|
||||
NumConstNode::computeDerivative([[maybe_unused]] int deriv_id)
|
||||
{
|
||||
|
@ -565,6 +582,7 @@ NumConstNode::normalizeEquationHelper([[maybe_unused]] const set<expr_t> &contai
|
|||
expr_t
|
||||
NumConstNode::computeChainRuleDerivative([[maybe_unused]] int deriv_id,
|
||||
[[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
|
||||
[[maybe_unused]] map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
[[maybe_unused]] map<pair<expr_t, int>, expr_t> &cache)
|
||||
{
|
||||
return datatree.Zero;
|
||||
|
@ -897,6 +915,56 @@ VariableNode::prepareForDerivation()
|
|||
}
|
||||
}
|
||||
|
||||
void
|
||||
VariableNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
|
||||
{
|
||||
if (non_null_chain_rule_derivatives.contains(const_cast<VariableNode *>(this)))
|
||||
return;
|
||||
|
||||
switch (get_type())
|
||||
{
|
||||
case SymbolType::endogenous:
|
||||
{
|
||||
set<int> &nnd { non_null_chain_rule_derivatives[const_cast<VariableNode *>(this)] };
|
||||
int my_deriv_id {datatree.getDerivID(symb_id, lag)};
|
||||
if (auto it = recursive_variables.find(my_deriv_id);
|
||||
it != recursive_variables.end())
|
||||
{
|
||||
it->second->arg2->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
|
||||
nnd = non_null_chain_rule_derivatives.at(it->second->arg2);
|
||||
}
|
||||
nnd.insert(my_deriv_id);
|
||||
}
|
||||
break;
|
||||
case SymbolType::exogenous:
|
||||
case SymbolType::exogenousDet:
|
||||
case SymbolType::parameter:
|
||||
case SymbolType::trend:
|
||||
case SymbolType::logTrend:
|
||||
case SymbolType::modFileLocalVariable:
|
||||
case SymbolType::statementDeclaredVariable:
|
||||
case SymbolType::unusedEndogenous:
|
||||
// Those variables are never derived using chain rule
|
||||
non_null_chain_rule_derivatives.try_emplace(const_cast<VariableNode *>(this));
|
||||
break;
|
||||
case SymbolType::modelLocalVariable:
|
||||
{
|
||||
expr_t def { datatree.getLocalVariable(symb_id) };
|
||||
// Non null derivatives are those of the value of the model local variable
|
||||
def->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
|
||||
non_null_chain_rule_derivatives.emplace(const_cast<VariableNode *>(this),
|
||||
non_null_chain_rule_derivatives.at(def));
|
||||
}
|
||||
break;
|
||||
case SymbolType::externalFunction:
|
||||
case SymbolType::epilogue:
|
||||
case SymbolType::excludedVariable:
|
||||
cerr << "VariableNode::prepareForChainRuleDerivation: impossible case" << endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
expr_t
|
||||
VariableNode::computeDerivative(int deriv_id)
|
||||
{
|
||||
|
@ -1422,6 +1490,7 @@ VariableNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs
|
|||
expr_t
|
||||
VariableNode::computeChainRuleDerivative(int deriv_id,
|
||||
const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
map<pair<expr_t, int>, expr_t> &cache)
|
||||
{
|
||||
switch (get_type())
|
||||
|
@ -1442,12 +1511,12 @@ VariableNode::computeChainRuleDerivative(int deriv_id,
|
|||
// If there is in the equation a recursive variable we could use a chaine rule derivation
|
||||
else if (auto it = recursive_variables.find(my_deriv_id);
|
||||
it != recursive_variables.end())
|
||||
return it->second->arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache);
|
||||
return it->second->arg2->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
|
||||
else
|
||||
return datatree.Zero;
|
||||
|
||||
case SymbolType::modelLocalVariable:
|
||||
return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables, cache);
|
||||
return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
|
||||
case SymbolType::modFileLocalVariable:
|
||||
cerr << "modFileLocalVariable is not derivable" << endl;
|
||||
exit(EXIT_FAILURE);
|
||||
|
@ -2151,6 +2220,28 @@ UnaryOpNode::prepareForDerivation()
|
|||
}
|
||||
}
|
||||
|
||||
void
|
||||
UnaryOpNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
|
||||
{
|
||||
if (non_null_chain_rule_derivatives.contains(const_cast<UnaryOpNode *>(this)))
|
||||
return;
|
||||
|
||||
/* Non-null derivatives are those of the argument (except for STEADY_STATE in
|
||||
a dynamic context, in which case the potentially non-null derivatives are
|
||||
all the parameters) */
|
||||
set<int> &nnd { non_null_chain_rule_derivatives[const_cast<UnaryOpNode *>(this)] };
|
||||
if ((op_code == UnaryOpcode::steadyState || op_code == UnaryOpcode::steadyStateParamDeriv
|
||||
|| op_code == UnaryOpcode::steadyStateParam2ndDeriv)
|
||||
&& datatree.isDynamic())
|
||||
datatree.addAllParamDerivId(nnd);
|
||||
else
|
||||
{
|
||||
arg->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
|
||||
nnd = non_null_chain_rule_derivatives.at(arg);
|
||||
}
|
||||
}
|
||||
|
||||
expr_t
|
||||
UnaryOpNode::composeDerivatives(expr_t darg, int deriv_id)
|
||||
{
|
||||
|
@ -3271,9 +3362,10 @@ UnaryOpNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs)
|
|||
expr_t
|
||||
UnaryOpNode::computeChainRuleDerivative(int deriv_id,
|
||||
const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
map<pair<expr_t, int>, expr_t> &cache)
|
||||
{
|
||||
expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables, cache);
|
||||
expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
|
||||
return composeDerivatives(darg, deriv_id);
|
||||
}
|
||||
|
||||
|
@ -3986,6 +4078,24 @@ BinaryOpNode::prepareForDerivation()
|
|||
inserter(non_null_derivatives, non_null_derivatives.begin()));
|
||||
}
|
||||
|
||||
void
|
||||
BinaryOpNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
|
||||
{
|
||||
if (non_null_chain_rule_derivatives.contains(const_cast<BinaryOpNode *>(this)))
|
||||
return;
|
||||
|
||||
arg1->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
|
||||
arg2->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
|
||||
|
||||
set<int> &nnd { non_null_chain_rule_derivatives[const_cast<BinaryOpNode *>(this)] };
|
||||
set_union(non_null_chain_rule_derivatives.at(arg1).begin(),
|
||||
non_null_chain_rule_derivatives.at(arg1).end(),
|
||||
non_null_chain_rule_derivatives.at(arg2).begin(),
|
||||
non_null_chain_rule_derivatives.at(arg2).end(),
|
||||
inserter(nnd, nnd.begin()));
|
||||
}
|
||||
|
||||
expr_t
|
||||
BinaryOpNode::getNonZeroPartofEquation() const
|
||||
{
|
||||
|
@ -5038,10 +5148,11 @@ BinaryOpNode::normalizeEquation(int symb_id, int lag) const
|
|||
expr_t
|
||||
BinaryOpNode::computeChainRuleDerivative(int deriv_id,
|
||||
const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
map<pair<expr_t, int>, expr_t> &cache)
|
||||
{
|
||||
expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, cache);
|
||||
expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache);
|
||||
expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
|
||||
expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
|
||||
return composeDerivatives(darg1, darg2);
|
||||
}
|
||||
|
||||
|
@ -5888,6 +5999,30 @@ TrinaryOpNode::prepareForDerivation()
|
|||
inserter(non_null_derivatives, non_null_derivatives.begin()));
|
||||
}
|
||||
|
||||
void
|
||||
TrinaryOpNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
|
||||
{
|
||||
if (non_null_chain_rule_derivatives.contains(const_cast<TrinaryOpNode *>(this)))
|
||||
return;
|
||||
|
||||
arg1->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
|
||||
arg2->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
|
||||
arg3->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
|
||||
|
||||
set<int> &nnd { non_null_chain_rule_derivatives[const_cast<TrinaryOpNode *>(this)] };
|
||||
set<int> nnd_tmp;
|
||||
set_union(non_null_chain_rule_derivatives.at(arg1).begin(),
|
||||
non_null_chain_rule_derivatives.at(arg1).end(),
|
||||
non_null_chain_rule_derivatives.at(arg2).begin(),
|
||||
non_null_chain_rule_derivatives.at(arg2).end(),
|
||||
inserter(nnd_tmp, nnd_tmp.begin()));
|
||||
set_union(nnd_tmp.begin(), nnd_tmp.end(),
|
||||
non_null_chain_rule_derivatives.at(arg3).begin(),
|
||||
non_null_chain_rule_derivatives.at(arg3).end(),
|
||||
inserter(nnd, nnd.begin()));
|
||||
}
|
||||
|
||||
expr_t
|
||||
TrinaryOpNode::composeDerivatives(expr_t darg1, expr_t darg2, expr_t darg3)
|
||||
{
|
||||
|
@ -6351,11 +6486,12 @@ TrinaryOpNode::normalizeEquationHelper([[maybe_unused]] const set<expr_t> &conta
|
|||
expr_t
|
||||
TrinaryOpNode::computeChainRuleDerivative(int deriv_id,
|
||||
const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
map<pair<expr_t, int>, expr_t> &cache)
|
||||
{
|
||||
expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, cache);
|
||||
expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache);
|
||||
expr_t darg3 = arg3->getChainRuleDerivative(deriv_id, recursive_variables, cache);
|
||||
expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
|
||||
expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
|
||||
expr_t darg3 = arg3->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
|
||||
return composeDerivatives(darg1, darg2, darg3);
|
||||
}
|
||||
|
||||
|
@ -6772,6 +6908,30 @@ AbstractExternalFunctionNode::prepareForDerivation()
|
|||
preparedForDerivation = true;
|
||||
}
|
||||
|
||||
void
|
||||
AbstractExternalFunctionNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
|
||||
{
|
||||
if (non_null_chain_rule_derivatives.contains(const_cast<AbstractExternalFunctionNode *>(this)))
|
||||
return;
|
||||
|
||||
for (auto argument : arguments)
|
||||
argument->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
|
||||
|
||||
non_null_chain_rule_derivatives.emplace(const_cast<AbstractExternalFunctionNode *>(this),
|
||||
non_null_chain_rule_derivatives.at(arguments.at(0)));
|
||||
set<int> &nnd { non_null_chain_rule_derivatives.at(const_cast<AbstractExternalFunctionNode *>(this)) };
|
||||
for (int i {1}; i < static_cast<int>(arguments.size()); i++)
|
||||
{
|
||||
set<int> nnd_tmp;
|
||||
set_union(nnd.begin(), nnd.end(),
|
||||
non_null_chain_rule_derivatives.at(arguments.at(i)).begin(),
|
||||
non_null_chain_rule_derivatives.at(arguments.at(i)).end(),
|
||||
inserter(nnd_tmp, nnd_tmp.begin()));
|
||||
nnd = move(nnd_tmp);
|
||||
}
|
||||
}
|
||||
|
||||
expr_t
|
||||
AbstractExternalFunctionNode::computeDerivative(int deriv_id)
|
||||
{
|
||||
|
@ -6785,12 +6945,13 @@ AbstractExternalFunctionNode::computeDerivative(int deriv_id)
|
|||
expr_t
|
||||
AbstractExternalFunctionNode::computeChainRuleDerivative(int deriv_id,
|
||||
const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
map<pair<expr_t, int>, expr_t> &cache)
|
||||
{
|
||||
assert(datatree.external_functions_table.getNargs(symb_id) > 0);
|
||||
vector<expr_t> dargs;
|
||||
for (auto argument : arguments)
|
||||
dargs.push_back(argument->getChainRuleDerivative(deriv_id, recursive_variables, cache));
|
||||
dargs.push_back(argument->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache));
|
||||
return composeDerivatives(dargs);
|
||||
}
|
||||
|
||||
|
@ -8364,6 +8525,14 @@ SubModelNode::prepareForDerivation()
|
|||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
void
|
||||
SubModelNode::prepareForChainRuleDerivation([[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
|
||||
[[maybe_unused]] map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
|
||||
{
|
||||
cerr << "SubModelNode::prepareForChainRuleDerivation not implemented." << endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
expr_t
|
||||
SubModelNode::computeDerivative([[maybe_unused]] int deriv_id)
|
||||
{
|
||||
|
@ -8374,6 +8543,7 @@ SubModelNode::computeDerivative([[maybe_unused]] int deriv_id)
|
|||
expr_t
|
||||
SubModelNode::computeChainRuleDerivative([[maybe_unused]] int deriv_id,
|
||||
[[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
|
||||
[[maybe_unused]] map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
[[maybe_unused]] map<pair<expr_t, int>, expr_t> &cache)
|
||||
{
|
||||
cerr << "SubModelNode::computeChainRuleDerivative not implemented." << endl;
|
||||
|
|
|
@ -249,7 +249,7 @@ private:
|
|||
|
||||
/* Internal helper for getChainRuleDerivative(), that does the computation
|
||||
but assumes that the caching of this is handled elsewhere */
|
||||
virtual expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) = 0;
|
||||
virtual expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) = 0;
|
||||
|
||||
protected:
|
||||
//! Reference to the enclosing DataTree
|
||||
|
@ -278,6 +278,14 @@ protected:
|
|||
//! Initializes data member non_null_derivatives
|
||||
virtual void prepareForDerivation() = 0;
|
||||
|
||||
/* Computes the derivatives which are potentially non-null, using symbolic a
|
||||
priori, similarly to prepareForDerivation(), but in a chain rule
|
||||
derivation context. See getChainRuleDerivation() for the meaning of
|
||||
“recursive_variables”. Note that all non-endogenous variables are
|
||||
automatically considered to have a zero derivative (since they’re never
|
||||
used in a chain rule context) */
|
||||
virtual void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const = 0;
|
||||
|
||||
//! Cost of computing current node
|
||||
/*! Nodes included in temporary_terms are considered having a null cost */
|
||||
virtual int cost(int cost, bool is_matlab) const;
|
||||
|
@ -335,11 +343,17 @@ public:
|
|||
— “recursive_variables” contains the derivation ID for which chain rules
|
||||
must be applied. Keys are derivation IDs, values are equations of the
|
||||
form x=f(y) where x is the key variable and x doesn't appear in y
|
||||
— “non_null_chain_rule_derivatives” is used to store the indices of
|
||||
variables that are potentially non-null (using symbolic a priori),
|
||||
similarly to ExprNode::non_null_derivatives.
|
||||
— “cache” is used to store already-computed derivatives (in a map
|
||||
<expression, deriv_id> → derivative); this cache is specific to a given
|
||||
value of “recursive_variables”, and thus should not be reused accross
|
||||
calls that use different values of “recursive_variables”. */
|
||||
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache);
|
||||
<expression, deriv_id> → derivative)
|
||||
NB: always returns zero when “deriv_id” corresponds to a non-endogenous
|
||||
variable (since such variables are never used in a chain rule context).
|
||||
NB 2: “non_null_chain_rule_derivatives” and “cache” are specific to a given
|
||||
value of “recursive_variables”, and thus should not be reused accross
|
||||
calls that use different values of “recursive_variables”. */
|
||||
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache);
|
||||
|
||||
//! Returns precedence of node
|
||||
/*! Equals 100 for constants, variables, unary ops, and temporary terms */
|
||||
|
@ -843,9 +857,10 @@ public:
|
|||
const int id;
|
||||
private:
|
||||
expr_t computeDerivative(int deriv_id) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
protected:
|
||||
void prepareForDerivation() override;
|
||||
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
|
||||
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override;
|
||||
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
|
||||
public:
|
||||
|
@ -915,9 +930,10 @@ public:
|
|||
const int lag;
|
||||
private:
|
||||
expr_t computeDerivative(int deriv_id) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
protected:
|
||||
void prepareForDerivation() override;
|
||||
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
|
||||
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override;
|
||||
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
|
||||
public:
|
||||
|
@ -985,6 +1001,7 @@ class UnaryOpNode : public ExprNode
|
|||
{
|
||||
protected:
|
||||
void prepareForDerivation() override;
|
||||
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
|
||||
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override;
|
||||
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
|
||||
public:
|
||||
|
@ -998,7 +1015,7 @@ public:
|
|||
const vector<int> adl_lags;
|
||||
private:
|
||||
expr_t computeDerivative(int deriv_id) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
int cost(int cost, bool is_matlab) const override;
|
||||
int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
|
||||
int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
|
||||
|
@ -1090,6 +1107,7 @@ class BinaryOpNode : public ExprNode
|
|||
{
|
||||
protected:
|
||||
void prepareForDerivation() override;
|
||||
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
|
||||
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override;
|
||||
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
|
||||
public:
|
||||
|
@ -1099,7 +1117,7 @@ public:
|
|||
const string adlparam;
|
||||
private:
|
||||
expr_t computeDerivative(int deriv_id) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
int cost(int cost, bool is_matlab) const override;
|
||||
int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
|
||||
int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
|
||||
|
@ -1240,10 +1258,11 @@ public:
|
|||
const TrinaryOpcode op_code;
|
||||
protected:
|
||||
void prepareForDerivation() override;
|
||||
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
|
||||
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
|
||||
private:
|
||||
expr_t computeDerivative(int deriv_id) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
int cost(int cost, bool is_matlab) const override;
|
||||
int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
|
||||
int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
|
||||
|
@ -1338,7 +1357,7 @@ public:
|
|||
const vector<expr_t> arguments;
|
||||
private:
|
||||
expr_t computeDerivative(int deriv_id) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
virtual expr_t composeDerivatives(const vector<expr_t> &dargs) = 0;
|
||||
// Computes the maximum of f applied to all arguments (result will always be non-negative)
|
||||
int maxHelper(const function<int (expr_t)> &f) const;
|
||||
|
@ -1348,6 +1367,7 @@ protected:
|
|||
{
|
||||
};
|
||||
void prepareForDerivation() override;
|
||||
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
|
||||
//! Returns true if the given external function has been written as a temporary term
|
||||
bool alreadyWrittenAsTefTerm(int the_symb_id, const deriv_node_temp_terms_t &tef_terms) const;
|
||||
//! Returns the index in the tef_terms map of this external function
|
||||
|
@ -1622,9 +1642,10 @@ public:
|
|||
expr_t substituteLogTransform(int orig_symb_id, int aux_symb_id) const override;
|
||||
protected:
|
||||
void prepareForDerivation() override;
|
||||
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
|
||||
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
|
||||
private:
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
};
|
||||
|
||||
class VarExpectationNode : public SubModelNode
|
||||
|
|
|
@ -666,6 +666,7 @@ StaticModel::computeChainRuleJacobian()
|
|||
&& simulation_type != BlockSimulationType::solveTwoBoundariesComplete);
|
||||
|
||||
int size = blocks[blk].size;
|
||||
map<expr_t, set<int>> non_null_chain_rule_derivatives;
|
||||
map<pair<expr_t, int>, expr_t> chain_rule_deriv_cache;
|
||||
for (int eq = nb_recursives; eq < size; eq++)
|
||||
{
|
||||
|
@ -673,7 +674,7 @@ StaticModel::computeChainRuleJacobian()
|
|||
for (int var = nb_recursives; var < size; var++)
|
||||
{
|
||||
int var_orig = getBlockVariableID(blk, var);
|
||||
expr_t d1 = equations[eq_orig]->getChainRuleDerivative(getDerivID(symbol_table.getID(SymbolType::endogenous, var_orig), 0), recursive_vars, chain_rule_deriv_cache);
|
||||
expr_t d1 = equations[eq_orig]->getChainRuleDerivative(getDerivID(symbol_table.getID(SymbolType::endogenous, var_orig), 0), recursive_vars, non_null_chain_rule_derivatives, chain_rule_deriv_cache);
|
||||
if (d1 != Zero)
|
||||
blocks_derivatives[blk][{ eq, var, 0 }] = d1;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue