Performance improvement of chain rule derivation, using caching
Useful for mfs > 0 on large models.master
parent
dbc2851606
commit
23b0c12d8e
|
@ -3124,6 +3124,7 @@ DynamicModel::computeChainRuleJacobian()
|
|||
}
|
||||
|
||||
// Compute the block derivatives
|
||||
map<pair<expr_t, int>, expr_t> chain_rule_deriv_cache;
|
||||
for (const auto &[indices, derivType] : determineBlockDerivativesType(blk))
|
||||
{
|
||||
auto [lag, eq, var] = indices;
|
||||
|
@ -3140,10 +3141,10 @@ DynamicModel::computeChainRuleJacobian()
|
|||
d = Zero;
|
||||
break;
|
||||
case BlockDerivativeType::chainRule:
|
||||
d = equations[eq_orig]->getChainRuleDerivative(deriv_id, recursive_vars);
|
||||
d = equations[eq_orig]->getChainRuleDerivative(deriv_id, recursive_vars, chain_rule_deriv_cache);
|
||||
break;
|
||||
case BlockDerivativeType::normalizedChainRule:
|
||||
d = equation_type_and_normalized_equation[eq_orig].second->getChainRuleDerivative(deriv_id, recursive_vars);
|
||||
d = equation_type_and_normalized_equation[eq_orig].second->getChainRuleDerivative(deriv_id, recursive_vars, chain_rule_deriv_cache);
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
|
@ -53,6 +53,22 @@ ExprNode::getDerivative(int deriv_id)
|
|||
}
|
||||
}
|
||||
|
||||
expr_t
|
||||
ExprNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<pair<expr_t, int>, expr_t> &cache)
|
||||
{
|
||||
pair key {this, deriv_id};
|
||||
if (auto it = cache.find(key);
|
||||
it != cache.end())
|
||||
return it->second;
|
||||
|
||||
auto r = computeChainRuleDerivative(deriv_id, recursive_variables, cache);
|
||||
|
||||
auto [ignore, success] = cache.emplace(key, r);
|
||||
assert(success); // The element should not already exist
|
||||
return r;
|
||||
}
|
||||
|
||||
int
|
||||
ExprNode::precedence([[maybe_unused]] ExprNodeOutputType output_type,
|
||||
[[maybe_unused]] const temporary_terms_t &temporary_terms) const
|
||||
|
@ -546,8 +562,9 @@ NumConstNode::normalizeEquationHelper([[maybe_unused]] const set<expr_t> &contai
|
|||
}
|
||||
|
||||
expr_t
|
||||
NumConstNode::getChainRuleDerivative([[maybe_unused]] int deriv_id,
|
||||
[[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables)
|
||||
NumConstNode::computeChainRuleDerivative([[maybe_unused]] int deriv_id,
|
||||
[[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
|
||||
[[maybe_unused]] map<pair<expr_t, int>, expr_t> &cache)
|
||||
{
|
||||
return datatree.Zero;
|
||||
}
|
||||
|
@ -1402,7 +1419,9 @@ VariableNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs
|
|||
}
|
||||
|
||||
expr_t
|
||||
VariableNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
|
||||
VariableNode::computeChainRuleDerivative(int deriv_id,
|
||||
const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<pair<expr_t, int>, expr_t> &cache)
|
||||
{
|
||||
switch (get_type())
|
||||
{
|
||||
|
@ -1421,12 +1440,12 @@ VariableNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *
|
|||
// If there is in the equation a recursive variable we could use a chaine rule derivation
|
||||
else if (auto it = recursive_variables.find(datatree.getDerivID(symb_id, lag));
|
||||
it != recursive_variables.end())
|
||||
return it->second->arg2->getChainRuleDerivative(deriv_id, recursive_variables);
|
||||
return it->second->arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache);
|
||||
else
|
||||
return datatree.Zero;
|
||||
|
||||
case SymbolType::modelLocalVariable:
|
||||
return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables);
|
||||
return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables, cache);
|
||||
case SymbolType::modFileLocalVariable:
|
||||
cerr << "modFileLocalVariable is not derivable" << endl;
|
||||
exit(EXIT_FAILURE);
|
||||
|
@ -1439,7 +1458,7 @@ VariableNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *
|
|||
case SymbolType::externalFunction:
|
||||
case SymbolType::epilogue:
|
||||
case SymbolType::excludedVariable:
|
||||
cerr << "VariableNode::getChainRuleDerivative: Impossible case" << endl;
|
||||
cerr << "VariableNode::computeChainRuleDerivative: Impossible case" << endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
// Suppress GCC warning
|
||||
|
@ -3224,9 +3243,11 @@ UnaryOpNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs)
|
|||
}
|
||||
|
||||
expr_t
|
||||
UnaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
|
||||
UnaryOpNode::computeChainRuleDerivative(int deriv_id,
|
||||
const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<pair<expr_t, int>, expr_t> &cache)
|
||||
{
|
||||
expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables);
|
||||
expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables, cache);
|
||||
return composeDerivatives(darg, deriv_id);
|
||||
}
|
||||
|
||||
|
@ -4978,10 +4999,12 @@ BinaryOpNode::normalizeEquation(int symb_id, int lag) const
|
|||
}
|
||||
|
||||
expr_t
|
||||
BinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
|
||||
BinaryOpNode::computeChainRuleDerivative(int deriv_id,
|
||||
const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<pair<expr_t, int>, expr_t> &cache)
|
||||
{
|
||||
expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables);
|
||||
expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables);
|
||||
expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, cache);
|
||||
expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache);
|
||||
return composeDerivatives(darg1, darg2);
|
||||
}
|
||||
|
||||
|
@ -6289,11 +6312,13 @@ TrinaryOpNode::normalizeEquationHelper([[maybe_unused]] const set<expr_t> &conta
|
|||
}
|
||||
|
||||
expr_t
|
||||
TrinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
|
||||
TrinaryOpNode::computeChainRuleDerivative(int deriv_id,
|
||||
const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<pair<expr_t, int>, expr_t> &cache)
|
||||
{
|
||||
expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables);
|
||||
expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables);
|
||||
expr_t darg3 = arg3->getChainRuleDerivative(deriv_id, recursive_variables);
|
||||
expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, cache);
|
||||
expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache);
|
||||
expr_t darg3 = arg3->getChainRuleDerivative(deriv_id, recursive_variables, cache);
|
||||
return composeDerivatives(darg1, darg2, darg3);
|
||||
}
|
||||
|
||||
|
@ -6717,12 +6742,14 @@ AbstractExternalFunctionNode::computeDerivative(int deriv_id)
|
|||
}
|
||||
|
||||
expr_t
|
||||
AbstractExternalFunctionNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
|
||||
AbstractExternalFunctionNode::computeChainRuleDerivative(int deriv_id,
|
||||
const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<pair<expr_t, int>, expr_t> &cache)
|
||||
{
|
||||
assert(datatree.external_functions_table.getNargs(symb_id) > 0);
|
||||
vector<expr_t> dargs;
|
||||
for (auto argument : arguments)
|
||||
dargs.push_back(argument->getChainRuleDerivative(deriv_id, recursive_variables));
|
||||
dargs.push_back(argument->getChainRuleDerivative(deriv_id, recursive_variables, cache));
|
||||
return composeDerivatives(dargs);
|
||||
}
|
||||
|
||||
|
@ -8334,10 +8361,11 @@ SubModelNode::computeDerivative([[maybe_unused]] int deriv_id)
|
|||
}
|
||||
|
||||
expr_t
|
||||
SubModelNode::getChainRuleDerivative([[maybe_unused]] int deriv_id,
|
||||
[[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables)
|
||||
SubModelNode::computeChainRuleDerivative([[maybe_unused]] int deriv_id,
|
||||
[[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
|
||||
[[maybe_unused]] map<pair<expr_t, int>, expr_t> &cache)
|
||||
{
|
||||
cerr << "SubModelNode::getChainRuleDerivative not implemented." << endl;
|
||||
cerr << "SubModelNode::computeChainRuleDerivative not implemented." << endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
|
|
|
@ -247,6 +247,10 @@ private:
|
|||
/*! You shoud use getDerivative() to get the benefit of symbolic a priori and of caching */
|
||||
virtual expr_t computeDerivative(int deriv_id) = 0;
|
||||
|
||||
/* Internal helper for getChainRuleDerivative(), that does the computation
|
||||
but assumes that the caching of this is handled elsewhere */
|
||||
virtual expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) = 0;
|
||||
|
||||
protected:
|
||||
//! Reference to the enclosing DataTree
|
||||
DataTree &datatree;
|
||||
|
@ -327,12 +331,15 @@ public:
|
|||
For an equal node, returns the derivative of lhs minus rhs */
|
||||
expr_t getDerivative(int deriv_id);
|
||||
|
||||
//! Computes derivatives by applying the chain rule for some variables
|
||||
/*!
|
||||
\param deriv_id The derivation ID with respect to which we are derivating
|
||||
\param recursive_variables Contains the derivation ID for which chain rules must be applied. Keys are derivation IDs, values are equations of the form x=f(y) where x is the key variable and x doesn't appear in y
|
||||
*/
|
||||
virtual expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) = 0;
|
||||
/* Computes derivatives by applying the chain rule for some variables.
|
||||
— “recursive_variables” contains the derivation ID for which chain rules
|
||||
must be applied. Keys are derivation IDs, values are equations of the
|
||||
form x=f(y) where x is the key variable and x doesn't appear in y
|
||||
— “cache” is used to store already-computed derivatives (in a map
|
||||
<expression, deriv_id> → derivative); this cache is specific to a given
|
||||
value of “recursive_variables”, and thus should not be reused accross
|
||||
calls that use different values of “recursive_variables”. */
|
||||
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache);
|
||||
|
||||
//! Returns precedence of node
|
||||
/*! Equals 100 for constants, variables, unary ops, and temporary terms */
|
||||
|
@ -836,6 +843,7 @@ public:
|
|||
const int id;
|
||||
private:
|
||||
expr_t computeDerivative(int deriv_id) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
protected:
|
||||
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override;
|
||||
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
|
||||
|
@ -853,7 +861,6 @@ public:
|
|||
expr_t toStatic(DataTree &static_datatree) const override;
|
||||
void computeXrefs(EquationInfo &ei) const override;
|
||||
BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
|
||||
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
|
||||
int maxEndoLead() const override;
|
||||
int maxExoLead() const override;
|
||||
int maxEndoLag() const override;
|
||||
|
@ -908,6 +915,7 @@ public:
|
|||
const int lag;
|
||||
private:
|
||||
expr_t computeDerivative(int deriv_id) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
protected:
|
||||
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override;
|
||||
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
|
||||
|
@ -926,7 +934,6 @@ public:
|
|||
void computeXrefs(EquationInfo &ei) const override;
|
||||
SymbolType get_type() const;
|
||||
BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
|
||||
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
|
||||
int maxEndoLead() const override;
|
||||
int maxExoLead() const override;
|
||||
int maxEndoLag() const override;
|
||||
|
@ -990,6 +997,7 @@ public:
|
|||
const vector<int> adl_lags;
|
||||
private:
|
||||
expr_t computeDerivative(int deriv_id) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
int cost(int cost, bool is_matlab) const override;
|
||||
int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
|
||||
int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
|
||||
|
@ -1029,7 +1037,6 @@ public:
|
|||
expr_t toStatic(DataTree &static_datatree) const override;
|
||||
void computeXrefs(EquationInfo &ei) const override;
|
||||
BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
|
||||
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
|
||||
int maxEndoLead() const override;
|
||||
int maxExoLead() const override;
|
||||
int maxEndoLag() const override;
|
||||
|
@ -1091,6 +1098,7 @@ public:
|
|||
const string adlparam;
|
||||
private:
|
||||
expr_t computeDerivative(int deriv_id) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
int cost(int cost, bool is_matlab) const override;
|
||||
int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
|
||||
int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
|
||||
|
@ -1137,7 +1145,6 @@ public:
|
|||
//! Try to normalize an equation with respect to a given dynamic variable.
|
||||
/*! Should only be called on Equal nodes. The variable must appear in the equation. */
|
||||
BinaryOpNode *normalizeEquation(int symb_id, int lag) const;
|
||||
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
|
||||
int maxEndoLead() const override;
|
||||
int maxExoLead() const override;
|
||||
int maxEndoLag() const override;
|
||||
|
@ -1235,6 +1242,7 @@ protected:
|
|||
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
|
||||
private:
|
||||
expr_t computeDerivative(int deriv_id) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
int cost(int cost, bool is_matlab) const override;
|
||||
int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
|
||||
int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
|
||||
|
@ -1276,7 +1284,6 @@ public:
|
|||
expr_t toStatic(DataTree &static_datatree) const override;
|
||||
void computeXrefs(EquationInfo &ei) const override;
|
||||
BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
|
||||
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
|
||||
int maxEndoLead() const override;
|
||||
int maxExoLead() const override;
|
||||
int maxEndoLag() const override;
|
||||
|
@ -1331,6 +1338,7 @@ public:
|
|||
const vector<expr_t> arguments;
|
||||
private:
|
||||
expr_t computeDerivative(int deriv_id) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
virtual expr_t composeDerivatives(const vector<expr_t> &dargs) = 0;
|
||||
protected:
|
||||
//! Thrown when trying to access an unknown entry in external_function_node_map
|
||||
|
@ -1389,7 +1397,6 @@ public:
|
|||
expr_t toStatic(DataTree &static_datatree) const override = 0;
|
||||
void computeXrefs(EquationInfo &ei) const override = 0;
|
||||
BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
|
||||
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
|
||||
int maxEndoLead() const override;
|
||||
int maxExoLead() const override;
|
||||
int maxEndoLag() const override;
|
||||
|
@ -1568,7 +1575,6 @@ public:
|
|||
expr_t toStatic(DataTree &static_datatree) const override;
|
||||
void prepareForDerivation() override;
|
||||
expr_t computeDerivative(int deriv_id) override;
|
||||
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
|
||||
int maxEndoLead() const override;
|
||||
int maxExoLead() const override;
|
||||
int maxEndoLag() const override;
|
||||
|
@ -1615,6 +1621,8 @@ public:
|
|||
expr_t substituteLogTransform(int orig_symb_id, int aux_symb_id) const override;
|
||||
protected:
|
||||
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
|
||||
private:
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
};
|
||||
|
||||
class VarExpectationNode : public SubModelNode
|
||||
|
|
|
@ -952,14 +952,14 @@ StaticModel::computeChainRuleJacobian()
|
|||
&& simulation_type != BlockSimulationType::solveTwoBoundariesComplete);
|
||||
|
||||
int size = blocks[blk].size;
|
||||
|
||||
map<pair<expr_t, int>, expr_t> chain_rule_deriv_cache;
|
||||
for (int eq = nb_recursives; eq < size; eq++)
|
||||
{
|
||||
int eq_orig = getBlockEquationID(blk, eq);
|
||||
for (int var = nb_recursives; var < size; var++)
|
||||
{
|
||||
int var_orig = getBlockVariableID(blk, var);
|
||||
expr_t d1 = equations[eq_orig]->getChainRuleDerivative(getDerivID(symbol_table.getID(SymbolType::endogenous, var_orig), 0), recursive_vars);
|
||||
expr_t d1 = equations[eq_orig]->getChainRuleDerivative(getDerivID(symbol_table.getID(SymbolType::endogenous, var_orig), 0), recursive_vars, chain_rule_deriv_cache);
|
||||
if (d1 != Zero)
|
||||
blocks_derivatives[blk][{ eq, var, 0 }] = d1;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue