Optimization: use std::unordered_map instead of std::map for caching chain rule derivation
Improves performance on very very large models (tens of thousands of equations).master
parent
b9bfcaad5d
commit
e22d9049ee
|
@ -2320,8 +2320,8 @@ DynamicModel::computeChainRuleJacobian()
|
|||
}
|
||||
|
||||
// Compute the block derivatives
|
||||
map<expr_t, set<int>> non_null_chain_rule_derivatives;
|
||||
map<pair<expr_t, int>, expr_t> chain_rule_deriv_cache;
|
||||
unordered_map<expr_t, set<int>> non_null_chain_rule_derivatives;
|
||||
unordered_map<expr_t, map<int, expr_t>> chain_rule_deriv_cache;
|
||||
for (const auto &[indices, derivType] : determineBlockDerivativesType(blk))
|
||||
{
|
||||
auto [lag, eq, var] = indices;
|
||||
|
|
|
@ -56,8 +56,8 @@ ExprNode::getDerivative(int deriv_id)
|
|||
|
||||
expr_t
|
||||
ExprNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
map<pair<expr_t, int>, expr_t> &cache)
|
||||
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
unordered_map<expr_t, map<int, expr_t>> &cache)
|
||||
{
|
||||
if (!non_null_chain_rule_derivatives.contains(this))
|
||||
prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
|
||||
|
@ -67,15 +67,16 @@ ExprNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &r
|
|||
return datatree.Zero;
|
||||
|
||||
// If derivative is in the cache, return that value
|
||||
pair key {this, deriv_id};
|
||||
if (auto it = cache.find(key);
|
||||
if (auto it = cache.find(this);
|
||||
it != cache.end())
|
||||
return it->second;
|
||||
if (auto it2 = it->second.find(deriv_id);
|
||||
it2 != it->second.end())
|
||||
return it2->second;
|
||||
|
||||
auto r = computeChainRuleDerivative(deriv_id, recursive_variables,
|
||||
non_null_chain_rule_derivatives, cache);
|
||||
|
||||
auto [ignore, success] = cache.emplace(key, r);
|
||||
auto [ignore, success] = cache[this].emplace(deriv_id, r);
|
||||
assert(success); // The element should not already exist
|
||||
return r;
|
||||
}
|
||||
|
@ -489,7 +490,7 @@ NumConstNode::prepareForDerivation()
|
|||
|
||||
void
|
||||
NumConstNode::prepareForChainRuleDerivation([[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
|
||||
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
|
||||
{
|
||||
non_null_chain_rule_derivatives.try_emplace(const_cast<NumConstNode *>(this));
|
||||
}
|
||||
|
@ -582,8 +583,8 @@ NumConstNode::normalizeEquationHelper([[maybe_unused]] const set<expr_t> &contai
|
|||
expr_t
|
||||
NumConstNode::computeChainRuleDerivative([[maybe_unused]] int deriv_id,
|
||||
[[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
|
||||
[[maybe_unused]] map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
[[maybe_unused]] map<pair<expr_t, int>, expr_t> &cache)
|
||||
[[maybe_unused]] unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
[[maybe_unused]] unordered_map<expr_t, map<int, expr_t>> &cache)
|
||||
{
|
||||
return datatree.Zero;
|
||||
}
|
||||
|
@ -911,7 +912,7 @@ VariableNode::prepareForDerivation()
|
|||
|
||||
void
|
||||
VariableNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
|
||||
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
|
||||
{
|
||||
if (non_null_chain_rule_derivatives.contains(const_cast<VariableNode *>(this)))
|
||||
return;
|
||||
|
@ -1470,8 +1471,8 @@ VariableNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs
|
|||
expr_t
|
||||
VariableNode::computeChainRuleDerivative(int deriv_id,
|
||||
const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
map<pair<expr_t, int>, expr_t> &cache)
|
||||
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
unordered_map<expr_t, map<int, expr_t>> &cache)
|
||||
{
|
||||
switch (get_type())
|
||||
{
|
||||
|
@ -2202,7 +2203,7 @@ UnaryOpNode::prepareForDerivation()
|
|||
|
||||
void
|
||||
UnaryOpNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
|
||||
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
|
||||
{
|
||||
if (non_null_chain_rule_derivatives.contains(const_cast<UnaryOpNode *>(this)))
|
||||
return;
|
||||
|
@ -3342,8 +3343,8 @@ UnaryOpNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs)
|
|||
expr_t
|
||||
UnaryOpNode::computeChainRuleDerivative(int deriv_id,
|
||||
const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
map<pair<expr_t, int>, expr_t> &cache)
|
||||
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
unordered_map<expr_t, map<int, expr_t>> &cache)
|
||||
{
|
||||
expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
|
||||
return composeDerivatives(darg, deriv_id);
|
||||
|
@ -4022,7 +4023,7 @@ BinaryOpNode::prepareForDerivation()
|
|||
|
||||
void
|
||||
BinaryOpNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
|
||||
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
|
||||
{
|
||||
if (non_null_chain_rule_derivatives.contains(const_cast<BinaryOpNode *>(this)))
|
||||
return;
|
||||
|
@ -5090,8 +5091,8 @@ BinaryOpNode::normalizeEquation(int symb_id, int lag) const
|
|||
expr_t
|
||||
BinaryOpNode::computeChainRuleDerivative(int deriv_id,
|
||||
const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
map<pair<expr_t, int>, expr_t> &cache)
|
||||
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
unordered_map<expr_t, map<int, expr_t>> &cache)
|
||||
{
|
||||
expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
|
||||
expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
|
||||
|
@ -5890,7 +5891,7 @@ TrinaryOpNode::prepareForDerivation()
|
|||
|
||||
void
|
||||
TrinaryOpNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
|
||||
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
|
||||
{
|
||||
if (non_null_chain_rule_derivatives.contains(const_cast<TrinaryOpNode *>(this)))
|
||||
return;
|
||||
|
@ -6375,8 +6376,8 @@ TrinaryOpNode::normalizeEquationHelper([[maybe_unused]] const set<expr_t> &conta
|
|||
expr_t
|
||||
TrinaryOpNode::computeChainRuleDerivative(int deriv_id,
|
||||
const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
map<pair<expr_t, int>, expr_t> &cache)
|
||||
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
unordered_map<expr_t, map<int, expr_t>> &cache)
|
||||
{
|
||||
expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
|
||||
expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
|
||||
|
@ -6723,7 +6724,7 @@ AbstractExternalFunctionNode::prepareForDerivation()
|
|||
|
||||
void
|
||||
AbstractExternalFunctionNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
|
||||
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
|
||||
{
|
||||
if (non_null_chain_rule_derivatives.contains(const_cast<AbstractExternalFunctionNode *>(this)))
|
||||
return;
|
||||
|
@ -6758,8 +6759,8 @@ AbstractExternalFunctionNode::computeDerivative(int deriv_id)
|
|||
expr_t
|
||||
AbstractExternalFunctionNode::computeChainRuleDerivative(int deriv_id,
|
||||
const map<int, BinaryOpNode *> &recursive_variables,
|
||||
map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
map<pair<expr_t, int>, expr_t> &cache)
|
||||
unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
unordered_map<expr_t, map<int, expr_t>> &cache)
|
||||
{
|
||||
assert(datatree.external_functions_table.getNargs(symb_id) > 0);
|
||||
vector<expr_t> dargs;
|
||||
|
@ -8238,7 +8239,7 @@ SubModelNode::prepareForDerivation()
|
|||
|
||||
void
|
||||
SubModelNode::prepareForChainRuleDerivation([[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
|
||||
[[maybe_unused]] map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
|
||||
[[maybe_unused]] unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
|
||||
{
|
||||
cerr << "SubModelNode::prepareForChainRuleDerivation not implemented." << endl;
|
||||
exit(EXIT_FAILURE);
|
||||
|
@ -8254,8 +8255,8 @@ SubModelNode::computeDerivative([[maybe_unused]] int deriv_id)
|
|||
expr_t
|
||||
SubModelNode::computeChainRuleDerivative([[maybe_unused]] int deriv_id,
|
||||
[[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
|
||||
[[maybe_unused]] map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
[[maybe_unused]] map<pair<expr_t, int>, expr_t> &cache)
|
||||
[[maybe_unused]] unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
|
||||
[[maybe_unused]] unordered_map<expr_t, map<int, expr_t>> &cache)
|
||||
{
|
||||
cerr << "SubModelNode::computeChainRuleDerivative not implemented." << endl;
|
||||
exit(EXIT_FAILURE);
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include <functional>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
@ -250,7 +251,7 @@ private:
|
|||
|
||||
/* Internal helper for getChainRuleDerivative(), that does the computation
|
||||
but assumes that the caching of this is handled elsewhere */
|
||||
virtual expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) = 0;
|
||||
virtual expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) = 0;
|
||||
|
||||
protected:
|
||||
//! Reference to the enclosing DataTree
|
||||
|
@ -285,7 +286,7 @@ protected:
|
|||
“recursive_variables”. Note that all non-endogenous variables are
|
||||
automatically considered to have a zero derivative (since they’re never
|
||||
used in a chain rule context) */
|
||||
virtual void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const = 0;
|
||||
virtual void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const = 0;
|
||||
|
||||
//! Cost of computing current node
|
||||
/*! Nodes included in temporary_terms are considered having a null cost */
|
||||
|
@ -353,8 +354,11 @@ public:
|
|||
variable (since such variables are never used in a chain rule context).
|
||||
NB 2: “non_null_chain_rule_derivatives” and “cache” are specific to a given
|
||||
value of “recursive_variables”, and thus should not be reused accross
|
||||
calls that use different values of “recursive_variables”. */
|
||||
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache);
|
||||
calls that use different values of “recursive_variables”.
|
||||
NB 3: the use of std::unordered_map instead of std::map for caching
|
||||
purposes improves performance on very very large models (tens of thousands
|
||||
of equations) */
|
||||
expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache);
|
||||
|
||||
//! Returns precedence of node
|
||||
/*! Equals 100 for constants, variables, unary ops, and temporary terms */
|
||||
|
@ -855,10 +859,10 @@ public:
|
|||
const int id;
|
||||
private:
|
||||
expr_t computeDerivative(int deriv_id) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) override;
|
||||
protected:
|
||||
void prepareForDerivation() override;
|
||||
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
|
||||
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
|
||||
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override;
|
||||
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
|
||||
public:
|
||||
|
@ -927,10 +931,10 @@ public:
|
|||
const int lag;
|
||||
private:
|
||||
expr_t computeDerivative(int deriv_id) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) override;
|
||||
protected:
|
||||
void prepareForDerivation() override;
|
||||
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
|
||||
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
|
||||
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override;
|
||||
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
|
||||
public:
|
||||
|
@ -996,7 +1000,7 @@ class UnaryOpNode : public ExprNode
|
|||
{
|
||||
protected:
|
||||
void prepareForDerivation() override;
|
||||
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
|
||||
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
|
||||
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override;
|
||||
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
|
||||
// Returns the node obtained by applying a transformation recursively on the argument (in same datatree)
|
||||
|
@ -1018,7 +1022,7 @@ public:
|
|||
const vector<int> adl_lags;
|
||||
private:
|
||||
expr_t computeDerivative(int deriv_id) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) override;
|
||||
int cost(int cost, bool is_matlab) const override;
|
||||
int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
|
||||
int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
|
||||
|
@ -1108,7 +1112,7 @@ class BinaryOpNode : public ExprNode
|
|||
{
|
||||
protected:
|
||||
void prepareForDerivation() override;
|
||||
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
|
||||
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
|
||||
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override;
|
||||
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
|
||||
public:
|
||||
|
@ -1118,7 +1122,7 @@ public:
|
|||
const string adlparam;
|
||||
private:
|
||||
expr_t computeDerivative(int deriv_id) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) override;
|
||||
int cost(int cost, bool is_matlab) const override;
|
||||
int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
|
||||
int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
|
||||
|
@ -1264,11 +1268,11 @@ public:
|
|||
const TrinaryOpcode op_code;
|
||||
protected:
|
||||
void prepareForDerivation() override;
|
||||
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
|
||||
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
|
||||
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
|
||||
private:
|
||||
expr_t computeDerivative(int deriv_id) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) override;
|
||||
int cost(int cost, bool is_matlab) const override;
|
||||
int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
|
||||
int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
|
||||
|
@ -1371,7 +1375,7 @@ public:
|
|||
const vector<expr_t> arguments;
|
||||
private:
|
||||
expr_t computeDerivative(int deriv_id) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) override;
|
||||
virtual expr_t composeDerivatives(const vector<expr_t> &dargs) = 0;
|
||||
// Computes the maximum of f applied to all arguments (result will always be non-negative)
|
||||
int maxHelper(const function<int (expr_t)> &f) const;
|
||||
|
@ -1391,7 +1395,7 @@ protected:
|
|||
{
|
||||
};
|
||||
void prepareForDerivation() override;
|
||||
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
|
||||
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
|
||||
//! Returns true if the given external function has been written as a temporary term
|
||||
bool alreadyWrittenAsTefTerm(int the_symb_id, const deriv_node_temp_terms_t &tef_terms) const;
|
||||
//! Returns the index in the tef_terms map of this external function
|
||||
|
@ -1657,10 +1661,10 @@ public:
|
|||
expr_t substituteLogTransform(int orig_symb_id, int aux_symb_id) const override;
|
||||
protected:
|
||||
void prepareForDerivation() override;
|
||||
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
|
||||
void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
|
||||
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
|
||||
private:
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
|
||||
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) override;
|
||||
};
|
||||
|
||||
class VarExpectationNode : public SubModelNode
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include <algorithm>
|
||||
#include <sstream>
|
||||
#include <numeric>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "StaticModel.hh"
|
||||
#include "DynamicModel.hh"
|
||||
|
@ -650,8 +651,8 @@ StaticModel::computeChainRuleJacobian()
|
|||
&& simulation_type != BlockSimulationType::solveTwoBoundariesComplete);
|
||||
|
||||
int size = blocks[blk].size;
|
||||
map<expr_t, set<int>> non_null_chain_rule_derivatives;
|
||||
map<pair<expr_t, int>, expr_t> chain_rule_deriv_cache;
|
||||
unordered_map<expr_t, set<int>> non_null_chain_rule_derivatives;
|
||||
unordered_map<expr_t, map<int, expr_t>> chain_rule_deriv_cache;
|
||||
for (int eq = nb_recursives; eq < size; eq++)
|
||||
{
|
||||
int eq_orig = getBlockEquationID(blk, eq);
|
||||
|
@ -822,8 +823,8 @@ StaticModel::computeRamseyMultipliersDerivatives(int ramsey_orig_endo_nbr, bool
|
|||
}
|
||||
|
||||
// Compute the chain rule derivatives w.r.t. multipliers
|
||||
map<expr_t, set<int>> non_null_chain_rule_derivatives;
|
||||
map<pair<expr_t, int>, expr_t> cache;
|
||||
unordered_map<expr_t, set<int>> non_null_chain_rule_derivatives;
|
||||
unordered_map<expr_t, map<int, expr_t>> cache;
|
||||
for (int eq {0}; eq < ramsey_orig_endo_nbr; eq++)
|
||||
for (int mult {0}; mult < static_cast<int>(mult_deriv_ids.size()); mult++)
|
||||
if (expr_t d { equations[eq]->getChainRuleDerivative(mult_deriv_ids[mult], recursive_variables,
|
||||
|
|
Loading…
Reference in New Issue