From 02ae0af3e51c9d14cdc9e0eb9d7ee92c8bc71391 Mon Sep 17 00:00:00 2001 From: Houtan Bastani Date: Tue, 29 Jan 2019 17:29:24 +0100 Subject: [PATCH] change map type for readability --- src/ExprNode.cc | 40 ++++++++++++++++++++-------------------- src/ExprNode.hh | 37 +++++++++++++++++++------------------ src/ModelTree.cc | 4 ++-- src/ModelTree.hh | 2 +- 4 files changed, 42 insertions(+), 41 deletions(-) diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 556f0efd..870be77f 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -714,13 +714,13 @@ NumConstNode::fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, c } void -NumConstNode::findConstantEquations(map &table) const +NumConstNode::findConstantEquations(map &table) const { return; } expr_t -NumConstNode::replaceVarsInEquation(map &table) const +NumConstNode::replaceVarsInEquation(map &table) const { return const_cast(this); } @@ -2024,17 +2024,17 @@ VariableNode::fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, c } void -VariableNode::findConstantEquations(map &table) const +VariableNode::findConstantEquations(map &table) const { return; } expr_t -VariableNode::replaceVarsInEquation(map &table) const +VariableNode::replaceVarsInEquation(map &table) const { for (auto & it : table) - if (dynamic_cast(it.first)->symb_id == symb_id) - return dynamic_cast(it.second); + if (it.first->symb_id == symb_id) + return it.second; return const_cast(this); } @@ -3857,13 +3857,13 @@ UnaryOpNode::fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, co } void -UnaryOpNode::findConstantEquations(map &table) const +UnaryOpNode::findConstantEquations(map &table) const { arg->findConstantEquations(table); } expr_t -UnaryOpNode::replaceVarsInEquation(map &table) const +UnaryOpNode::replaceVarsInEquation(map &table) const { expr_t argsubst = arg->replaceVarsInEquation(table); return buildSimilarUnaryOpNode(argsubst, datatree); @@ -5830,13 +5830,13 @@ BinaryOpNode::fillErrorCorrectionRowHelper(expr_t arg1, expr_t arg2, } void -BinaryOpNode::findConstantEquations(map &table) const +BinaryOpNode::findConstantEquations(map &table) const { if (op_code == BinaryOpcode::equal) if (dynamic_cast(arg1) != nullptr && dynamic_cast(arg2) != nullptr) - table[arg1] = arg2; + table[dynamic_cast(arg1)] = dynamic_cast(arg2); else if (dynamic_cast(arg2) != nullptr && dynamic_cast(arg1) != nullptr) - table[arg2] = arg1; + table[dynamic_cast(arg2)] = dynamic_cast(arg1); else { arg1->findConstantEquations(table); @@ -5845,7 +5845,7 @@ BinaryOpNode::findConstantEquations(map &table) const } expr_t -BinaryOpNode::replaceVarsInEquation(map &table) const +BinaryOpNode::replaceVarsInEquation(map &table) const { if (op_code == BinaryOpcode::equal) for (auto & it : table) @@ -6858,7 +6858,7 @@ TrinaryOpNode::fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, } void -TrinaryOpNode::findConstantEquations(map &table) const +TrinaryOpNode::findConstantEquations(map &table) const { arg1->findConstantEquations(table); arg2->findConstantEquations(table); @@ -6866,7 +6866,7 @@ TrinaryOpNode::findConstantEquations(map &table) const } expr_t -TrinaryOpNode::replaceVarsInEquation(map &table) const +TrinaryOpNode::replaceVarsInEquation(map &table) const { expr_t arg1subst = arg1->replaceVarsInEquation(table); expr_t arg2subst = arg2->replaceVarsInEquation(table); @@ -7511,14 +7511,14 @@ AbstractExternalFunctionNode::fillErrorCorrectionRow(int eqn, const vector } void -AbstractExternalFunctionNode::findConstantEquations(map &table) const +AbstractExternalFunctionNode::findConstantEquations(map &table) const { for (auto argument : arguments) argument->findConstantEquations(table); } expr_t -AbstractExternalFunctionNode::replaceVarsInEquation(map &table) const +AbstractExternalFunctionNode::replaceVarsInEquation(map &table) const { vector arguments_subst; for (auto argument : arguments) @@ -9040,13 +9040,13 @@ VarExpectationNode::fillErrorCorrectionRow(int eqn, const vector &nontrend_ } void -VarExpectationNode::findConstantEquations(map &table) const +VarExpectationNode::findConstantEquations(map &table) const { return; } expr_t -VarExpectationNode::replaceVarsInEquation(map &table) const +VarExpectationNode::replaceVarsInEquation(map &table) const { return const_cast(this); } @@ -9560,13 +9560,13 @@ PacExpectationNode::fillErrorCorrectionRow(int eqn, const vector &nontrend_ } void -PacExpectationNode::findConstantEquations(map &table) const +PacExpectationNode::findConstantEquations(map &table) const { return; } expr_t -PacExpectationNode::replaceVarsInEquation(map &table) const +PacExpectationNode::replaceVarsInEquation(map &table) const { return const_cast(this); } diff --git a/src/ExprNode.hh b/src/ExprNode.hh index df57b822..ff84420c 100644 --- a/src/ExprNode.hh +++ b/src/ExprNode.hh @@ -34,6 +34,7 @@ using namespace std; #include "SymbolList.hh" class DataTree; +class NumConstNode; class VariableNode; class UnaryOpNode; class BinaryOpNode; @@ -613,10 +614,10 @@ class ExprNode map, expr_t> &EC) const = 0; //! Finds equations where a variable is equal to a constant - virtual void findConstantEquations(map &table) const = 0; + virtual void findConstantEquations(map &table) const = 0; //! Replaces variables found in findConstantEquations() with their constant values - virtual expr_t replaceVarsInEquation(map &table) const = 0; + virtual expr_t replaceVarsInEquation(map &table) const = 0; //! Returns true if PacExpectationNode encountered virtual bool containsPacExpectation(const string &pac_model_name = "") const = 0; @@ -732,8 +733,8 @@ public: void fillPacExpectationVarInfo(string &model_name_arg, vector &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override; void fillAutoregressiveRow(int eqn, const vector &lhs, map, expr_t> &AR) const override; void fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &EC) const override; - void findConstantEquations(map &table) const override; - expr_t replaceVarsInEquation(map &table) const override; + void findConstantEquations(map &table) const override; + expr_t replaceVarsInEquation(map &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override; void getPacOptimizingPart(int lhs_orig_symb_id, pair, vector>> &ec_params_and_vars, set>> ¶ms_and_vars) const override; @@ -823,8 +824,8 @@ public: void fillPacExpectationVarInfo(string &model_name_arg, vector &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override; void fillAutoregressiveRow(int eqn, const vector &lhs, map, expr_t> &AR) const override; void fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &EC) const override; - void findConstantEquations(map &table) const override; - expr_t replaceVarsInEquation(map &table) const override; + void findConstantEquations(map &table) const override; + expr_t replaceVarsInEquation(map &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override; void getPacOptimizingPart(int lhs_orig_symb_id, pair, vector>> &ec_params_and_vars, set>> ¶ms_and_vars) const override; @@ -942,8 +943,8 @@ public: void fillPacExpectationVarInfo(string &model_name_arg, vector &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override; void fillAutoregressiveRow(int eqn, const vector &lhs, map, expr_t> &AR) const override; void fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &EC) const override; - void findConstantEquations(map &table) const override; - expr_t replaceVarsInEquation(map &table) const override; + void findConstantEquations(map &table) const override; + expr_t replaceVarsInEquation(map &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override; void getPacOptimizingPart(int lhs_orig_symb_id, pair, vector>> &ec_params_and_vars, set>> ¶ms_and_vars) const override; @@ -1079,8 +1080,8 @@ public: int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &AR) const; void fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &EC) const override; - void findConstantEquations(map &table) const override; - expr_t replaceVarsInEquation(map &table) const override; + void findConstantEquations(map &table) const override; + expr_t replaceVarsInEquation(map &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override; void getPacOptimizingPart(int lhs_orig_symb_id, pair, vector>> &ec_params_and_vars, set>> ¶ms_and_vars) const override; @@ -1195,8 +1196,8 @@ public: void fillPacExpectationVarInfo(string &model_name_arg, vector &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override; void fillAutoregressiveRow(int eqn, const vector &lhs, map, expr_t> &AR) const override; void fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &EC) const override; - void findConstantEquations(map &table) const override; - expr_t replaceVarsInEquation(map &table) const override; + void findConstantEquations(map &table) const override; + expr_t replaceVarsInEquation(map &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override; void getPacOptimizingPart(int lhs_orig_symb_id, pair, vector>> &ec_params_and_vars, set>> ¶ms_and_vars) const override; @@ -1323,8 +1324,8 @@ public: void fillPacExpectationVarInfo(string &model_name_arg, vector &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override; void fillAutoregressiveRow(int eqn, const vector &lhs, map, expr_t> &AR) const override; void fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &EC) const override; - void findConstantEquations(map &table) const override; - expr_t replaceVarsInEquation(map &table) const override; + void findConstantEquations(map &table) const override; + expr_t replaceVarsInEquation(map &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override; void getPacOptimizingPart(int lhs_orig_symb_id, pair, vector>> &ec_params_and_vars, set>> ¶ms_and_vars) const override; @@ -1539,8 +1540,8 @@ public: void fillPacExpectationVarInfo(string &model_name_arg, vector &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override; void fillAutoregressiveRow(int eqn, const vector &lhs, map, expr_t> &AR) const override; void fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &EC) const override; - void findConstantEquations(map &table) const override; - expr_t replaceVarsInEquation(map &table) const override; + void findConstantEquations(map &table) const override; + expr_t replaceVarsInEquation(map &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override; void getPacOptimizingPart(int lhs_orig_symb_id, pair, vector>> &ec_params_and_vars, set>> ¶ms_and_vars) const override; @@ -1641,8 +1642,8 @@ public: void fillPacExpectationVarInfo(string &model_name_arg, vector &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override; void fillAutoregressiveRow(int eqn, const vector &lhs, map, expr_t> &AR) const override; void fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &EC) const override; - void findConstantEquations(map &table) const override; - expr_t replaceVarsInEquation(map &table) const override; + void findConstantEquations(map &table) const override; + expr_t replaceVarsInEquation(map &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override; void getPacOptimizingPart(int lhs_orig_symb_id, pair, vector>> &ec_params_and_vars, set>> ¶ms_and_vars) const override; diff --git a/src/ModelTree.cc b/src/ModelTree.cc index 29486fa3..136b39bb 100644 --- a/src/ModelTree.cc +++ b/src/ModelTree.cc @@ -1925,7 +1925,7 @@ void ModelTree::simplifyEquations() { size_t last_subst_table_size = 0; - map subst_table; + map subst_table; findConstantEquations(subst_table); while (subst_table.size() != last_subst_table_size) { @@ -1938,7 +1938,7 @@ ModelTree::simplifyEquations() } void -ModelTree::findConstantEquations(map &subst_table) const +ModelTree::findConstantEquations(map &subst_table) const { for (auto & equation : equations) equation->findConstantEquations(subst_table); diff --git a/src/ModelTree.hh b/src/ModelTree.hh index 95af4faa..f607fb3e 100644 --- a/src/ModelTree.hh +++ b/src/ModelTree.hh @@ -352,7 +352,7 @@ public: //! Simplify model equations: if a variable is equal to a constant, replace that variable elsewhere in the model void simplifyEquations(); //! Find equations where variable is equal to a constant - void findConstantEquations(map &subst_table) const; + void findConstantEquations(map &subst_table) const; void jacobianHelper(ostream &output, int eq_nb, int col_nb, ExprNodeOutputType output_type) const; //! Helper for writing the sparse Hessian or third derivatives in MATLAB and C