Make ExprNode::prepareForDerivation() a protected member (was public)

master
Sébastien Villemot 2023-03-02 16:08:32 +01:00
parent fe83933b1d
commit bf8ca27a47
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
1 changed files with 10 additions and 10 deletions

View File

@ -275,6 +275,9 @@ protected:
return is_matlab ? min_cost_matlab : min_cost_c;
};
//! Initializes data member non_null_derivatives
virtual void prepareForDerivation() = 0;
//! Cost of computing current node
/*! Nodes included in temporary_terms are considered having a null cost */
virtual int cost(int cost, bool is_matlab) const;
@ -323,9 +326,6 @@ public:
ExprNode(const ExprNode &) = delete;
ExprNode &operator=(const ExprNode &) = delete;
//! Initializes data member non_null_derivatives
virtual void prepareForDerivation() = 0;
//! Returns derivative w.r. to derivation ID
/*! Uses a symbolic a priori to pre-detect null derivatives, and caches the result for other derivatives (to avoid computing it several times)
For an equal node, returns the derivative of lhs minus rhs */
@ -845,11 +845,11 @@ private:
expr_t computeDerivative(int deriv_id) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
protected:
void prepareForDerivation() override;
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
public:
NumConstNode(DataTree &datatree_arg, int idx_arg, int id_arg);
void prepareForDerivation() override;
void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override;
void writeJsonAST(ostream &output) const override;
void writeJsonOutput(ostream &output, const temporary_terms_t &temporary_terms, const deriv_node_temp_terms_t &tef_terms, bool isdynamic) const override;
@ -917,11 +917,11 @@ private:
expr_t computeDerivative(int deriv_id) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
protected:
void prepareForDerivation() override;
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
public:
VariableNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg, int lag_arg);
void prepareForDerivation() override;
void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override;
void writeJsonAST(ostream &output) const override;
void writeJsonOutput(ostream &output, const temporary_terms_t &temporary_terms, const deriv_node_temp_terms_t &tef_terms, bool isdynamic) const override;
@ -984,6 +984,7 @@ public:
class UnaryOpNode : public ExprNode
{
protected:
void prepareForDerivation() override;
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
public:
@ -1005,7 +1006,6 @@ private:
expr_t composeDerivatives(expr_t darg, int deriv_id);
public:
UnaryOpNode(DataTree &datatree_arg, int idx_arg, UnaryOpcode op_code_arg, const expr_t arg_arg, int expectation_information_set_arg, int param1_symb_id_arg, int param2_symb_id_arg, string adl_param_name_arg, vector<int> adl_lags_arg);
void prepareForDerivation() override;
void computeTemporaryTerms(const pair<int, int> &derivOrder,
map<pair<int, int>, temporary_terms_t> &temp_terms_map,
map<expr_t, pair<int, pair<int, int>>> &reference_count,
@ -1089,6 +1089,7 @@ public:
class BinaryOpNode : public ExprNode
{
protected:
void prepareForDerivation() override;
void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
public:
@ -1107,7 +1108,6 @@ private:
public:
BinaryOpNode(DataTree &datatree_arg, int idx_arg, const expr_t arg1_arg,
BinaryOpcode op_code_arg, const expr_t arg2_arg, int powerDerivOrder);
void prepareForDerivation() override;
int precedenceJson(const temporary_terms_t &temporary_terms) const override;
int precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const override;
void computeTemporaryTerms(const pair<int, int> &derivOrder,
@ -1239,6 +1239,7 @@ public:
const expr_t arg1, arg2, arg3;
const TrinaryOpcode op_code;
protected:
void prepareForDerivation() override;
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
private:
expr_t computeDerivative(int deriv_id) override;
@ -1251,7 +1252,6 @@ private:
public:
TrinaryOpNode(DataTree &datatree_arg, int idx_arg, const expr_t arg1_arg,
TrinaryOpcode op_code_arg, const expr_t arg2_arg, const expr_t arg3_arg);
void prepareForDerivation() override;
int precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const override;
void computeTemporaryTerms(const pair<int, int> &derivOrder,
map<pair<int, int>, temporary_terms_t> &temp_terms_map,
@ -1347,6 +1347,7 @@ protected:
class UnknownFunctionNameAndArgs
{
};
void prepareForDerivation() override;
//! Returns true if the given external function has been written as a temporary term
bool alreadyWrittenAsTefTerm(int the_symb_id, const deriv_node_temp_terms_t &tef_terms) const;
//! Returns the index in the tef_terms map of this external function
@ -1368,7 +1369,6 @@ protected:
public:
AbstractExternalFunctionNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg,
vector<expr_t> arguments_arg);
void prepareForDerivation() override;
void computeTemporaryTerms(const pair<int, int> &derivOrder,
map<pair<int, int>, temporary_terms_t> &temp_terms_map,
map<expr_t, pair<int, pair<int, int>>> &reference_count,
@ -1575,7 +1575,6 @@ public:
void computeBlockTemporaryTerms(int blk, int eq, vector<vector<temporary_terms_t>> &blocks_temporary_terms,
map<expr_t, tuple<int, int, int>> &reference_count) const override;
expr_t toStatic(DataTree &static_datatree) const override;
void prepareForDerivation() override;
expr_t computeDerivative(int deriv_id) override;
int maxEndoLead() const override;
int maxExoLead() const override;
@ -1622,6 +1621,7 @@ public:
expr_t removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const override;
expr_t substituteLogTransform(int orig_symb_id, int aux_symb_id) const override;
protected:
void prepareForDerivation() override;
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
private:
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;