From d789e6a4c5f0f2a627a3f3b5a840c38ed3f13478 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= Date: Tue, 18 Apr 2023 15:49:56 +0200 Subject: [PATCH] Bytecode: move storage for variable and parameter values out of the Evaluate class --- mex/sources/bytecode/Evaluate.cc | 81 +++++++--------------------- mex/sources/bytecode/Evaluate.hh | 21 ++------ mex/sources/bytecode/Interpreter.cc | 35 ++++++++++++ mex/sources/bytecode/Interpreter.hh | 1 + mex/sources/bytecode/SparseMatrix.cc | 22 +++++++- mex/sources/bytecode/SparseMatrix.hh | 16 ++++++ 6 files changed, 95 insertions(+), 81 deletions(-) diff --git a/mex/sources/bytecode/Evaluate.cc b/mex/sources/bytecode/Evaluate.cc index 54ad3f613..c81dc4fda 100644 --- a/mex/sources/bytecode/Evaluate.cc +++ b/mex/sources/bytecode/Evaluate.cc @@ -29,12 +29,8 @@ #include "CommonEnums.hh" #include "ErrorHandling.hh" -Evaluate::Evaluate(int y_size_arg, int y_kmin_arg, int y_kmax_arg, bool steady_state_arg, int periods_arg, BasicSymbolTable &symbol_table_arg) : +Evaluate::Evaluate(bool steady_state_arg, BasicSymbolTable &symbol_table_arg) : symbol_table {symbol_table_arg}, - y_size {y_size_arg}, - y_kmin {y_kmin_arg}, - y_kmax {y_kmax_arg}, - periods {periods_arg}, steady_state {steady_state_arg} { } @@ -989,7 +985,7 @@ Evaluate::print_expression(const Evaluate::it_code_type &expr_begin, const optio } void -Evaluate::evaluateBlock(int Per_u_, bool evaluate, bool no_derivative) +Evaluate::evaluateBlock(int it_, double *y, const double *ya, int y_size, double *x, int nb_row_x, double *params, const double *steady_y, double *u, int Per_u_, double *T, int T_nrows, map &TEF, map, double> &TEFD, map, double> &TEFDD, double *r, double *g1, double *jacob, double *jacob_exo, double *jacob_exo_det, bool evaluate, bool no_derivatives) { auto it_code { currentBlockBeginning() }; int var{0}, lag{0}; @@ -1004,24 +1000,11 @@ Evaluate::evaluateBlock(int Per_u_, bool evaluate, bool no_derivative) bool go_on = true; double ll; double rr; - double *jacob = nullptr, *jacob_exo = nullptr, *jacob_exo_det = nullptr; EQN_block = block_num; stack Stack; ExternalFunctionCallType call_type{ExternalFunctionCallType::levelWithoutDerivative}; it_code_type it_code_expr; -#ifdef DEBUG - mexPrintf("compute_block_time\n"); -#endif - if (evaluate) - { - jacob = mxGetPr(jacobian_block[block_num]); - if (!steady_state) - { - jacob_exo = mxGetPr(jacobian_exo_block[block_num]); - jacob_exo_det = mxGetPr(jacobian_det_exo_block[block_num]); - } - } #ifdef MATLAB_MEX_FILE if (utIsInterruptPending()) throw UserException{}; @@ -1221,10 +1204,10 @@ Evaluate::evaluateBlock(int Per_u_, bool evaluate, bool no_derivative) //load a temporary variable in the processor var = static_cast(*it_code)->get_pos(); #ifdef DEBUG - mexPrintf("FLDT T[it_=%d var=%d, y_kmin=%d, y_kmax=%d == %d]=>%f\n", it_, var, y_kmin, y_kmax, var*(periods+y_kmin+y_kmax)+it_, T[var*(periods+y_kmin+y_kmax)+it_]); - tmp_out << " T[" << it_ << ", " << var << "](" << T[var*(periods+y_kmin+y_kmax)+it_] << ")"; + mexPrintf("FLDT T[it_=%d var=%d, y_kmin=%d, y_kmax=%d == %d]=>%f\n", it_, var, y_kmin, y_kmax, var*T_nrows+it_, T[var*T_nrows+it_]); + tmp_out << " T[" << it_ << ", " << var << "](" << T[var*T_nrows+it_] << ")"; #endif - Stack.push(T[var*(periods+y_kmin+y_kmax)+it_]); + Stack.push(T[var*T_nrows+it_]); break; case Tags::FLDST: //load a temporary variable in the processor @@ -1369,10 +1352,10 @@ Evaluate::evaluateBlock(int Per_u_, bool evaluate, bool no_derivative) mexPrintf("FSTPT\n"); #endif var = static_cast(*it_code)->get_pos(); - T[var*(periods+y_kmin+y_kmax)+it_] = Stack.top(); + T[var*T_nrows+it_] = Stack.top(); #ifdef DEBUG tmp_out << "=>"; - mexPrintf(" T[%d, %d](%f)=%s\n", it_, var, T[var*(periods+y_kmin+y_kmax)+it_], tmp_out.str().c_str()); + mexPrintf(" T[%d, %d](%f)=%s\n", it_, var, T[var*T_nrows+it_], tmp_out.str().c_str()); tmp_out.str(""); #endif @@ -2209,7 +2192,7 @@ Evaluate::evaluateBlock(int Per_u_, bool evaluate, bool no_derivative) mexPrintf("Impossible case in Bytecode\n"); break; case Tags::FENDEQU: - if (no_derivative) + if (no_derivatives) go_on = false; break; case Tags::FJMPIFEVAL: @@ -2289,46 +2272,22 @@ Evaluate::printCurrentBlock() } } -void -Evaluate::initializeTemporaryTerms(bool global_temporary_terms) +int +Evaluate::getNumberOfTemporaryTerms() const { BytecodeInstruction *instr {instructions_list.front()}; - if (instr->op_code == Tags::FDIMT) + if (steady_state) { - int ntt {reinterpret_cast(instr)->get_size()}; -#ifdef DEBUG - mexPrintf("FDIMT size=%d\n", ntt); -#endif - if (T) - mxFree(T); - T = static_cast(mxMalloc(ntt*(periods+y_kmin+y_kmax)*sizeof(double))); - test_mxMalloc(T, __LINE__, __FILE__, __func__, ntt*(periods+y_kmin+y_kmax)*sizeof(double)); - } - else if (instr->op_code == Tags::FDIMST) - { - int ntt {reinterpret_cast(instr)->get_size()}; -#ifdef DEBUG - mexPrintf("FDIMST size=%d\n", ntt); -#endif - if (T) - mxFree(T); - if (global_temporary_terms) - { - if (!GlobalTemporaryTerms) - { - mexPrintf("GlobalTemporaryTerms is nullptr\n"); - mexEvalString("drawnow;"); - } - if (ntt != static_cast(mxGetNumberOfElements(GlobalTemporaryTerms))) - GlobalTemporaryTerms = mxCreateDoubleMatrix(ntt, 1, mxREAL); - T = mxGetPr(GlobalTemporaryTerms); - } + if (instr->op_code == Tags::FDIMST) + return reinterpret_cast(instr)->get_size(); else - { - T = static_cast(mxMalloc(ntt*sizeof(double))); - test_mxMalloc(T, __LINE__, __FILE__, __func__, ntt*sizeof(double)); - } + throw FatalException {"Evaluate::getNumberOfTemporaryTerms: static .cod file does not begin with FDIMST!"}; } else - throw FatalException {"Evaluate::initializeTemporaryTerms: .cod file does not begin with FDIMT or FDIMST!"}; + { + if (instr->op_code == Tags::FDIMT) + return reinterpret_cast(instr)->get_size(); + else + throw FatalException {"Evaluate::getNumberOfTemporaryTerms: dynamic .cod file does not begin with FDIMT!"}; + } } diff --git a/mex/sources/bytecode/Evaluate.hh b/mex/sources/bytecode/Evaluate.hh index b54250995..0dfea3f9b 100644 --- a/mex/sources/bytecode/Evaluate.hh +++ b/mex/sources/bytecode/Evaluate.hh @@ -56,9 +56,6 @@ private: ExpressionType EQN_type; int EQN_equation, EQN_block, EQN_dvar1; int EQN_lag1, EQN_lag2, EQN_lag3; - map TEF; - map, double> TEFD; - map, double> TEFDD; string error_location(it_code_type expr_begin, it_code_type faulty_op, int it_) const; @@ -78,19 +75,7 @@ private: protected: BasicSymbolTable &symbol_table; int EQN_block_number; - double *y, *ya; - int y_size; - double *T; - int nb_row_x; - int y_kmin, y_kmax, periods; - double *x, *params; - double *u; - double *steady_y; - double *g1, *r, *res; - vector jacobian_block, jacobian_exo_block, jacobian_det_exo_block; - mxArray *GlobalTemporaryTerms; - void evaluateBlock(int Per_u_, bool evaluate, bool no_derivatives); - int it_; + void evaluateBlock(int it_, double *y, const double *ya, int y_size, double *x, int nb_row_x, double *params, const double *steady_y, double *u, int Per_u_, double *T, int T_nrows, map &TEF, map, double> &TEFD, map, double> &TEFDD, double *r, double *g1, double *jacob, double *jacob_exo, double *jacob_exo_det, bool evaluate, bool no_derivatives); int block_num; // Index of the current block int size; // Size of the current block @@ -113,7 +98,7 @@ protected: void gotoBlock(int block); - void initializeTemporaryTerms(bool global_temporary_terms); + int getNumberOfTemporaryTerms() const; auto getCurrentBlockExogenous() const @@ -142,7 +127,7 @@ protected: } public: - Evaluate(int y_size_arg, int y_kmin_arg, int y_kmax_arg, bool steady_state_arg, int periods_arg, BasicSymbolTable &symbol_table_arg); + Evaluate(bool steady_state_arg, BasicSymbolTable &symbol_table_arg); // TODO: integrate into the constructor void loadCodeFile(const filesystem::path &codfile); diff --git a/mex/sources/bytecode/Interpreter.cc b/mex/sources/bytecode/Interpreter.cc index cba6fecd0..7b11e5497 100644 --- a/mex/sources/bytecode/Interpreter.cc +++ b/mex/sources/bytecode/Interpreter.cc @@ -914,3 +914,38 @@ Interpreter::compute_blocks(const string &file_name, bool evaluate, int block) mxFree(T); return {true, blocks}; } + +void +Interpreter::initializeTemporaryTerms(bool global_temporary_terms) +{ + int ntt { getNumberOfTemporaryTerms() }; + + if (steady_state) + { + if (T) + mxFree(T); + if (global_temporary_terms) + { + if (!GlobalTemporaryTerms) + { + mexPrintf("GlobalTemporaryTerms is nullptr\n"); + mexEvalString("drawnow;"); + } + if (ntt != static_cast(mxGetNumberOfElements(GlobalTemporaryTerms))) + GlobalTemporaryTerms = mxCreateDoubleMatrix(ntt, 1, mxREAL); + T = mxGetPr(GlobalTemporaryTerms); + } + else + { + T = static_cast(mxMalloc(ntt*sizeof(double))); + test_mxMalloc(T, __LINE__, __FILE__, __func__, ntt*sizeof(double)); + } + } + else + { + if (T) + mxFree(T); + T = static_cast(mxMalloc(ntt*(periods+y_kmin+y_kmax)*sizeof(double))); + test_mxMalloc(T, __LINE__, __FILE__, __func__, ntt*(periods+y_kmin+y_kmax)*sizeof(double)); + } +} diff --git a/mex/sources/bytecode/Interpreter.hh b/mex/sources/bytecode/Interpreter.hh index f38386813..c77291a27 100644 --- a/mex/sources/bytecode/Interpreter.hh +++ b/mex/sources/bytecode/Interpreter.hh @@ -45,6 +45,7 @@ private: void solve_simple_one_periods(); void solve_simple_over_periods(bool forward); void compute_complete_2b(bool no_derivatives, double *_res1, double *_res2, double *_max_res, int *_max_res_idx); + void initializeTemporaryTerms(bool global_temporary_terms); protected: void evaluate_a_block(bool initialization, bool single_block, const string &bin_base_name); int simulate_a_block(const vector_table_conditional_local_type &vector_table_conditional_local, bool single_block, const string &bin_base_name); diff --git a/mex/sources/bytecode/SparseMatrix.cc b/mex/sources/bytecode/SparseMatrix.cc index 1b10ac076..574dacde7 100644 --- a/mex/sources/bytecode/SparseMatrix.cc +++ b/mex/sources/bytecode/SparseMatrix.cc @@ -27,10 +27,14 @@ dynSparseMatrix::dynSparseMatrix(int y_size_arg, int y_kmin_arg, int y_kmax_arg, bool print_it_arg, bool steady_state_arg, bool block_decomposed_arg, int periods_arg, int minimal_solving_periods_arg, BasicSymbolTable &symbol_table_arg, bool print_error_arg) : - Evaluate {y_size_arg, y_kmin_arg, y_kmax_arg, steady_state_arg, periods_arg, symbol_table_arg}, + Evaluate {steady_state_arg, symbol_table_arg}, block_decomposed {block_decomposed_arg}, minimal_solving_periods {minimal_solving_periods_arg}, print_it {print_it_arg}, + y_size {y_size_arg}, + y_kmin {y_kmin_arg}, + y_kmax {y_kmax_arg}, + periods {periods_arg}, print_error {print_error_arg} { pivotva = nullptr; @@ -1823,9 +1827,23 @@ dynSparseMatrix::Sparse_transpose(const mxArray *A_m) void dynSparseMatrix::compute_block_time(int Per_u_, bool evaluate, bool no_derivatives) { +#ifdef DEBUG + mexPrintf("compute_block_time\n"); +#endif + double *jacob {nullptr}, *jacob_exo {nullptr}, *jacob_exo_det {nullptr}; + if (evaluate) + { + jacob = mxGetPr(jacobian_block[block_num]); + if (!steady_state) + { + jacob_exo = mxGetPr(jacobian_exo_block[block_num]); + jacob_exo_det = mxGetPr(jacobian_det_exo_block[block_num]); + } + } + try { - evaluateBlock(Per_u_, evaluate, no_derivatives); + evaluateBlock(it_, y, ya, y_size, x, nb_row_x, params, steady_y, u, Per_u_, T, periods+y_kmin+y_kmax, TEF, TEFD, TEFDD, r, g1, jacob, jacob_exo, jacob_exo_det, evaluate, no_derivatives); } catch (FloatingPointException &e) { diff --git a/mex/sources/bytecode/SparseMatrix.hh b/mex/sources/bytecode/SparseMatrix.hh index 0f018e702..d6cfcf258 100644 --- a/mex/sources/bytecode/SparseMatrix.hh +++ b/mex/sources/bytecode/SparseMatrix.hh @@ -202,6 +202,22 @@ protected: int max_res_idx; int *index_vara; + double *y, *ya; + int y_size; + double *T; + int nb_row_x; + int y_kmin, y_kmax, periods; + double *x, *params; + double *u; + double *steady_y; + double *g1, *r, *res; + vector jacobian_block, jacobian_exo_block, jacobian_det_exo_block; + mxArray *GlobalTemporaryTerms; + int it_; + map TEF; + map, double> TEFD; + map, double> TEFDD; + void compute_block_time(int Per_u_, bool evaluate, bool no_derivatives); bool compute_complete(bool no_derivatives, double &res1, double &res2, double &max_res, int &max_res_idx);