Performance improvement of chain rule derivation

Commit 23b0c12d8e introduced caching in chain
rule derivation (used by block decomposition), which increased speed for mfs >
0, but actually decreased it for mfs=0.

This patch introduces the pre-computation of derivatives which are known to be
zero using symbolic a priori (similarly to what is done in the non-chain rule
context). The algorithms are now identical between the two contexts (both
symbolic a priori + caching), the difference being that in the chain rule
context, the symbolic a priori and the cache are not stored within the ExprNode
class, since they depend on the list of recursive variables.

This patch brings a significant performant improvement for all values of the
“mfs” option (the improvement is greater for small values of “mfs”).
master
Sébastien Villemot 2023-03-02 17:49:16 +01:00
parent f8edce01ec
commit 7acf278370
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
4 changed files with 218 additions and 25 deletions

View File

@ -2320,6 +2320,7 @@ DynamicModel::computeChainRuleJacobian()
}
// Compute the block derivatives
map<expr_t, set<int>> non_null_chain_rule_derivatives;
map<pair<expr_t, int>, expr_t> chain_rule_deriv_cache;
for (const auto &[indices, derivType] : determineBlockDerivativesType(blk))
{
@ -2337,10 +2338,10 @@ DynamicModel::computeChainRuleJacobian()
d = Zero;
break;
case BlockDerivativeType::chainRule:
d = equations[eq_orig]->getChainRuleDerivative(deriv_id, recursive_vars, chain_rule_deriv_cache);
d = equations[eq_orig]->getChainRuleDerivative(deriv_id, recursive_vars, non_null_chain_rule_derivatives, chain_rule_deriv_cache);
break;
case BlockDerivativeType::normalizedChainRule:
d = equation_type_and_normalized_equation[eq_orig].second->getChainRuleDerivative(deriv_id, recursive_vars, chain_rule_deriv_cache);
d = equation_type_and_normalized_equation[eq_orig].second->getChainRuleDerivative(deriv_id, recursive_vars, non_null_chain_rule_derivatives, chain_rule_deriv_cache);
break;
}

View File

