diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc index 4db74ef2..18ef15e9 100644 --- a/src/DynamicModel.cc +++ b/src/DynamicModel.cc @@ -185,11 +185,11 @@ DynamicModel::operator=(const DynamicModel &m) } void -DynamicModel::compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id, int lag) const +DynamicModel::compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const { if (auto it = derivatives[1].find({ eq, getDerivID(symbol_table.getID(SymbolType::endogenous, symb_id), lag) }); it != derivatives[1].end()) - it->second->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, true, false); + it->second->compile(code_file, instruction_number, false, temporary_terms, temporary_terms_idxs, true, false); else { FLDZ_ fldz; @@ -198,11 +198,11 @@ DynamicModel::compileDerivative(ofstream &code_file, unsigned int &instruction_n } void -DynamicModel::compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag) const +DynamicModel::compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const { if (auto it = blocks_derivatives[blk].find({ eq, var, lag }); it != blocks_derivatives[blk].end()) - it->second->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, true, false); + it->second->compile(code_file, instruction_number, false, temporary_terms, temporary_terms_idxs, true, false); else { FLDZ_ fldz; @@ -773,9 +773,10 @@ DynamicModel::writeModelEquationsCode(const string &basename) const other_endo); fbeginblock.write(code_file, instruction_number); - compileTemporaryTerms(code_file, instruction_number, true, false); + temporary_terms_t temporary_terms_union; + compileTemporaryTerms(code_file, instruction_number, true, false, temporary_terms_union, temporary_terms_idxs); - compileModelEquations(code_file, instruction_number, true, false); + compileModelEquations(code_file, instruction_number, true, false, temporary_terms_union, temporary_terms_idxs); FENDEQU_ fendequ; fendequ.write(code_file, instruction_number); @@ -802,7 +803,7 @@ DynamicModel::writeModelEquationsCode(const string &basename) const if (!my_derivatives[eq].size()) my_derivatives[eq].clear(); my_derivatives[eq].emplace_back(var, lag, count_u); - d1->compile(code_file, instruction_number, false, temporary_terms_idxs, true, false); + d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, true, false); FSTPU_ fstpu(count_u); fstpu.write(code_file, instruction_number); @@ -864,7 +865,7 @@ DynamicModel::writeModelEquationsCode(const string &basename) const prev_lag = lag; count_col_endo++; } - d1->compile(code_file, instruction_number, false, temporary_terms_idxs, true, false); + d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, true, false); FSTPG3_ fstpg3(eq, var, lag, count_col_endo-1); fstpg3.write(code_file, instruction_number); } @@ -883,7 +884,7 @@ DynamicModel::writeModelEquationsCode(const string &basename) const prev_lag = lag; count_col_exo++; } - d1->compile(code_file, instruction_number, false, temporary_terms_idxs, true, false); + d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, true, false); FSTPG3_ fstpg3(eq, var, lag, count_col_exo-1); fstpg3.write(code_file, instruction_number); } @@ -993,8 +994,9 @@ DynamicModel::writeModelEquationsCode_Block(const string &basename, bool linear_ vector(blocks_other_endo[block].begin(), blocks_other_endo[block].end())); fbeginblock.write(code_file, instruction_number); + temporary_terms_t temporary_terms_union; if (linear_decomposition) - compileTemporaryTerms(code_file, instruction_number, true, false); + compileTemporaryTerms(code_file, instruction_number, true, false, temporary_terms_union, blocks_temporary_terms_idxs); //The Temporary terms deriv_node_temp_terms_t tef_terms; @@ -1002,13 +1004,14 @@ DynamicModel::writeModelEquationsCode_Block(const string &basename, bool linear_ for (auto it : blocks_temporary_terms[block]) { if (dynamic_cast(it)) - it->compileExternalFunctionOutput(code_file, instruction_number, false, blocks_temporary_terms_idxs, true, false, tef_terms); + it->compileExternalFunctionOutput(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false, tef_terms); FNUMEXPR_ fnumexpr(TemporaryTerm, static_cast(blocks_temporary_terms_idxs.at(it))); fnumexpr.write(code_file, instruction_number); - it->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, true, false, tef_terms); + it->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false, tef_terms); FSTPT_ fstpt(static_cast(blocks_temporary_terms_idxs.at(it))); fstpt.write(code_file, instruction_number); + temporary_terms_union.insert(it); #ifdef DEBUGC cout << "FSTPT " << v << endl; instruction_number++; @@ -1039,16 +1042,16 @@ DynamicModel::writeModelEquationsCode_Block(const string &basename, bool linear_ eq_node = getBlockEquationExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - rhs->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, true, false); - lhs->compile(code_file, instruction_number, true, blocks_temporary_terms_idxs, true, false); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false); + lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, true, false); } else if (equ_type == EquationType::evaluate_s) { eq_node = getBlockEquationRenormalizedExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - rhs->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, true, false); - lhs->compile(code_file, instruction_number, true, blocks_temporary_terms_idxs, true, false); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false); + lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, true, false); } break; case BlockSimulationType::solveBackwardComplete: @@ -1069,8 +1072,8 @@ DynamicModel::writeModelEquationsCode_Block(const string &basename, bool linear_ eq_node = getBlockEquationExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - lhs->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, true, false); - rhs->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, true, false); + lhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false); FBINARY_ fbinary{static_cast(BinaryOpcode::minus)}; fbinary.write(code_file, instruction_number); @@ -1098,7 +1101,7 @@ DynamicModel::writeModelEquationsCode_Block(const string &basename, bool linear_ FNUMEXPR_ fnumexpr(FirstEndoDerivative, getBlockEquationID(block, 0), getBlockVariableID(block, 0), 0); fnumexpr.write(code_file, instruction_number); } - compileDerivative(code_file, instruction_number, getBlockEquationID(block, 0), getBlockVariableID(block, 0), 0); + compileDerivative(code_file, instruction_number, getBlockEquationID(block, 0), getBlockVariableID(block, 0), 0, temporary_terms_union, blocks_temporary_terms_idxs); { FSTPG_ fstpg(0); fstpg.write(code_file, instruction_number); @@ -1137,7 +1140,7 @@ DynamicModel::writeModelEquationsCode_Block(const string &basename, bool linear_ Uf[eqr].Ufl->lag = lag; FNUMEXPR_ fnumexpr(FirstEndoDerivative, eqr, varr, lag); fnumexpr.write(code_file, instruction_number); - compileChainRuleDerivative(code_file, instruction_number, block, eq, var, lag); + compileChainRuleDerivative(code_file, instruction_number, block, eq, var, lag, temporary_terms_union, blocks_temporary_terms_idxs); FSTPU_ fstpu(count_u); fstpu.write(code_file, instruction_number); count_u++; @@ -1206,7 +1209,7 @@ DynamicModel::writeModelEquationsCode_Block(const string &basename, bool linear_ int varr = getBlockVariableID(block, var); FNUMEXPR_ fnumexpr(FirstEndoDerivative, eqr, varr, lag); fnumexpr.write(code_file, instruction_number); - compileDerivative(code_file, instruction_number, eqr, varr, lag); + compileDerivative(code_file, instruction_number, eqr, varr, lag, temporary_terms_union, blocks_temporary_terms_idxs); FSTPG3_ fstpg3(eq, var, lag, blocks_jacob_cols_endo[block].at({ var, lag })); fstpg3.write(code_file, instruction_number); } @@ -1217,7 +1220,7 @@ DynamicModel::writeModelEquationsCode_Block(const string &basename, bool linear_ int varr = 0; // Dummy value, actually unused by the bytecode MEX FNUMEXPR_ fnumexpr(FirstExoDerivative, eqr, varr, lag); fnumexpr.write(code_file, instruction_number); - d->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, true, false); + d->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false); FSTPG3_ fstpg3(eq, var, lag, blocks_jacob_cols_exo[block].at({ var, lag })); fstpg3.write(code_file, instruction_number); } @@ -1228,7 +1231,7 @@ DynamicModel::writeModelEquationsCode_Block(const string &basename, bool linear_ int varr = 0; // Dummy value, actually unused by the bytecode MEX FNUMEXPR_ fnumexpr(FirstExodetDerivative, eqr, varr, lag); fnumexpr.write(code_file, instruction_number); - d->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, true, false); + d->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false); FSTPG3_ fstpg3(eq, var, lag, blocks_jacob_cols_exo_det[block].at({ var, lag })); fstpg3.write(code_file, instruction_number); } @@ -1239,7 +1242,7 @@ DynamicModel::writeModelEquationsCode_Block(const string &basename, bool linear_ int varr = 0; // Dummy value, actually unused by the bytecode MEX FNUMEXPR_ fnumexpr(FirstOtherEndoDerivative, eqr, varr, lag); fnumexpr.write(code_file, instruction_number); - d->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, true, false); + d->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false); FSTPG3_ fstpg3(eq, var, lag, blocks_jacob_cols_other_endo[block].at({ var, lag })); fstpg3.write(code_file, instruction_number); } diff --git a/src/DynamicModel.hh b/src/DynamicModel.hh index 75e81543..44004f0e 100644 --- a/src/DynamicModel.hh +++ b/src/DynamicModel.hh @@ -166,9 +166,9 @@ private: map> &reference_count) const override; //! Write derivative code of an equation w.r. to a variable - void compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id, int lag) const; + void compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const; //! Write chain rule derivative code of an equation w.r. to a variable - void compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag) const; + void compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const; //! Get the type corresponding to a derivation ID SymbolType getTypeByDerivID(int deriv_id) const noexcept(false) override; diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 28da5c6e..7e70c532 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -187,10 +187,10 @@ ExprNode::writeOutput(ostream &output, ExprNodeOutputType output_type, const tem void ExprNode::compile(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic) const { - compile(CompileCode, instruction_number, lhs_rhs, temporary_terms_idxs, dynamic, steady_dynamic, {}); + compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, temporary_terms_idxs, dynamic, steady_dynamic, {}); } void @@ -213,7 +213,7 @@ ExprNode::writeJsonExternalFunctionOutput(vector &efout, void ExprNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, deriv_node_temp_terms_t &tef_terms) const { @@ -459,7 +459,7 @@ NumConstNode::eval(const eval_context_t &eval_context) const noexcept(false) void NumConstNode::compile(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const { @@ -1232,13 +1232,13 @@ VariableNode::eval(const eval_context_t &eval_context) const noexcept(false) void VariableNode::compile(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const { auto type = get_type(); if (type == SymbolType::modelLocalVariable || type == SymbolType::modFileLocalVariable) - datatree.getLocalVariable(symb_id)->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); + datatree.getLocalVariable(symb_id)->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); else { int tsid = datatree.symbol_table.getTypeSpecificID(symb_id); @@ -2866,12 +2866,12 @@ UnaryOpNode::writeJsonExternalFunctionOutput(vector &efout, void UnaryOpNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, deriv_node_temp_terms_t &tef_terms) const { - arg->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms_idxs, - dynamic, steady_dynamic, tef_terms); + arg->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms, + temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); } double @@ -2952,30 +2952,30 @@ UnaryOpNode::eval(const eval_context_t &eval_context) const noexcept(false) void UnaryOpNode::compile(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const { - if (auto it = temporary_terms_idxs.find(const_cast(this)); - it != temporary_terms_idxs.end()) + if (auto this2 = const_cast(this); + temporary_terms.find(this2) != temporary_terms.end()) { if (dynamic) { - FLDT_ fldt(it->second); + FLDT_ fldt(temporary_terms_idxs.at(this2)); fldt.write(CompileCode, instruction_number); } else { - FLDST_ fldst(it->second); + FLDST_ fldst(temporary_terms_idxs.at(this2)); fldst.write(CompileCode, instruction_number); } return; } if (op_code == UnaryOpcode::steadyState) - arg->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms_idxs, dynamic, true, tef_terms); + arg->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, temporary_terms_idxs, dynamic, true, tef_terms); else { - arg->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); + arg->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); FUNARY_ funary{static_cast(op_code)}; funary.write(CompileCode, instruction_number); } @@ -4231,22 +4231,22 @@ BinaryOpNode::eval(const eval_context_t &eval_context) const noexcept(false) void BinaryOpNode::compile(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const { // If current node is a temporary term - if (auto it = temporary_terms_idxs.find(const_cast(this)); - it != temporary_terms_idxs.end()) + if (auto this2 = const_cast(this); + temporary_terms.find(this2) != temporary_terms.end()) { if (dynamic) { - FLDT_ fldt(it->second); + FLDT_ fldt(temporary_terms_idxs.at(this2)); fldt.write(CompileCode, instruction_number); } else { - FLDST_ fldst(it->second); + FLDST_ fldst(temporary_terms_idxs.at(this2)); fldst.write(CompileCode, instruction_number); } return; @@ -4256,8 +4256,8 @@ BinaryOpNode::compile(ostream &CompileCode, unsigned int &instruction_number, FLDC_ fldc(powerDerivOrder); fldc.write(CompileCode, instruction_number); } - arg1->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); - arg2->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); + arg1->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); + arg2->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); FBINARY_ fbinary{static_cast(op_code)}; fbinary.write(CompileCode, instruction_number); } @@ -4653,14 +4653,14 @@ BinaryOpNode::writeJsonExternalFunctionOutput(vector &efout, void BinaryOpNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, deriv_node_temp_terms_t &tef_terms) const { - arg1->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms_idxs, - dynamic, steady_dynamic, tef_terms); - arg2->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms_idxs, - dynamic, steady_dynamic, tef_terms); + arg1->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms, + temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); + arg2->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms, + temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); } int @@ -5953,29 +5953,29 @@ TrinaryOpNode::eval(const eval_context_t &eval_context) const noexcept(false) void TrinaryOpNode::compile(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const { // If current node is a temporary term - if (auto it = temporary_terms_idxs.find(const_cast(this)); - it != temporary_terms_idxs.end()) + if (auto this2 = const_cast(this); + temporary_terms.find(this2) != temporary_terms.end()) { if (dynamic) { - FLDT_ fldt(it->second); + FLDT_ fldt(temporary_terms_idxs.at(this2)); fldt.write(CompileCode, instruction_number); } else { - FLDST_ fldst(it->second); + FLDST_ fldst(temporary_terms_idxs.at(this2)); fldst.write(CompileCode, instruction_number); } return; } - arg1->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); - arg2->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); - arg3->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); + arg1->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); + arg2->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); + arg3->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); FTRINARY_ ftrinary{static_cast(op_code)}; ftrinary.write(CompileCode, instruction_number); } @@ -6128,16 +6128,16 @@ TrinaryOpNode::writeJsonExternalFunctionOutput(vector &efout, void TrinaryOpNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, deriv_node_temp_terms_t &tef_terms) const { - arg1->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms_idxs, - dynamic, steady_dynamic, tef_terms); - arg2->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms_idxs, - dynamic, steady_dynamic, tef_terms); - arg3->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms_idxs, - dynamic, steady_dynamic, tef_terms); + arg1->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms, + temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); + arg2->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms, + temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); + arg3->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms, + temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); } void @@ -6636,13 +6636,13 @@ AbstractExternalFunctionNode::getChainRuleDerivative(int deriv_id, const mapcompile(CompileCode, instruction_number, lhs_rhs, temporary_terms_idxs, - dynamic, steady_dynamic, tef_terms); + argument->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, + temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); return (arguments.size()); } @@ -7226,21 +7226,21 @@ ExternalFunctionNode::composeDerivatives(const vector &dargs) void ExternalFunctionNode::compile(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const { - if (auto it = temporary_terms_idxs.find(const_cast(this)); - it != temporary_terms_idxs.end()) + if (auto this2 = const_cast(this); + temporary_terms.find(this2) != temporary_terms.end()) { if (dynamic) { - FLDT_ fldt(it->second); + FLDT_ fldt(temporary_terms_idxs.at(this2)); fldt.write(CompileCode, instruction_number); } else { - FLDST_ fldst(it->second); + FLDST_ fldst(temporary_terms_idxs.at(this2)); fldst.write(CompileCode, instruction_number); } return; @@ -7260,7 +7260,7 @@ ExternalFunctionNode::compile(ostream &CompileCode, unsigned int &instruction_nu void ExternalFunctionNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, deriv_node_temp_terms_t &tef_terms) const { @@ -7268,7 +7268,7 @@ ExternalFunctionNode::compileExternalFunctionOutput(ostream &CompileCode, unsign assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided); for (auto argument : arguments) - argument->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, + argument->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); if (!alreadyWrittenAsTefTerm(symb_id, tef_terms)) @@ -7286,7 +7286,7 @@ ExternalFunctionNode::compileExternalFunctionOutput(ostream &CompileCode, unsign nb_output_arguments = 2; else nb_output_arguments = 1; - unsigned int nb_input_arguments = compileExternalFunctionArguments(CompileCode, instruction_number, lhs_rhs, + unsigned int nb_input_arguments = compileExternalFunctionArguments(CompileCode, instruction_number, lhs_rhs, temporary_terms, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); FCALL_ fcall(nb_output_arguments, nb_input_arguments, datatree.symbol_table.getName(symb_id), indx); @@ -7616,21 +7616,21 @@ FirstDerivExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType void FirstDerivExternalFunctionNode::compile(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const { - if (auto it = temporary_terms_idxs.find(const_cast(this)); - it != temporary_terms_idxs.end()) + if (auto this2 = const_cast(this); + temporary_terms.find(this2) != temporary_terms.end()) { if (dynamic) { - FLDT_ fldt(it->second); + FLDT_ fldt(temporary_terms_idxs.at(this2)); fldt.write(CompileCode, instruction_number); } else { - FLDST_ fldst(it->second); + FLDST_ fldst(temporary_terms_idxs.at(this2)); fldst.write(CompileCode, instruction_number); } return; @@ -7797,7 +7797,7 @@ FirstDerivExternalFunctionNode::writeJsonExternalFunctionOutput(vector & void FirstDerivExternalFunctionNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, deriv_node_temp_terms_t &tef_terms) const { @@ -7807,7 +7807,7 @@ FirstDerivExternalFunctionNode::compileExternalFunctionOutput(ostream &CompileCo if (first_deriv_symb_id == symb_id || alreadyWrittenAsTefTerm(first_deriv_symb_id, tef_terms)) return; - unsigned int nb_add_input_arguments = compileExternalFunctionArguments(CompileCode, instruction_number, lhs_rhs, + unsigned int nb_add_input_arguments = compileExternalFunctionArguments(CompileCode, instruction_number, lhs_rhs, temporary_terms, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet) { @@ -8183,7 +8183,7 @@ SecondDerivExternalFunctionNode::computeXrefs(EquationInfo &ei) const void SecondDerivExternalFunctionNode::compile(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const { @@ -8193,7 +8193,7 @@ SecondDerivExternalFunctionNode::compile(ostream &CompileCode, unsigned int &ins void SecondDerivExternalFunctionNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, deriv_node_temp_terms_t &tef_terms) const { @@ -8429,7 +8429,7 @@ VarExpectationNode::collectDynamicVariables(SymbolType type_arg, set &result) const override; void collectDynamicVariables(SymbolType type_arg, set> &result) const override; double eval(const eval_context_t &eval_context) const noexcept(false) override; - void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override; + void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override; expr_t toStatic(DataTree &static_datatree) const override; void computeXrefs(EquationInfo &ei) const override; void computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const override; @@ -827,7 +827,7 @@ public: void collectVARLHSVariable(set &result) const override; void collectDynamicVariables(SymbolType type_arg, set> &result) const override; double eval(const eval_context_t &eval_context) const noexcept(false) override; - void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override; + void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override; expr_t toStatic(DataTree &static_datatree) const override; void computeXrefs(EquationInfo &ei) const override; SymbolType get_type() const; @@ -925,14 +925,14 @@ public: deriv_node_temp_terms_t &tef_terms, bool isdynamic) const override; void compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, deriv_node_temp_terms_t &tef_terms) const override; void collectVARLHSVariable(set &result) const override; void collectDynamicVariables(SymbolType type_arg, set> &result) const override; static double eval_opcode(UnaryOpcode op_code, double v) noexcept(false); double eval(const eval_context_t &eval_context) const noexcept(false) override; - void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override; + void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override; expr_t toStatic(DataTree &static_datatree) const override; void computeXrefs(EquationInfo &ei) const override; void computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const override; @@ -1032,14 +1032,14 @@ public: deriv_node_temp_terms_t &tef_terms, bool isdynamic) const override; void compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, deriv_node_temp_terms_t &tef_terms) const override; void collectVARLHSVariable(set &result) const override; void collectDynamicVariables(SymbolType type_arg, set> &result) const override; static double eval_opcode(double v1, BinaryOpcode op_code, double v2, int derivOrder) noexcept(false); double eval(const eval_context_t &eval_context) const noexcept(false) override; - void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override; + void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override; expr_t Compute_RHS(expr_t arg1, expr_t arg2, int op, int op_type) const; expr_t toStatic(DataTree &static_datatree) const override; void computeXrefs(EquationInfo &ei) const override; @@ -1163,14 +1163,14 @@ public: deriv_node_temp_terms_t &tef_terms, bool isdynamic) const override; void compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, deriv_node_temp_terms_t &tef_terms) const override; void collectVARLHSVariable(set &result) const override; void collectDynamicVariables(SymbolType type_arg, set> &result) const override; static double eval_opcode(double v1, TrinaryOpcode op_code, double v2, double v3) noexcept(false); double eval(const eval_context_t &eval_context) const noexcept(false) override; - void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override; + void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override; expr_t toStatic(DataTree &static_datatree) const override; void computeXrefs(EquationInfo &ei) const override; void computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const override; @@ -1275,18 +1275,18 @@ public: deriv_node_temp_terms_t &tef_terms, bool isdynamic = true) const override = 0; void compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, deriv_node_temp_terms_t &tef_terms) const override = 0; void collectVARLHSVariable(set &result) const override; void collectDynamicVariables(SymbolType type_arg, set> &result) const override; double eval(const eval_context_t &eval_context) const noexcept(false) override; unsigned int compileExternalFunctionArguments(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const; - void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override = 0; + void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override = 0; expr_t toStatic(DataTree &static_datatree) const override = 0; void computeXrefs(EquationInfo &ei) const override = 0; void computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const override; @@ -1365,10 +1365,10 @@ public: deriv_node_temp_terms_t &tef_terms, bool isdynamic) const override; void compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, deriv_node_temp_terms_t &tef_terms) const override; - void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override; + void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override; expr_t toStatic(DataTree &static_datatree) const override; void computeXrefs(EquationInfo &ei) const override; expr_t buildSimilarExternalFunctionNode(vector &alt_args, DataTree &alt_datatree) const override; @@ -1392,7 +1392,7 @@ public: void writeJsonAST(ostream &output) const override; void writeJsonOutput(ostream &output, const temporary_terms_t &temporary_terms, const deriv_node_temp_terms_t &tef_terms, bool isdynamic) const override; void compile(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override; void writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type, @@ -1404,7 +1404,7 @@ public: deriv_node_temp_terms_t &tef_terms, bool isdynamic) const override; void compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, deriv_node_temp_terms_t &tef_terms) const override; expr_t toStatic(DataTree &static_datatree) const override; @@ -1432,7 +1432,7 @@ public: void writeJsonAST(ostream &output) const override; void writeJsonOutput(ostream &output, const temporary_terms_t &temporary_terms, const deriv_node_temp_terms_t &tef_terms, bool isdynamic) const override; void compile(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override; void writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type, @@ -1444,7 +1444,7 @@ public: deriv_node_temp_terms_t &tef_terms, bool isdynamic) const override; void compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, deriv_node_temp_terms_t &tef_terms) const override; expr_t toStatic(DataTree &static_datatree) const override; @@ -1502,7 +1502,7 @@ public: void computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const override; BinaryOpNode *normalizeEquationHelper(const set &contain_var, expr_t rhs) const override; void compile(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override; void collectVARLHSVariable(set &result) const override; @@ -1579,7 +1579,7 @@ public: void computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const override; BinaryOpNode *normalizeEquationHelper(const set &contain_var, expr_t rhs) const override; void compile(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, + bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override; void collectVARLHSVariable(set &result) const override; diff --git a/src/ModelTree.cc b/src/ModelTree.cc index 2bd95e79..0aba63f5 100644 --- a/src/ModelTree.cc +++ b/src/ModelTree.cc @@ -1314,18 +1314,18 @@ ModelTree::testNestedParenthesis(const string &str) const } void -ModelTree::compileTemporaryTerms(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic) const +ModelTree::compileTemporaryTerms(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic, temporary_terms_t &temporary_terms_union, const temporary_terms_idxs_t &temporary_terms_idxs) const { // To store the functions that have already been written in the form TEF* = ext_fun(); deriv_node_temp_terms_t tef_terms; for (auto [tt, idx] : temporary_terms_idxs) { if (dynamic_cast(tt)) - tt->compileExternalFunctionOutput(code_file, instruction_number, false, blocks_temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); + tt->compileExternalFunctionOutput(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); FNUMEXPR_ fnumexpr(TemporaryTerm, idx); fnumexpr.write(code_file, instruction_number); - tt->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); + tt->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); if (dynamic) { FSTPT_ fstpt(idx); @@ -1455,7 +1455,7 @@ ModelTree::writeModelEquations(ostream &output, ExprNodeOutputType output_type, } void -ModelTree::compileModelEquations(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic) const +ModelTree::compileModelEquations(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic, const temporary_terms_t &temporary_terms_union, const temporary_terms_idxs_t &temporary_terms_idxs) const { for (int eq = 0; eq < static_cast(equations.size()); eq++) { @@ -1476,8 +1476,8 @@ ModelTree::compileModelEquations(ostream &code_file, unsigned int &instruction_n if (vrhs != 0) // The right hand side of the equation is not empty ==> residual=lhs-rhs; { - lhs->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, dynamic, steady_dynamic); - rhs->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, dynamic, steady_dynamic); + lhs->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, dynamic, steady_dynamic); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, dynamic, steady_dynamic); FBINARY_ fbinary{static_cast(BinaryOpcode::minus)}; fbinary.write(code_file, instruction_number); @@ -1487,7 +1487,7 @@ ModelTree::compileModelEquations(ostream &code_file, unsigned int &instruction_n } else // The right hand side of the equation is empty ==> residual=lhs; { - lhs->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, dynamic, steady_dynamic); + lhs->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, dynamic, steady_dynamic); FSTPR_ fstpr(eq); fstpr.write(code_file, instruction_number); } diff --git a/src/ModelTree.hh b/src/ModelTree.hh index 47a3a122..0cd0db8b 100644 --- a/src/ModelTree.hh +++ b/src/ModelTree.hh @@ -231,7 +231,7 @@ protected: void writeTemporaryTerms(const temporary_terms_t &tt, temporary_terms_t &temp_term_union, const temporary_terms_idxs_t &tt_idxs, ostream &output, ExprNodeOutputType output_type, deriv_node_temp_terms_t &tef_terms) const; void writeJsonTemporaryTerms(const temporary_terms_t &tt, temporary_terms_t &temp_term_union, ostream &output, deriv_node_temp_terms_t &tef_terms, const string &concat) const; //! Compiles temporary terms - void compileTemporaryTerms(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic) const; + void compileTemporaryTerms(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic, temporary_terms_t &temporary_terms_union, const temporary_terms_idxs_t &temporary_terms_idxs) const; //! Adds informations for simulation in a binary file void Write_Inf_To_Bin_File(const string &filename, int &u_count_int, bool &file_open, bool is_two_boundaries, int block_mfs) const; //! Fixes output when there are more than 32 nested parens, Issue #1201 @@ -252,7 +252,7 @@ protected: void writeJsonModelEquations(ostream &output, bool residuals) const; void writeJsonModelLocalVariables(ostream &output, deriv_node_temp_terms_t &tef_terms) const; //! Compiles model equations - void compileModelEquations(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic) const; + void compileModelEquations(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic, const temporary_terms_t &temporary_terms_union, const temporary_terms_idxs_t &temporary_terms_idxs) const; //! Writes LaTeX model file void writeLatexModelFile(const string &mod_basename, const string &latex_basename, ExprNodeOutputType output_type, bool write_equation_tags) const; diff --git a/src/StaticModel.cc b/src/StaticModel.cc index db96d48c..1f5ef4fa 100644 --- a/src/StaticModel.cc +++ b/src/StaticModel.cc @@ -102,11 +102,11 @@ StaticModel::StaticModel(const DynamicModel &m) : } void -StaticModel::compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id) const +StaticModel::compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const { if (auto it = derivatives[1].find({ eq, getDerivID(symbol_table.getID(SymbolType::endogenous, symb_id), 0) }); it != derivatives[1].end()) - it->second->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, false, false); + it->second->compile(code_file, instruction_number, false, temporary_terms, temporary_terms_idxs, false, false); else { FLDZ_ fldz; @@ -115,11 +115,11 @@ StaticModel::compileDerivative(ofstream &code_file, unsigned int &instruction_nu } void -StaticModel::compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag) const +StaticModel::compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const { if (auto it = blocks_derivatives[blk].find({ eq, var, lag }); it != blocks_derivatives[blk].end()) - it->second->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, false, false); + it->second->compile(code_file, instruction_number, false, temporary_terms, temporary_terms_idxs, false, false); else { FLDZ_ fldz; @@ -327,9 +327,10 @@ StaticModel::writeModelEquationsCode(const string &basename) const symbol_table.endo_nbr()); fbeginblock.write(code_file, instruction_number); - compileTemporaryTerms(code_file, instruction_number, false, false); + temporary_terms_t temporary_terms_union; + compileTemporaryTerms(code_file, instruction_number, false, false, temporary_terms_union, temporary_terms_idxs); - compileModelEquations(code_file, instruction_number, false, false); + compileModelEquations(code_file, instruction_number, false, false, temporary_terms_union, temporary_terms_idxs); FENDEQU_ fendequ; fendequ.write(code_file, instruction_number); @@ -356,7 +357,7 @@ StaticModel::writeModelEquationsCode(const string &basename) const my_derivatives[eq].clear(); my_derivatives[eq].emplace_back(var, count_u); - d1->compile(code_file, instruction_number, false, temporary_terms_idxs, false, false); + d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, false, false); FSTPSU_ fstpsu(count_u); fstpsu.write(code_file, instruction_number); @@ -418,7 +419,7 @@ StaticModel::writeModelEquationsCode(const string &basename) const my_derivatives[eq].clear(); my_derivatives[eq].emplace_back(var, count_u); - d1->compile(code_file, instruction_number, false, temporary_terms_idxs, false, false); + d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, false, false); FSTPG2_ fstpg2(eq, var); fstpg2.write(code_file, instruction_number); } @@ -478,6 +479,8 @@ StaticModel::writeModelEquationsCode_Block(const string &basename) const FDIMST_ fdimst(blocks_temporary_terms_idxs.size()); fdimst.write(code_file, instruction_number); + temporary_terms_t temporary_terms_union; + for (int block = 0; block < static_cast(blocks.size()); block++) { feedback_variables.clear(); @@ -525,16 +528,20 @@ StaticModel::writeModelEquationsCode_Block(const string &basename) const //The Temporary terms deriv_node_temp_terms_t tef_terms; + /* We are force to use a copy of tt_union here, since temp. terms are + written a second time below. This is probably unwanted… */ + temporary_terms_t tt2 = temporary_terms_union; for (auto it : blocks_temporary_terms[block]) { if (dynamic_cast(it)) - it->compileExternalFunctionOutput(code_file, instruction_number, false, blocks_temporary_terms_idxs, false, false, tef_terms); + it->compileExternalFunctionOutput(code_file, instruction_number, false, tt2, blocks_temporary_terms_idxs, false, false, tef_terms); FNUMEXPR_ fnumexpr(TemporaryTerm, static_cast(blocks_temporary_terms_idxs.at(it))); fnumexpr.write(code_file, instruction_number); - it->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, false, false, tef_terms); + it->compile(code_file, instruction_number, false, tt2, blocks_temporary_terms_idxs, false, false, tef_terms); FSTPST_ fstpst(static_cast(blocks_temporary_terms_idxs.at(it))); fstpst.write(code_file, instruction_number); + tt2.insert(it); } for (i = 0; i < block_size; i++) @@ -557,16 +564,16 @@ StaticModel::writeModelEquationsCode_Block(const string &basename) const eq_node = getBlockEquationExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - rhs->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, false, false); - lhs->compile(code_file, instruction_number, true, blocks_temporary_terms_idxs, false, false); + rhs->compile(code_file, instruction_number, false, tt2, blocks_temporary_terms_idxs, false, false); + lhs->compile(code_file, instruction_number, true, tt2, blocks_temporary_terms_idxs, false, false); } else if (equ_type == EquationType::evaluate_s) { eq_node = getBlockEquationRenormalizedExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - rhs->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, false, false); - lhs->compile(code_file, instruction_number, true, blocks_temporary_terms_idxs, false, false); + rhs->compile(code_file, instruction_number, false, tt2, blocks_temporary_terms_idxs, false, false); + lhs->compile(code_file, instruction_number, true, tt2, blocks_temporary_terms_idxs, false, false); } break; case BlockSimulationType::solveBackwardComplete: @@ -585,8 +592,8 @@ StaticModel::writeModelEquationsCode_Block(const string &basename) const eq_node = getBlockEquationExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - lhs->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, false, false); - rhs->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, false, false); + lhs->compile(code_file, instruction_number, false, tt2, blocks_temporary_terms_idxs, false, false); + rhs->compile(code_file, instruction_number, false, tt2, blocks_temporary_terms_idxs, false, false); FBINARY_ fbinary{static_cast(BinaryOpcode::minus)}; fbinary.write(code_file, instruction_number); @@ -610,7 +617,7 @@ StaticModel::writeModelEquationsCode_Block(const string &basename) const FNUMEXPR_ fnumexpr(FirstEndoDerivative, 0, 0); fnumexpr.write(code_file, instruction_number); } - compileDerivative(code_file, instruction_number, getBlockEquationID(block, 0), getBlockVariableID(block, 0)); + compileDerivative(code_file, instruction_number, getBlockEquationID(block, 0), getBlockVariableID(block, 0), tt2, blocks_temporary_terms_idxs); { FSTPG_ fstpg(0); fstpg.write(code_file, instruction_number); @@ -642,7 +649,7 @@ StaticModel::writeModelEquationsCode_Block(const string &basename) const Uf[eqr].Ufl->var = varr; FNUMEXPR_ fnumexpr(FirstEndoDerivative, eqr, varr); fnumexpr.write(code_file, instruction_number); - compileChainRuleDerivative(code_file, instruction_number, block, eq, var, 0); + compileChainRuleDerivative(code_file, instruction_number, block, eq, var, 0, tt2, blocks_temporary_terms_idxs); FSTPSU_ fstpsu(count_u); fstpsu.write(code_file, instruction_number); count_u++; @@ -709,13 +716,14 @@ StaticModel::writeModelEquationsCode_Block(const string &basename) const for (auto it : blocks_temporary_terms[block]) { if (dynamic_cast(it)) - it->compileExternalFunctionOutput(code_file, instruction_number, false, blocks_temporary_terms_idxs, false, false, tef_terms); + it->compileExternalFunctionOutput(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); FNUMEXPR_ fnumexpr(TemporaryTerm, static_cast(blocks_temporary_terms_idxs.at(it))); fnumexpr.write(code_file, instruction_number); - it->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, false, false, tef_terms); + it->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); FSTPST_ fstpst(static_cast(blocks_temporary_terms_idxs.at(it))); fstpst.write(code_file, instruction_number); + temporary_terms_union.insert(it); } for (i = 0; i < block_size; i++) @@ -738,16 +746,16 @@ StaticModel::writeModelEquationsCode_Block(const string &basename) const eq_node = getBlockEquationExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - rhs->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, false, false); - lhs->compile(code_file, instruction_number, true, blocks_temporary_terms_idxs, false, false); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false); + lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, false, false); } else if (equ_type == EquationType::evaluate_s) { eq_node = getBlockEquationRenormalizedExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - rhs->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, false, false); - lhs->compile(code_file, instruction_number, true, blocks_temporary_terms_idxs, false, false); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false); + lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, false, false); } break; case BlockSimulationType::solveBackwardComplete: @@ -766,8 +774,8 @@ StaticModel::writeModelEquationsCode_Block(const string &basename) const eq_node = getBlockEquationExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - lhs->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, false, false); - rhs->compile(code_file, instruction_number, false, blocks_temporary_terms_idxs, false, false); + lhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false); FBINARY_ fbinary{static_cast(BinaryOpcode::minus)}; fbinary.write(code_file, instruction_number); @@ -788,7 +796,7 @@ StaticModel::writeModelEquationsCode_Block(const string &basename) const FNUMEXPR_ fnumexpr(FirstEndoDerivative, 0, 0); fnumexpr.write(code_file, instruction_number); } - compileDerivative(code_file, instruction_number, getBlockEquationID(block, 0), getBlockVariableID(block, 0)); + compileDerivative(code_file, instruction_number, getBlockEquationID(block, 0), getBlockVariableID(block, 0), temporary_terms_union, blocks_temporary_terms_idxs); { FSTPG2_ fstpg2(0, 0); fstpg2.write(code_file, instruction_number); @@ -807,7 +815,7 @@ StaticModel::writeModelEquationsCode_Block(const string &basename) const FNUMEXPR_ fnumexpr(FirstEndoDerivative, eqr, varr, 0); fnumexpr.write(code_file, instruction_number); - compileChainRuleDerivative(code_file, instruction_number, block, eq, var, 0); + compileChainRuleDerivative(code_file, instruction_number, block, eq, var, 0, temporary_terms_union, blocks_temporary_terms_idxs); FSTPG2_ fstpg2(eq, var); fstpg2.write(code_file, instruction_number); diff --git a/src/StaticModel.hh b/src/StaticModel.hh index 1c191feb..c05fef59 100644 --- a/src/StaticModel.hh +++ b/src/StaticModel.hh @@ -65,9 +65,9 @@ private: void evaluateJacobian(const eval_context_t &eval_context, jacob_map_t *j_m, bool dynamic); //! Write derivative code of an equation w.r. to a variable - void compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id) const; + void compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const; //! Write chain rule derivative code of an equation w.r. to a variable - void compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag) const; + void compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const; //! Get the type corresponding to a derivation ID SymbolType getTypeByDerivID(int deriv_id) const noexcept(false) override;