Optimization: use std::unordered_map instead of std::map for caching chain rule derivation

Improves performance on very very large models (tens of thousands of equations).
master
Sébastien Villemot 2023-04-05 14:16:22 +02:00
parent b9bfcaad5d
commit e22d9049ee
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
4 changed files with 57 additions and 51 deletions

View File

@ -2320,8 +2320,8 @@ DynamicModel::computeChainRuleJacobian()
}
// Compute the block derivatives
map<expr_t, set<int>> non_null_chain_rule_derivatives;
map<pair<expr_t, int>, expr_t> chain_rule_deriv_cache;
unordered_map<expr_t, set<int>> non_null_chain_rule_derivatives;
unordered_map<expr_t, map<int, expr_t>> chain_rule_deriv_cache;
for (const auto &[indices, derivType] : determineBlockDerivativesType(blk))
{
auto [lag, eq, var] = indices;

View File

@ -56,8 +56,8 @@ ExprNode::getDerivative(int deriv_id)
expr_t
ExprNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
map<pair<expr_t, int>, expr_t> &cache)
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
unordered_map<expr_t, map<int, expr_t>> &cache)
{
if (!non_null_chain_rule_derivatives.contains(this))
prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
@ -67,15 +67,16 @@ ExprNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &r
return datatree.Zero;
// If derivative is in the cache, return that value
pair key {this, deriv_id};
if (auto it = cache.find(key);
if (auto it = cache.find(this);
it != cache.end())
return it->second;
if (auto it2 = it->second.find(deriv_id);
it2 != it->second.end())
return it2->second;
auto r = computeChainRuleDerivative(deriv_id, recursive_variables,
non_null_chain_rule_derivatives, cache);
auto [ignore, success] = cache.emplace(key, r);
auto [ignore, success] = cache[this].emplace(deriv_id, r);
assert(success); // The element should not already exist
return r;
}
@ -489,7 +490,7 @@ NumConstNode::prepareForDerivation()
void
NumConstNode::prepareForChainRuleDerivation([[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
{
non_null_chain_rule_derivatives.try_emplace(const_cast<NumConstNode *>(this));
}
@ -582,8 +583,8 @@ NumConstNode::normalizeEquationHelper([[maybe_unused]] const set<expr_t> &contai
expr_t
NumConstNode::computeChainRuleDerivative([[maybe_unused]] int deriv_id,
[[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
[[maybe_unused]] map<expr_t, set<int>> &non_null_chain_rule_derivatives,
[[maybe_unused]] map<pair<expr_t, int>, expr_t> &cache)
[[maybe_unused]] unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
[[maybe_unused]] unordered_map<expr_t, map<int, expr_t>> &cache)
{
return datatree.Zero;
}
@ -911,7 +912,7 @@ VariableNode::prepareForDerivation()
void
VariableNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
{
if (non_null_chain_rule_derivatives.contains(const_cast<VariableNode *>(this)))
return;
@ -1470,8 +1471,8 @@ VariableNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs
expr_t
VariableNode::computeChainRuleDerivative(int deriv_id,
const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
map<pair<expr_t, int>, expr_t> &cache)
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
unordered_map<expr_t, map<int, expr_t>> &cache)
{
switch (get_type())
{
@ -2202,7 +2203,7 @@ UnaryOpNode::prepareForDerivation()
void
UnaryOpNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
{
if (non_null_chain_rule_derivatives.contains(const_cast<UnaryOpNode *>(this)))
return;
@ -3342,8 +3343,8 @@ UnaryOpNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs)
expr_t
UnaryOpNode::computeChainRuleDerivative(int deriv_id,
const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
map<pair<expr_t, int>, expr_t> &cache)
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
unordered_map<expr_t, map<int, expr_t>> &cache)
{
expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
return composeDerivatives(darg, deriv_id);
@ -4022,7 +4023,7 @@ BinaryOpNode::prepareForDerivation()
void
BinaryOpNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
{
if (non_null_chain_rule_derivatives.contains(const_cast<BinaryOpNode *>(this)))
return;
@ -5090,8 +5091,8 @@ BinaryOpNode::normalizeEquation(int symb_id, int lag) const
expr_t
BinaryOpNode::computeChainRuleDerivative(int deriv_id,
const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
map<pair<expr_t, int>, expr_t> &cache)
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
unordered_map<expr_t, map<int, expr_t>> &cache)
{
expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
@ -5890,7 +5891,7 @@ TrinaryOpNode::prepareForDerivation()
void
TrinaryOpNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
{
if (non_null_chain_rule_derivatives.contains(const_cast<TrinaryOpNode *>(this)))
return;
@ -6375,8 +6376,8 @@ TrinaryOpNode::normalizeEquationHelper([[maybe_unused]] const set<expr_t> &conta
expr_t
TrinaryOpNode::computeChainRuleDerivative(int deriv_id,
const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
map<pair<expr_t, int>, expr_t> &cache)
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
unordered_map<expr_t, map<int, expr_t>> &cache)
{
expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
@ -6723,7 +6724,7 @@ AbstractExternalFunctionNode::prepareForDerivation()
void
AbstractExternalFunctionNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
{
if (non_null_chain_rule_derivatives.contains(const_cast<AbstractExternalFunctionNode *>(this)))
return;
@ -6758,8 +6759,8 @@ AbstractExternalFunctionNode::computeDerivative(int deriv_id)
expr_t
AbstractExternalFunctionNode::computeChainRuleDerivative(int deriv_id,
const map<int, BinaryOpNode *> &recursive_variables,
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
map<pair<expr_t, int>, expr_t> &cache)
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
unordered_map<expr_t, map<int, expr_t>> &cache)
{
assert(datatree.external_functions_table.getNargs(symb_id) > 0);
vector<expr_t> dargs;
@ -8238,7 +8239,7 @@ SubModelNode::prepareForDerivation()
void
SubModelNode::prepareForChainRuleDerivation([[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
[[maybe_unused]] map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
[[maybe_unused]] unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
{
cerr << "SubModelNode::prepareForChainRuleDerivation not implemented." << endl;
exit(EXIT_FAILURE);
@ -8254,8 +8255,8 @@ SubModelNode::computeDerivative([[maybe_unused]] int deriv_id)
expr_t
SubModelNode::computeChainRuleDerivative([[maybe_unused]] int deriv_id,
[[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
[[maybe_unused]] map<expr_t, set<int>> &non_null_chain_rule_derivatives,
[[maybe_unused]] map<pair<expr_t, int>, expr_t> &cache)
[[maybe_unused]] unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
[[maybe_unused]] unordered_map<expr_t, map<int, expr_t>> &cache)
{
cerr << "SubModelNode::computeChainRuleDerivative not implemented." << endl;
exit(EXIT_FAILURE);

View File

@ -27,6 +27,7 @@
#include <functional>
#include <optional>
#include <utility>
#include <unordered_map>
using namespace std;
@ -250,7 +251,7 @@ private:
/* Internal helper for getChainRuleDerivative(), that does the computation
but assumes that the caching of this is handled elsewhere */
virtual expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) = 0;
virtual expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) = 0;
protected:
//! Reference to the enclosing DataTree
@ -285,7 +286,7 @@ protected:
recursive_variables. Note that all non-endogenous variables are
automatically considered to have a zero derivative (since theyre never
used in a chain rule context) */
virtual void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const = 0;
virtual void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const = 0;
//! Cost of computing current node
/*! Nodes included in temporary_terms are considered having a null cost */
@ -353,8 +354,11 @@ public:
variable (since such variables are never used in a chain rule context).
NB 2: non_null_chain_rule_derivatives and cache are specific to a given
value of recursive_variables, and thus should not be reused accross
calls that use different values of recursive_variables. */
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache);
calls that use different values of recursive_variables.
NB 3: the use of std::unordered_map instead of std::map for caching
purposes improves performance on very very large models (tens of thousands
of equations) */
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache);
//! Returns precedence of node
/*! Equals 100 for constants, variables, unary ops, and temporary terms */
@ -855,10 +859,10 @@ public:
const int id;
private:
expr_t computeDerivative(int deriv_id) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) override;
protected:
void prepareForDerivation() override;
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
public:
@ -927,10 +931,10 @@ public:
const int lag;
private:
expr_t computeDerivative(int deriv_id) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) override;
protected:
void prepareForDerivation() override;
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
public:
@ -996,7 +1000,7 @@ class UnaryOpNode : public ExprNode
{
protected:
void prepareForDerivation() override;
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
// Returns the node obtained by applying a transformation recursively on the argument (in same datatree)
@ -1018,7 +1022,7 @@ public:
const vector<int> adl_lags;
private:
expr_t computeDerivative(int deriv_id) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) override;
int cost(int cost, bool is_matlab) const override;
int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
@ -1108,7 +1112,7 @@ class BinaryOpNode : public ExprNode
{
protected:
void prepareForDerivation() override;
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
public:
@ -1118,7 +1122,7 @@ public:
const string adlparam;
private:
expr_t computeDerivative(int deriv_id) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) override;
int cost(int cost, bool is_matlab) const override;
int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
@ -1264,11 +1268,11 @@ public:
const TrinaryOpcode op_code;
protected:
void prepareForDerivation() override;
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
private:
expr_t computeDerivative(int deriv_id) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) override;
int cost(int cost, bool is_matlab) const override;
int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
@ -1371,7 +1375,7 @@ public:
const vector<expr_t> arguments;
private:
expr_t computeDerivative(int deriv_id) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) override;
virtual expr_t composeDerivatives(const vector<expr_t> &dargs) = 0;
// Computes the maximum of f applied to all arguments (result will always be non-negative)
int maxHelper(const function<int (expr_t)> &f) const;
@ -1391,7 +1395,7 @@ protected:
{
};
void prepareForDerivation() override;
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
//! Returns true if the given external function has been written as a temporary term
bool alreadyWrittenAsTefTerm(int the_symb_id, const deriv_node_temp_terms_t &tef_terms) const;
//! Returns the index in the tef_terms map of this external function
@ -1657,10 +1661,10 @@ public:
expr_t substituteLogTransform(int orig_symb_id, int aux_symb_id) const override;
protected:
void prepareForDerivation() override;
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
private:
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) override;
};
class VarExpectationNode : public SubModelNode

View File

@ -24,6 +24,7 @@
#include <algorithm>
#include <sstream>
#include <numeric>
#include <unordered_map>
#include "StaticModel.hh"
#include "DynamicModel.hh"
@ -650,8 +651,8 @@ StaticModel::computeChainRuleJacobian()
&& simulation_type != BlockSimulationType::solveTwoBoundariesComplete);
int size = blocks[blk].size;
map<expr_t, set<int>> non_null_chain_rule_derivatives;
map<pair<expr_t, int>, expr_t> chain_rule_deriv_cache;
unordered_map<expr_t, set<int>> non_null_chain_rule_derivatives;
unordered_map<expr_t, map<int, expr_t>> chain_rule_deriv_cache;
for (int eq = nb_recursives; eq < size; eq++)
{
int eq_orig = getBlockEquationID(blk, eq);
@ -822,8 +823,8 @@ StaticModel::computeRamseyMultipliersDerivatives(int ramsey_orig_endo_nbr, bool
}
// Compute the chain rule derivatives w.r.t. multipliers
map<expr_t, set<int>> non_null_chain_rule_derivatives;
map<pair<expr_t, int>, expr_t> cache;
unordered_map<expr_t, set<int>> non_null_chain_rule_derivatives;
unordered_map<expr_t, map<int, expr_t>> cache;
for (int eq {0}; eq < ramsey_orig_endo_nbr; eq++)
for (int mult {0}; mult < static_cast<int>(mult_deriv_ids.size()); mult++)
if (expr_t d { equations[eq]->getChainRuleDerivative(mult_deriv_ids[mult], recursive_variables,