Performance improvement of chain rule derivation, using caching

Useful for mfs > 0 on large models.
master
Sébastien Villemot 2022-11-08 12:28:48 +01:00
parent dbc2851606
commit 23b0c12d8e
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
4 changed files with 74 additions and 37 deletions

View File

@ -3124,6 +3124,7 @@ DynamicModel::computeChainRuleJacobian()
}
// Compute the block derivatives
map<pair<expr_t, int>, expr_t> chain_rule_deriv_cache;
for (const auto &[indices, derivType] : determineBlockDerivativesType(blk))
{
auto [lag, eq, var] = indices;
@ -3140,10 +3141,10 @@ DynamicModel::computeChainRuleJacobian()
d = Zero;
break;
case BlockDerivativeType::chainRule:
d = equations[eq_orig]->getChainRuleDerivative(deriv_id, recursive_vars);
d = equations[eq_orig]->getChainRuleDerivative(deriv_id, recursive_vars, chain_rule_deriv_cache);
break;
case BlockDerivativeType::normalizedChainRule:
d = equation_type_and_normalized_equation[eq_orig].second->getChainRuleDerivative(deriv_id, recursive_vars);
d = equation_type_and_normalized_equation[eq_orig].second->getChainRuleDerivative(deriv_id, recursive_vars, chain_rule_deriv_cache);
break;
}

View File

@ -53,6 +53,22 @@ ExprNode::getDerivative(int deriv_id)
}
}
expr_t
ExprNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables,
map<pair<expr_t, int>, expr_t> &cache)
{
pair key {this, deriv_id};
if (auto it = cache.find(key);
it != cache.end())
return it->second;
auto r = computeChainRuleDerivative(deriv_id, recursive_variables, cache);
auto [ignore, success] = cache.emplace(key, r);
assert(success); // The element should not already exist
return r;
}
int
ExprNode::precedence([[maybe_unused]] ExprNodeOutputType output_type,
[[maybe_unused]] const temporary_terms_t &temporary_terms) const
@ -546,8 +562,9 @@ NumConstNode::normalizeEquationHelper([[maybe_unused]] const set<expr_t> &contai
}
expr_t
NumConstNode::getChainRuleDerivative([[maybe_unused]] int deriv_id,
[[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables)
NumConstNode::computeChainRuleDerivative([[maybe_unused]] int deriv_id,
[[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
[[maybe_unused]] map<pair<expr_t, int>, expr_t> &cache)
{
return datatree.Zero;
}
@ -1402,7 +1419,9 @@ VariableNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs
}
expr_t
VariableNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
VariableNode::computeChainRuleDerivative(int deriv_id,
const map<int, BinaryOpNode *> &recursive_variables,
map<pair<expr_t, int>, expr_t> &cache)
{
switch (get_type())
{
@ -1421,12 +1440,12 @@ VariableNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *
// If there is in the equation a recursive variable we could use a chaine rule derivation
else if (auto it = recursive_variables.find(datatree.getDerivID(symb_id, lag));
it != recursive_variables.end())
return it->second->arg2->getChainRuleDerivative(deriv_id, recursive_variables);
return it->second->arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache);
else
return datatree.Zero;
case SymbolType::modelLocalVariable:
return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables);
return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables, cache);
case SymbolType::modFileLocalVariable:
cerr << "modFileLocalVariable is not derivable" << endl;
exit(EXIT_FAILURE);
@ -1439,7 +1458,7 @@ VariableNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *
case SymbolType::externalFunction:
case SymbolType::epilogue:
case SymbolType::excludedVariable:
cerr << "VariableNode::getChainRuleDerivative: Impossible case" << endl;
cerr << "VariableNode::computeChainRuleDerivative: Impossible case" << endl;
exit(EXIT_FAILURE);
}
// Suppress GCC warning
@ -3224,9 +3243,11 @@ UnaryOpNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs)
}
expr_t
UnaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
UnaryOpNode::computeChainRuleDerivative(int deriv_id,
const map<int, BinaryOpNode *> &recursive_variables,
map<pair<expr_t, int>, expr_t> &cache)
{
expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables);
expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables, cache);
return composeDerivatives(darg, deriv_id);
}
@ -4978,10 +4999,12 @@ BinaryOpNode::normalizeEquation(int symb_id, int lag) const
}
expr_t
BinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
BinaryOpNode::computeChainRuleDerivative(int deriv_id,
const map<int, BinaryOpNode *> &recursive_variables,
map<pair<expr_t, int>, expr_t> &cache)
{
expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables);
expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables);
expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, cache);
expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache);
return composeDerivatives(darg1, darg2);
}
@ -6289,11 +6312,13 @@ TrinaryOpNode::normalizeEquationHelper([[maybe_unused]] const set<expr_t> &conta
}
expr_t
TrinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
TrinaryOpNode::computeChainRuleDerivative(int deriv_id,
const map<int, BinaryOpNode *> &recursive_variables,
map<pair<expr_t, int>, expr_t> &cache)
{
expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables);
expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables);
expr_t darg3 = arg3->getChainRuleDerivative(deriv_id, recursive_variables);
expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, cache);
expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache);
expr_t darg3 = arg3->getChainRuleDerivative(deriv_id, recursive_variables, cache);
return composeDerivatives(darg1, darg2, darg3);
}
@ -6717,12 +6742,14 @@ AbstractExternalFunctionNode::computeDerivative(int deriv_id)
}
expr_t
AbstractExternalFunctionNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
AbstractExternalFunctionNode::computeChainRuleDerivative(int deriv_id,
const map<int, BinaryOpNode *> &recursive_variables,
map<pair<expr_t, int>, expr_t> &cache)
{
assert(datatree.external_functions_table.getNargs(symb_id) > 0);
vector<expr_t> dargs;
for (auto argument : arguments)
dargs.push_back(argument->getChainRuleDerivative(deriv_id, recursive_variables));
dargs.push_back(argument->getChainRuleDerivative(deriv_id, recursive_variables, cache));
return composeDerivatives(dargs);
}
@ -8334,10 +8361,11 @@ SubModelNode::computeDerivative([[maybe_unused]] int deriv_id)
}
expr_t
SubModelNode::getChainRuleDerivative([[maybe_unused]] int deriv_id,
[[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables)
SubModelNode::computeChainRuleDerivative([[maybe_unused]] int deriv_id,
[[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
[[maybe_unused]] map<pair<expr_t, int>, expr_t> &cache)
{
cerr << "SubModelNode::getChainRuleDerivative not implemented." << endl;
cerr << "SubModelNode::computeChainRuleDerivative not implemented." << endl;
exit(EXIT_FAILURE);
}

View File

@ -247,6 +247,10 @@ private:
/*! You shoud use getDerivative() to get the benefit of symbolic a priori and of caching */
virtual expr_t computeDerivative(int deriv_id) = 0;
/* Internal helper for getChainRuleDerivative(), that does the computation
but assumes that the caching of this is handled elsewhere */
virtual expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) = 0;
protected:
//! Reference to the enclosing DataTree
DataTree &datatree;
@ -327,12 +331,15 @@ public:
For an equal node, returns the derivative of lhs minus rhs */
expr_t getDerivative(int deriv_id);
//! Computes derivatives by applying the chain rule for some variables
/*!
\param deriv_id The derivation ID with respect to which we are derivating
\param recursive_variables Contains the derivation ID for which chain rules must be applied. Keys are derivation IDs, values are equations of the form x=f(y) where x is the key variable and x doesn't appear in y
*/
virtual expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) = 0;
/* Computes derivatives by applying the chain rule for some variables.
recursive_variables contains the derivation ID for which chain rules
must be applied. Keys are derivation IDs, values are equations of the
form x=f(y) where x is the key variable and x doesn't appear in y
cache is used to store already-computed derivatives (in a map
<expression, deriv_id> derivative); this cache is specific to a given
value of recursive_variables, and thus should not be reused accross
calls that use different values of recursive_variables. */
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache);
//! Returns precedence of node
/*! Equals 100 for constants, variables, unary ops, and temporary terms */
@ -836,6 +843,7 @@ public:
const int id;
private:
expr_t computeDerivative(int deriv_id) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
protected:
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
@ -853,7 +861,6 @@ public:
expr_t toStatic(DataTree &static_datatree) const override;
void computeXrefs(EquationInfo &ei) const override;
BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
int maxEndoLead() const override;
int maxExoLead() const override;
int maxEndoLag() const override;
@ -908,6 +915,7 @@ public:
const int lag;
private:
expr_t computeDerivative(int deriv_id) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
protected:
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
@ -926,7 +934,6 @@ public:
void computeXrefs(EquationInfo &ei) const override;
SymbolType get_type() const;
BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
int maxEndoLead() const override;
int maxExoLead() const override;
int maxEndoLag() const override;
@ -990,6 +997,7 @@ public:
const vector<int> adl_lags;
private:
expr_t computeDerivative(int deriv_id) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
int cost(int cost, bool is_matlab) const override;
int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
@ -1029,7 +1037,6 @@ public:
expr_t toStatic(DataTree &static_datatree) const override;
void computeXrefs(EquationInfo &ei) const override;
BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
int maxEndoLead() const override;
int maxExoLead() const override;
int maxEndoLag() const override;
@ -1091,6 +1098,7 @@ public:
const string adlparam;
private:
expr_t computeDerivative(int deriv_id) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
int cost(int cost, bool is_matlab) const override;
int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
@ -1137,7 +1145,6 @@ public:
//! Try to normalize an equation with respect to a given dynamic variable.
/*! Should only be called on Equal nodes. The variable must appear in the equation. */
BinaryOpNode *normalizeEquation(int symb_id, int lag) const;
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
int maxEndoLead() const override;
int maxExoLead() const override;
int maxEndoLag() const override;
@ -1235,6 +1242,7 @@ protected:
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
private:
expr_t computeDerivative(int deriv_id) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
int cost(int cost, bool is_matlab) const override;
int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
@ -1276,7 +1284,6 @@ public:
expr_t toStatic(DataTree &static_datatree) const override;
void computeXrefs(EquationInfo &ei) const override;
BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
int maxEndoLead() const override;
int maxExoLead() const override;
int maxEndoLag() const override;
@ -1331,6 +1338,7 @@ public:
const vector<expr_t> arguments;
private:
expr_t computeDerivative(int deriv_id) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
virtual expr_t composeDerivatives(const vector<expr_t> &dargs) = 0;
protected:
//! Thrown when trying to access an unknown entry in external_function_node_map
@ -1389,7 +1397,6 @@ public:
expr_t toStatic(DataTree &static_datatree) const override = 0;
void computeXrefs(EquationInfo &ei) const override = 0;
BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
int maxEndoLead() const override;
int maxExoLead() const override;
int maxEndoLag() const override;
@ -1568,7 +1575,6 @@ public:
expr_t toStatic(DataTree &static_datatree) const override;
void prepareForDerivation() override;
expr_t computeDerivative(int deriv_id) override;
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
int maxEndoLead() const override;
int maxExoLead() const override;
int maxEndoLag() const override;
@ -1615,6 +1621,8 @@ public:
expr_t substituteLogTransform(int orig_symb_id, int aux_symb_id) const override;
protected:
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
private:
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
};
class VarExpectationNode : public SubModelNode

View File

@ -952,14 +952,14 @@ StaticModel::computeChainRuleJacobian()
&& simulation_type != BlockSimulationType::solveTwoBoundariesComplete);
int size = blocks[blk].size;
map<pair<expr_t, int>, expr_t> chain_rule_deriv_cache;
for (int eq = nb_recursives; eq < size; eq++)
{
int eq_orig = getBlockEquationID(blk, eq);
for (int var = nb_recursives; var < size; var++)
{
int var_orig = getBlockVariableID(blk, var);
expr_t d1 = equations[eq_orig]->getChainRuleDerivative(getDerivID(symbol_table.getID(SymbolType::endogenous, var_orig), 0), recursive_vars);
expr_t d1 = equations[eq_orig]->getChainRuleDerivative(getDerivID(symbol_table.getID(SymbolType::endogenous, var_orig), 0), recursive_vars, chain_rule_deriv_cache);
if (d1 != Zero)
blocks_derivatives[blk][{ eq, var, 0 }] = d1;
}