@ -56,14 +56,24 @@ ExprNode::getDerivative(int deriv_id)
expr_t
ExprNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
map<pair<expr_t, int>, expr_t> &cache)
{
if (!non_null_chain_rule_derivatives.contains(this))
prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
// Return zero if derivative is necessarily null (using symbolic a priori)
if (!non_null_chain_rule_derivatives.at(this).contains(deriv_id))
return datatree.Zero;
// If derivative is in the cache, return that value
pair key {this, deriv_id};
if (auto it = cache.find(key);
it != cache.end())
return it->second;
auto r = computeChainRuleDerivative(deriv_id, recursive_variables, cache);
auto r = computeChainRuleDerivative(deriv_id, recursive_variables,
non_null_chain_rule_derivatives, cache);
auto [ignore, success] = cache.emplace(key, r);
assert(success); // The element should not already exist
@ -477,6 +487,13 @@ NumConstNode::prepareForDerivation()
// All derivatives are null, so non_null_derivatives is left empty
}
void
NumConstNode::prepareForChainRuleDerivation([[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
{
non_null_chain_rule_derivatives.try_emplace(const_cast<NumConstNode *>(this));
}
expr_t
NumConstNode::computeDerivative([[maybe_unused]] int deriv_id)
{
@ -565,6 +582,7 @@ NumConstNode::normalizeEquationHelper([[maybe_unused]] const set<expr_t> &contai
expr_t
NumConstNode::computeChainRuleDerivative([[maybe_unused]] int deriv_id,
[[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
[[maybe_unused]] map<expr_t, set<int>> &non_null_chain_rule_derivatives,
[[maybe_unused]] map<pair<expr_t, int>, expr_t> &cache)
{
return datatree.Zero;
@ -897,6 +915,56 @@ VariableNode::prepareForDerivation()
}
}
void
VariableNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
{
if (non_null_chain_rule_derivatives.contains(const_cast<VariableNode *>(this)))
return;
switch (get_type())
{
case SymbolType::endogenous:
{
set<int> &nnd { non_null_chain_rule_derivatives[const_cast<VariableNode *>(this)] };
int my_deriv_id {datatree.getDerivID(symb_id, lag)};
if (auto it = recursive_variables.find(my_deriv_id);
it != recursive_variables.end())
{
it->second->arg2->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
nnd = non_null_chain_rule_derivatives.at(it->second->arg2);
}
nnd.insert(my_deriv_id);
}
break;
case SymbolType::exogenous:
case SymbolType::exogenousDet:
case SymbolType::parameter:
case SymbolType::trend:
case SymbolType::logTrend:
case SymbolType::modFileLocalVariable:
case SymbolType::statementDeclaredVariable:
case SymbolType::unusedEndogenous:
// Those variables are never derived using chain rule
non_null_chain_rule_derivatives.try_emplace(const_cast<VariableNode *>(this));
break;
case SymbolType::modelLocalVariable:
{
expr_t def { datatree.getLocalVariable(symb_id) };
// Non null derivatives are those of the value of the model local variable
def->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
non_null_chain_rule_derivatives.emplace(const_cast<VariableNode *>(this),
non_null_chain_rule_derivatives.at(def));
}
break;
case SymbolType::externalFunction:
case SymbolType::epilogue:
case SymbolType::excludedVariable:
cerr << "VariableNode::prepareForChainRuleDerivation: impossible case" << endl;
exit(EXIT_FAILURE);
}
}
expr_t
VariableNode::computeDerivative(int deriv_id)
{
@ -1422,6 +1490,7 @@ VariableNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs
expr_t
VariableNode::computeChainRuleDerivative(int deriv_id,
const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
map<pair<expr_t, int>, expr_t> &cache)
{
switch (get_type())
@ -1442,12 +1511,12 @@ VariableNode::computeChainRuleDerivative(int deriv_id,
// If there is in the equation a recursive variable we could use a chaine rule derivation
else if (auto it = recursive_variables.find(my_deriv_id);
it != recursive_variables.end())
return it->second->arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache);
return it->second->arg2->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
else
return datatree.Zero;
case SymbolType::modelLocalVariable:
return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables, cache);
return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
case SymbolType::modFileLocalVariable:
cerr << "modFileLocalVariable is not derivable" << endl;
exit(EXIT_FAILURE);
@ -2151,6 +2220,28 @@ UnaryOpNode::prepareForDerivation()
}
}
void
UnaryOpNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
{
if (non_null_chain_rule_derivatives.contains(const_cast<UnaryOpNode *>(this)))
return;
/* Non-null derivatives are those of the argument (except for STEADY_STATE in
a dynamic context, in which case the potentially non-null derivatives are
all the parameters) */
set<int> &nnd { non_null_chain_rule_derivatives[const_cast<UnaryOpNode *>(this)] };
if ((op_code == UnaryOpcode::steadyState || op_code == UnaryOpcode::steadyStateParamDeriv
|| op_code == UnaryOpcode::steadyStateParam2ndDeriv)
&& datatree.isDynamic())
datatree.addAllParamDerivId(nnd);
else
{
arg->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
nnd = non_null_chain_rule_derivatives.at(arg);
}
}
expr_t
UnaryOpNode::composeDerivatives(expr_t darg, int deriv_id)
{
@ -3271,9 +3362,10 @@ UnaryOpNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs)
expr_t
UnaryOpNode::computeChainRuleDerivative(int deriv_id,
const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
map<pair<expr_t, int>, expr_t> &cache)
{
expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables, cache);
expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
return composeDerivatives(darg, deriv_id);
}
@ -3986,6 +4078,24 @@ BinaryOpNode::prepareForDerivation()
inserter(non_null_derivatives, non_null_derivatives.begin()));
}
void
BinaryOpNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
{
if (non_null_chain_rule_derivatives.contains(const_cast<BinaryOpNode *>(this)))
return;
arg1->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
arg2->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
set<int> &nnd { non_null_chain_rule_derivatives[const_cast<BinaryOpNode *>(this)] };
set_union(non_null_chain_rule_derivatives.at(arg1).begin(),
non_null_chain_rule_derivatives.at(arg1).end(),
non_null_chain_rule_derivatives.at(arg2).begin(),
non_null_chain_rule_derivatives.at(arg2).end(),
inserter(nnd, nnd.begin()));
}
expr_t
BinaryOpNode::getNonZeroPartofEquation() const
{
@ -5038,10 +5148,11 @@ BinaryOpNode::normalizeEquation(int symb_id, int lag) const
expr_t
BinaryOpNode::computeChainRuleDerivative(int deriv_id,
const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
map<pair<expr_t, int>, expr_t> &cache)
{
expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, cache);
expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache);
expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
return composeDerivatives(darg1, darg2);
}
@ -5888,6 +5999,30 @@ TrinaryOpNode::prepareForDerivation()
inserter(non_null_derivatives, non_null_derivatives.begin()));
}
void
TrinaryOpNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
{
if (non_null_chain_rule_derivatives.contains(const_cast<TrinaryOpNode *>(this)))
return;
arg1->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
arg2->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
arg3->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
set<int> &nnd { non_null_chain_rule_derivatives[const_cast<TrinaryOpNode *>(this)] };
set<int> nnd_tmp;
set_union(non_null_chain_rule_derivatives.at(arg1).begin(),
non_null_chain_rule_derivatives.at(arg1).end(),
non_null_chain_rule_derivatives.at(arg2).begin(),
non_null_chain_rule_derivatives.at(arg2).end(),
inserter(nnd_tmp, nnd_tmp.begin()));
set_union(nnd_tmp.begin(), nnd_tmp.end(),
non_null_chain_rule_derivatives.at(arg3).begin(),
non_null_chain_rule_derivatives.at(arg3).end(),
inserter(nnd, nnd.begin()));
}
expr_t
TrinaryOpNode::composeDerivatives(expr_t darg1, expr_t darg2, expr_t darg3)
{
@ -6351,11 +6486,12 @@ TrinaryOpNode::normalizeEquationHelper([[maybe_unused]] const set<expr_t> &conta
expr_t
TrinaryOpNode::computeChainRuleDerivative(int deriv_id,
const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
map<pair<expr_t, int>, expr_t> &cache)
{
expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, cache);
expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache);
expr_t darg3 = arg3->getChainRuleDerivative(deriv_id, recursive_variables, cache);
expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
expr_t darg3 = arg3->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
return composeDerivatives(darg1, darg2, darg3);
}
@ -6772,6 +6908,30 @@ AbstractExternalFunctionNode::prepareForDerivation()
preparedForDerivation = true;
}
void
AbstractExternalFunctionNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
{
if (non_null_chain_rule_derivatives.contains(const_cast<AbstractExternalFunctionNode *>(this)))
return;
for (auto argument : arguments)
argument->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
non_null_chain_rule_derivatives.emplace(const_cast<AbstractExternalFunctionNode *>(this),
non_null_chain_rule_derivatives.at(arguments.at(0)));
set<int> &nnd { non_null_chain_rule_derivatives.at(const_cast<AbstractExternalFunctionNode *>(this)) };
for (int i {1}; i < static_cast<int>(arguments.size()); i++)
{
set<int> nnd_tmp;
set_union(nnd.begin(), nnd.end(),
non_null_chain_rule_derivatives.at(arguments.at(i)).begin(),
non_null_chain_rule_derivatives.at(arguments.at(i)).end(),
inserter(nnd_tmp, nnd_tmp.begin()));
nnd = move(nnd_tmp);
}
}
expr_t
AbstractExternalFunctionNode::computeDerivative(int deriv_id)
{
@ -6785,12 +6945,13 @@ AbstractExternalFunctionNode::computeDerivative(int deriv_id)
expr_t
AbstractExternalFunctionNode::computeChainRuleDerivative(int deriv_id,
const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
map<pair<expr_t, int>, expr_t> &cache)
{
assert(datatree.external_functions_table.getNargs(symb_id) > 0);
vector<expr_t> dargs;
for (auto argument : arguments)
dargs.push_back(argument->getChainRuleDerivative(deriv_id, recursive_variables, cache));
dargs.push_back(argument->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache));
return composeDerivatives(dargs);
}
@ -8364,6 +8525,14 @@ SubModelNode::prepareForDerivation()
exit(EXIT_FAILURE);
}
void
SubModelNode::prepareForChainRuleDerivation([[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
[[maybe_unused]] map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
{
cerr << "SubModelNode::prepareForChainRuleDerivation not implemented." << endl;
exit(EXIT_FAILURE);
}
expr_t
SubModelNode::computeDerivative([[maybe_unused]] int deriv_id)
{
@ -8374,6 +8543,7 @@ SubModelNode::computeDerivative([[maybe_unused]] int deriv_id)
expr_t
SubModelNode::computeChainRuleDerivative([[maybe_unused]] int deriv_id,
[[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
[[maybe_unused]] map<expr_t, set<int>> &non_null_chain_rule_derivatives,
[[maybe_unused]] map<pair<expr_t, int>, expr_t> &cache)
{
cerr << "SubModelNode::computeChainRuleDerivative not implemented." << endl;

View File

@ -249,7 +249,7 @@ private:
/* Internal helper for getChainRuleDerivative(), that does the computation
but assumes that the caching of this is handled elsewhere */
virtual expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) = 0;
virtual expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) = 0;
protected:
//! Reference to the enclosing DataTree
@ -278,6 +278,14 @@ protected:
//! Initializes data member non_null_derivatives
virtual void prepareForDerivation() = 0;
/* Computes the derivatives which are potentially non-null, using symbolic a
priori, similarly to prepareForDerivation(), but in a chain rule
derivation context. See getChainRuleDerivation() for the meaning of
recursive_variables. Note that all non-endogenous variables are
automatically considered to have a zero derivative (since theyre never
used in a chain rule context) */
virtual void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const = 0;
//! Cost of computing current node
/*! Nodes included in temporary_terms are considered having a null cost */
virtual int cost(int cost, bool is_matlab) const;
@ -335,11 +343,17 @@ public:
recursive_variables contains the derivation ID for which chain rules
must be applied. Keys are derivation IDs, values are equations of the
form x=f(y) where x is the key variable and x doesn't appear in y
non_null_chain_rule_derivatives is used to store the indices of
variables that are potentially non-null (using symbolic a priori),
similarly to ExprNode::non_null_derivatives.
cache is used to store already-computed derivatives (in a map
<expression, deriv_id> derivative); this cache is specific to a given
value of recursive_variables, and thus should not be reused accross
calls that use different values of recursive_variables. */
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache);
<expression, deriv_id> derivative)
NB: always returns zero when deriv_id corresponds to a non-endogenous
variable (since such variables are never used in a chain rule context).
NB 2: non_null_chain_rule_derivatives and cache are specific to a given
value of recursive_variables, and thus should not be reused accross
calls that use different values of recursive_variables. */
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache);
//! Returns precedence of node
/*! Equals 100 for constants, variables, unary ops, and temporary terms */
@ -843,9 +857,10 @@ public:
const int id;
private:
expr_t computeDerivative(int deriv_id) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
protected:
void prepareForDerivation() override;
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
public:
@ -915,9 +930,10 @@ public:
const int lag;
private:
expr_t computeDerivative(int deriv_id) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
protected:
void prepareForDerivation() override;
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
public:
@ -985,6 +1001,7 @@ class UnaryOpNode : public ExprNode
{
protected:
void prepareForDerivation() override;
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
public:
@ -998,7 +1015,7 @@ public:
const vector<int> adl_lags;
private:
expr_t computeDerivative(int deriv_id) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
int cost(int cost, bool is_matlab) const override;
int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
@ -1090,6 +1107,7 @@ class BinaryOpNode : public ExprNode
{
protected:
void prepareForDerivation() override;
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
public:
@ -1099,7 +1117,7 @@ public:
const string adlparam;
private:
expr_t computeDerivative(int deriv_id) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
int cost(int cost, bool is_matlab) const override;
int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
@ -1240,10 +1258,11 @@ public:
const TrinaryOpcode op_code;
protected:
void prepareForDerivation() override;
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
private:
expr_t computeDerivative(int deriv_id) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
int cost(int cost, bool is_matlab) const override;
int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
@ -1338,7 +1357,7 @@ public:
const vector<expr_t> arguments;
private:
expr_t computeDerivative(int deriv_id) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
virtual expr_t composeDerivatives(const vector<expr_t> &dargs) = 0;
// Computes the maximum of f applied to all arguments (result will always be non-negative)
int maxHelper(const function<int (expr_t)> &f) const;
@ -1348,6 +1367,7 @@ protected:
{
};
void prepareForDerivation() override;
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
//! Returns true if the given external function has been written as a temporary term
bool alreadyWrittenAsTefTerm(int the_symb_id, const deriv_node_temp_terms_t &tef_terms) const;
//! Returns the index in the tef_terms map of this external function
@ -1622,9 +1642,10 @@ public:
expr_t substituteLogTransform(int orig_symb_id, int aux_symb_id) const override;
protected:
void prepareForDerivation() override;
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
private:
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
};
class VarExpectationNode : public SubModelNode

View File

@ -666,6 +666,7 @@ StaticModel::computeChainRuleJacobian()
&& simulation_type != BlockSimulationType::solveTwoBoundariesComplete);
int size = blocks[blk].size;
map<expr_t, set<int>> non_null_chain_rule_derivatives;
map<pair<expr_t, int>, expr_t> chain_rule_deriv_cache;
for (int eq = nb_recursives; eq < size; eq++)
{
@ -673,7 +674,7 @@ StaticModel::computeChainRuleJacobian()
for (int var = nb_recursives; var < size; var++)
{
int var_orig = getBlockVariableID(blk, var);
expr_t d1 = equations[eq_orig]->getChainRuleDerivative(getDerivID(symbol_table.getID(SymbolType::endogenous, var_orig), 0), recursive_vars, chain_rule_deriv_cache);
expr_t d1 = equations[eq_orig]->getChainRuleDerivative(getDerivID(symbol_table.getID(SymbolType::endogenous, var_orig), 0), recursive_vars, non_null_chain_rule_derivatives, chain_rule_deriv_cache);
if (d1 != Zero)
blocks_derivatives[blk][{ eq, var, 0 }] = d1;
}