From 171cd6556661f0379740691eb5482b13ca64a178 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= Date: Fri, 20 May 2022 11:43:02 +0200 Subject: [PATCH] Fix bytecode compilation of external function nodes --- src/DynamicModel.cc | 44 ++++++++++++++++++++------------------- src/DynamicModel.hh | 4 ++-- src/ExprNode.cc | 8 -------- src/ExprNode.hh | 2 +- src/ModelTree.cc | 11 +++++----- src/ModelTree.hh | 4 ++-- src/StaticModel.cc | 50 +++++++++++++++++++++++---------------------- src/StaticModel.hh | 4 ++-- 8 files changed, 61 insertions(+), 66 deletions(-) diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc index 84244248..4bcd0c50 100644 --- a/src/DynamicModel.cc +++ b/src/DynamicModel.cc @@ -168,11 +168,11 @@ DynamicModel::operator=(const DynamicModel &m) } void -DynamicModel::compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const +DynamicModel::compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const { if (auto it = derivatives[1].find({ eq, getDerivID(symbol_table.getID(SymbolType::endogenous, symb_id), lag) }); it != derivatives[1].end()) - it->second->compile(code_file, instruction_number, false, temporary_terms, temporary_terms_idxs, true, false); + it->second->compile(code_file, instruction_number, false, temporary_terms, temporary_terms_idxs, true, false, tef_terms); else { FLDZ_ fldz; @@ -181,11 +181,11 @@ DynamicModel::compileDerivative(ofstream &code_file, unsigned int &instruction_n } void -DynamicModel::compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const +DynamicModel::compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const { if (auto it = blocks_derivatives[blk].find({ eq, var, lag }); it != blocks_derivatives[blk].end()) - it->second->compile(code_file, instruction_number, false, temporary_terms, temporary_terms_idxs, true, false); + it->second->compile(code_file, instruction_number, false, temporary_terms, temporary_terms_idxs, true, false, tef_terms); else { FLDZ_ fldz; @@ -886,9 +886,11 @@ DynamicModel::writeDynamicBytecode(const string &basename) const fbeginblock.write(code_file, instruction_number); temporary_terms_t temporary_terms_union; - compileTemporaryTerms(code_file, instruction_number, true, false, temporary_terms_union, temporary_terms_idxs); + deriv_node_temp_terms_t tef_terms; - compileModelEquations(code_file, instruction_number, true, false, temporary_terms_union, temporary_terms_idxs); + compileTemporaryTerms(code_file, instruction_number, true, false, temporary_terms_union, temporary_terms_idxs, tef_terms); + + compileModelEquations(code_file, instruction_number, true, false, temporary_terms_union, temporary_terms_idxs, tef_terms); FENDEQU_ fendequ; fendequ.write(code_file, instruction_number); @@ -915,7 +917,7 @@ DynamicModel::writeDynamicBytecode(const string &basename) const if (!my_derivatives[eq].size()) my_derivatives[eq].clear(); my_derivatives[eq].emplace_back(var, lag, count_u); - d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, true, false); + d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, true, false, tef_terms); FSTPU_ fstpu(count_u); fstpu.write(code_file, instruction_number); @@ -977,7 +979,7 @@ DynamicModel::writeDynamicBytecode(const string &basename) const prev_lag = lag; count_col_endo++; } - d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, true, false); + d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, true, false, tef_terms); FSTPG3_ fstpg3(eq, var, lag, count_col_endo-1); fstpg3.write(code_file, instruction_number); } @@ -996,7 +998,7 @@ DynamicModel::writeDynamicBytecode(const string &basename) const prev_lag = lag; count_col_exo++; } - d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, true, false); + d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, true, false, tef_terms); FSTPG3_ fstpg3(eq, var, lag, count_col_exo-1); fstpg3.write(code_file, instruction_number); } @@ -1152,16 +1154,16 @@ DynamicModel::writeDynamicBlockBytecode(const string &basename) const eq_node = getBlockEquationExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false); - lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, true, false); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false, tef_terms); + lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, true, false, tef_terms); } else if (equ_type == EquationType::evaluateRenormalized) { eq_node = getBlockEquationRenormalizedExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false); - lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, true, false); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false, tef_terms); + lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, true, false, tef_terms); } break; case BlockSimulationType::solveBackwardComplete: @@ -1182,8 +1184,8 @@ DynamicModel::writeDynamicBlockBytecode(const string &basename) const eq_node = getBlockEquationExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - lhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false); - rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false); + lhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false, tef_terms); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false, tef_terms); FBINARY_ fbinary{static_cast(BinaryOpcode::minus)}; fbinary.write(code_file, instruction_number); @@ -1221,7 +1223,7 @@ DynamicModel::writeDynamicBlockBytecode(const string &basename) const FNUMEXPR_ fnumexpr(ExpressionType::FirstEndoDerivative, getBlockEquationID(block, 0), getBlockVariableID(block, 0), 0); fnumexpr.write(code_file, instruction_number); } - compileDerivative(code_file, instruction_number, getBlockEquationID(block, 0), getBlockVariableID(block, 0), 0, temporary_terms_union, blocks_temporary_terms_idxs); + compileDerivative(code_file, instruction_number, getBlockEquationID(block, 0), getBlockVariableID(block, 0), 0, temporary_terms_union, blocks_temporary_terms_idxs, tef_terms); { FSTPG_ fstpg(0); fstpg.write(code_file, instruction_number); @@ -1260,7 +1262,7 @@ DynamicModel::writeDynamicBlockBytecode(const string &basename) const Uf[eqr].Ufl->lag = lag; FNUMEXPR_ fnumexpr(ExpressionType::FirstEndoDerivative, eqr, varr, lag); fnumexpr.write(code_file, instruction_number); - compileChainRuleDerivative(code_file, instruction_number, block, eq, var, lag, temporary_terms_union, blocks_temporary_terms_idxs); + compileChainRuleDerivative(code_file, instruction_number, block, eq, var, lag, temporary_terms_union, blocks_temporary_terms_idxs, tef_terms); FSTPU_ fstpu(count_u); fstpu.write(code_file, instruction_number); count_u++; @@ -1337,7 +1339,7 @@ DynamicModel::writeDynamicBlockBytecode(const string &basename) const int varr = getBlockVariableID(block, var); FNUMEXPR_ fnumexpr(ExpressionType::FirstEndoDerivative, eqr, varr, lag); fnumexpr.write(code_file, instruction_number); - compileDerivative(code_file, instruction_number, eqr, varr, lag, temporary_terms_union, blocks_temporary_terms_idxs); + compileDerivative(code_file, instruction_number, eqr, varr, lag, temporary_terms_union, blocks_temporary_terms_idxs, tef_terms); FSTPG3_ fstpg3(eq, var, lag, blocks_jacob_cols_endo[block].at({ var, lag })); fstpg3.write(code_file, instruction_number); } @@ -1348,7 +1350,7 @@ DynamicModel::writeDynamicBlockBytecode(const string &basename) const int varr = 0; // Dummy value, actually unused by the bytecode MEX FNUMEXPR_ fnumexpr(ExpressionType::FirstExoDerivative, eqr, varr, lag); fnumexpr.write(code_file, instruction_number); - d->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false); + d->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false, tef_terms); FSTPG3_ fstpg3(eq, var, lag, blocks_jacob_cols_exo[block].at({ var, lag })); fstpg3.write(code_file, instruction_number); } @@ -1359,7 +1361,7 @@ DynamicModel::writeDynamicBlockBytecode(const string &basename) const int varr = 0; // Dummy value, actually unused by the bytecode MEX FNUMEXPR_ fnumexpr(ExpressionType::FirstExodetDerivative, eqr, varr, lag); fnumexpr.write(code_file, instruction_number); - d->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false); + d->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false, tef_terms); FSTPG3_ fstpg3(eq, var, lag, blocks_jacob_cols_exo_det[block].at({ var, lag })); fstpg3.write(code_file, instruction_number); } @@ -1370,7 +1372,7 @@ DynamicModel::writeDynamicBlockBytecode(const string &basename) const int varr = 0; // Dummy value, actually unused by the bytecode MEX FNUMEXPR_ fnumexpr(ExpressionType::FirstOtherEndoDerivative, eqr, varr, lag); fnumexpr.write(code_file, instruction_number); - d->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false); + d->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false, tef_terms); FSTPG3_ fstpg3(eq, var, lag, blocks_jacob_cols_other_endo[block].at({ var, lag })); fstpg3.write(code_file, instruction_number); } diff --git a/src/DynamicModel.hh b/src/DynamicModel.hh index d808a410..e9a9e2f7 100644 --- a/src/DynamicModel.hh +++ b/src/DynamicModel.hh @@ -185,9 +185,9 @@ private: map> &reference_count) const override; //! Write derivative code of an equation w.r. to a variable - void compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const; + void compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const; //! Write chain rule derivative code of an equation w.r. to a variable - void compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const; + void compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const; //! Get the type corresponding to a derivation ID SymbolType getTypeByDerivID(int deriv_id) const noexcept(false) override; diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 6ccec3d6..cd86dfd6 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -169,14 +169,6 @@ ExprNode::writeOutput(ostream &output, ExprNodeOutputType output_type, const tem writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, {}); } -void -ExprNode::compile(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, const temporary_terms_t &temporary_terms, - const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic) const -{ - compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, temporary_terms_idxs, dynamic, steady_dynamic, {}); -} - void ExprNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, diff --git a/src/ExprNode.hh b/src/ExprNode.hh index 0ad7f51a..27fd40ec 100644 --- a/src/ExprNode.hh +++ b/src/ExprNode.hh @@ -410,7 +410,7 @@ public: virtual double eval(const eval_context_t &eval_context) const noexcept(false) = 0; virtual void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const = 0; - void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic) const; + //! Creates a static version of this node /*! This method duplicates the current node by creating a similar node from which all leads/lags have been stripped, diff --git a/src/ModelTree.cc b/src/ModelTree.cc index d82a32b7..66ba7f9c 100644 --- a/src/ModelTree.cc +++ b/src/ModelTree.cc @@ -1260,10 +1260,9 @@ ModelTree::testNestedParenthesis(const string &str) const } void -ModelTree::compileTemporaryTerms(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic, temporary_terms_t &temporary_terms_union, const temporary_terms_idxs_t &temporary_terms_idxs) const +ModelTree::compileTemporaryTerms(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic, temporary_terms_t &temporary_terms_union, const temporary_terms_idxs_t &temporary_terms_idxs, deriv_node_temp_terms_t &tef_terms) const { // To store the functions that have already been written in the form TEF* = ext_fun(); - deriv_node_temp_terms_t tef_terms; for (auto [tt, idx] : temporary_terms_idxs) { if (dynamic_cast(tt)) @@ -1398,7 +1397,7 @@ ModelTree::writeModelEquations(ostream &output, ExprNodeOutputType output_type, } void -ModelTree::compileModelEquations(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic, const temporary_terms_t &temporary_terms_union, const temporary_terms_idxs_t &temporary_terms_idxs) const +ModelTree::compileModelEquations(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic, const temporary_terms_t &temporary_terms_union, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const { for (int eq = 0; eq < static_cast(equations.size()); eq++) { @@ -1419,8 +1418,8 @@ ModelTree::compileModelEquations(ostream &code_file, unsigned int &instruction_n if (vrhs != 0) // The right hand side of the equation is not empty ==> residual=lhs-rhs; { - lhs->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, dynamic, steady_dynamic); - rhs->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, dynamic, steady_dynamic); + lhs->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); FBINARY_ fbinary{static_cast(BinaryOpcode::minus)}; fbinary.write(code_file, instruction_number); @@ -1430,7 +1429,7 @@ ModelTree::compileModelEquations(ostream &code_file, unsigned int &instruction_n } else // The right hand side of the equation is empty ==> residual=lhs; { - lhs->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, dynamic, steady_dynamic); + lhs->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); FSTPR_ fstpr(eq); fstpr.write(code_file, instruction_number); } diff --git a/src/ModelTree.hh b/src/ModelTree.hh index 291a839d..5b9bd966 100644 --- a/src/ModelTree.hh +++ b/src/ModelTree.hh @@ -237,7 +237,7 @@ protected: void writeTemporaryTerms(const temporary_terms_t &tt, temporary_terms_t &temp_term_union, const temporary_terms_idxs_t &tt_idxs, ostream &output, ExprNodeOutputType output_type, deriv_node_temp_terms_t &tef_terms) const; void writeJsonTemporaryTerms(const temporary_terms_t &tt, temporary_terms_t &temp_term_union, ostream &output, deriv_node_temp_terms_t &tef_terms, const string &concat) const; //! Compiles temporary terms - void compileTemporaryTerms(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic, temporary_terms_t &temporary_terms_union, const temporary_terms_idxs_t &temporary_terms_idxs) const; + void compileTemporaryTerms(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic, temporary_terms_t &temporary_terms_union, const temporary_terms_idxs_t &temporary_terms_idxs, deriv_node_temp_terms_t &tef_terms) const; //! Adds information for (non-block) bytecode simulation in a separate .bin file void writeBytecodeBinFile(const string &filename, int &u_count_int, bool &file_open, bool is_two_boundaries) const; //! Fixes output when there are more than 32 nested parens, Issue #1201 @@ -260,7 +260,7 @@ protected: Optionally put the external function variable calls into TEF terms */ void writeJsonModelLocalVariables(ostream &output, bool write_tef_terms, deriv_node_temp_terms_t &tef_terms) const; //! Compiles model equations - void compileModelEquations(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic, const temporary_terms_t &temporary_terms_union, const temporary_terms_idxs_t &temporary_terms_idxs) const; + void compileModelEquations(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic, const temporary_terms_t &temporary_terms_union, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const; //! Writes LaTeX model file void writeLatexModelFile(const string &mod_basename, const string &latex_basename, ExprNodeOutputType output_type, bool write_equation_tags) const; diff --git a/src/StaticModel.cc b/src/StaticModel.cc index 3b6320de..e5e9ab7d 100644 --- a/src/StaticModel.cc +++ b/src/StaticModel.cc @@ -103,11 +103,11 @@ StaticModel::StaticModel(const DynamicModel &m) : } void -StaticModel::compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const +StaticModel::compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const { if (auto it = derivatives[1].find({ eq, getDerivID(symbol_table.getID(SymbolType::endogenous, symb_id), 0) }); it != derivatives[1].end()) - it->second->compile(code_file, instruction_number, false, temporary_terms, temporary_terms_idxs, false, false); + it->second->compile(code_file, instruction_number, false, temporary_terms, temporary_terms_idxs, false, false, tef_terms); else { FLDZ_ fldz; @@ -116,11 +116,11 @@ StaticModel::compileDerivative(ofstream &code_file, unsigned int &instruction_nu } void -StaticModel::compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const +StaticModel::compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const { if (auto it = blocks_derivatives[blk].find({ eq, var, lag }); it != blocks_derivatives[blk].end()) - it->second->compile(code_file, instruction_number, false, temporary_terms, temporary_terms_idxs, false, false); + it->second->compile(code_file, instruction_number, false, temporary_terms, temporary_terms_idxs, false, false, tef_terms); else { FLDZ_ fldz; @@ -420,9 +420,11 @@ StaticModel::writeStaticBytecode(const string &basename) const fbeginblock.write(code_file, instruction_number); temporary_terms_t temporary_terms_union; - compileTemporaryTerms(code_file, instruction_number, false, false, temporary_terms_union, temporary_terms_idxs); + deriv_node_temp_terms_t tef_terms; - compileModelEquations(code_file, instruction_number, false, false, temporary_terms_union, temporary_terms_idxs); + compileTemporaryTerms(code_file, instruction_number, false, false, temporary_terms_union, temporary_terms_idxs, tef_terms); + + compileModelEquations(code_file, instruction_number, false, false, temporary_terms_union, temporary_terms_idxs, tef_terms); FENDEQU_ fendequ; fendequ.write(code_file, instruction_number); @@ -449,7 +451,7 @@ StaticModel::writeStaticBytecode(const string &basename) const my_derivatives[eq].clear(); my_derivatives[eq].emplace_back(var, count_u); - d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, false, false); + d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, false, false, tef_terms); FSTPSU_ fstpsu(count_u); fstpsu.write(code_file, instruction_number); @@ -511,7 +513,7 @@ StaticModel::writeStaticBytecode(const string &basename) const my_derivatives[eq].clear(); my_derivatives[eq].emplace_back(var, count_u); - d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, false, false); + d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, false, false, tef_terms); FSTPG2_ fstpg2(eq, var); fstpg2.write(code_file, instruction_number); } @@ -660,16 +662,16 @@ StaticModel::writeStaticBlockBytecode(const string &basename) const eq_node = getBlockEquationExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false); - lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, false, false); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); + lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); } else if (equ_type == EquationType::evaluateRenormalized) { eq_node = getBlockEquationRenormalizedExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false); - lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, false, false); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); + lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); } break; case BlockSimulationType::solveBackwardComplete: @@ -688,8 +690,8 @@ StaticModel::writeStaticBlockBytecode(const string &basename) const eq_node = getBlockEquationExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - lhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false); - rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false); + lhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); FBINARY_ fbinary{static_cast(BinaryOpcode::minus)}; fbinary.write(code_file, instruction_number); @@ -716,7 +718,7 @@ StaticModel::writeStaticBlockBytecode(const string &basename) const FNUMEXPR_ fnumexpr(ExpressionType::FirstEndoDerivative, 0, 0); fnumexpr.write(code_file, instruction_number); } - compileDerivative(code_file, instruction_number, getBlockEquationID(block, 0), getBlockVariableID(block, 0), temporary_terms_union, blocks_temporary_terms_idxs); + compileDerivative(code_file, instruction_number, getBlockEquationID(block, 0), getBlockVariableID(block, 0), temporary_terms_union, blocks_temporary_terms_idxs, tef_terms); { FSTPG_ fstpg(0); fstpg.write(code_file, instruction_number); @@ -748,7 +750,7 @@ StaticModel::writeStaticBlockBytecode(const string &basename) const Uf[eqr].Ufl->var = varr; FNUMEXPR_ fnumexpr(ExpressionType::FirstEndoDerivative, eqr, varr); fnumexpr.write(code_file, instruction_number); - compileChainRuleDerivative(code_file, instruction_number, block, eq, var, 0, temporary_terms_union, blocks_temporary_terms_idxs); + compileChainRuleDerivative(code_file, instruction_number, block, eq, var, 0, temporary_terms_union, blocks_temporary_terms_idxs, tef_terms); FSTPSU_ fstpsu(count_u); fstpsu.write(code_file, instruction_number); count_u++; @@ -836,16 +838,16 @@ StaticModel::writeStaticBlockBytecode(const string &basename) const eq_node = getBlockEquationExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false); - lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, false, false); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); + lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); } else if (equ_type == EquationType::evaluateRenormalized) { eq_node = getBlockEquationRenormalizedExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false); - lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, false, false); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); + lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); } break; case BlockSimulationType::solveBackwardComplete: @@ -864,8 +866,8 @@ StaticModel::writeStaticBlockBytecode(const string &basename) const eq_node = getBlockEquationExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - lhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false); - rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false); + lhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); FBINARY_ fbinary{static_cast(BinaryOpcode::minus)}; fbinary.write(code_file, instruction_number); @@ -890,7 +892,7 @@ StaticModel::writeStaticBlockBytecode(const string &basename) const FNUMEXPR_ fnumexpr(ExpressionType::FirstEndoDerivative, 0, 0); fnumexpr.write(code_file, instruction_number); } - compileDerivative(code_file, instruction_number, getBlockEquationID(block, 0), getBlockVariableID(block, 0), temporary_terms_union, blocks_temporary_terms_idxs); + compileDerivative(code_file, instruction_number, getBlockEquationID(block, 0), getBlockVariableID(block, 0), temporary_terms_union, blocks_temporary_terms_idxs, tef_terms); { FSTPG2_ fstpg2(0, 0); fstpg2.write(code_file, instruction_number); @@ -909,7 +911,7 @@ StaticModel::writeStaticBlockBytecode(const string &basename) const FNUMEXPR_ fnumexpr(ExpressionType::FirstEndoDerivative, eqr, varr, 0); fnumexpr.write(code_file, instruction_number); - compileChainRuleDerivative(code_file, instruction_number, block, eq, var, 0, temporary_terms_union, blocks_temporary_terms_idxs); + compileChainRuleDerivative(code_file, instruction_number, block, eq, var, 0, temporary_terms_union, blocks_temporary_terms_idxs, tef_terms); FSTPG2_ fstpg2(eq, var); fstpg2.write(code_file, instruction_number); diff --git a/src/StaticModel.hh b/src/StaticModel.hh index c105cec4..b582516a 100644 --- a/src/StaticModel.hh +++ b/src/StaticModel.hh @@ -78,9 +78,9 @@ private: void evaluateJacobian(const eval_context_t &eval_context, jacob_map_t *j_m, bool dynamic); //! Write derivative code of an equation w.r. to a variable - void compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const; + void compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const; //! Write chain rule derivative code of an equation w.r. to a variable - void compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const; + void compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const; //! Get the type corresponding to a derivation ID SymbolType getTypeByDerivID(int deriv_id) const noexcept(false) override;