Bytecode: move storage for variable and parameter values out of the Evaluate class

mr#2134
Sébastien Villemot 2023-04-18 15:49:56 +02:00
parent d0864689d2
commit d789e6a4c5
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
6 changed files with 95 additions and 81 deletions

View File

@ -29,12 +29,8 @@
#include "CommonEnums.hh"
#include "ErrorHandling.hh"
Evaluate::Evaluate(int y_size_arg, int y_kmin_arg, int y_kmax_arg, bool steady_state_arg, int periods_arg, BasicSymbolTable &symbol_table_arg) :
Evaluate::Evaluate(bool steady_state_arg, BasicSymbolTable &symbol_table_arg) :
symbol_table {symbol_table_arg},
y_size {y_size_arg},
y_kmin {y_kmin_arg},
y_kmax {y_kmax_arg},
periods {periods_arg},
steady_state {steady_state_arg}
{
}
@ -989,7 +985,7 @@ Evaluate::print_expression(const Evaluate::it_code_type &expr_begin, const optio
}
void
Evaluate::evaluateBlock(int Per_u_, bool evaluate, bool no_derivative)
Evaluate::evaluateBlock(int it_, double *y, const double *ya, int y_size, double *x, int nb_row_x, double *params, const double *steady_y, double *u, int Per_u_, double *T, int T_nrows, map<int, double> &TEF, map<pair<int, int>, double> &TEFD, map<tuple<int, int, int>, double> &TEFDD, double *r, double *g1, double *jacob, double *jacob_exo, double *jacob_exo_det, bool evaluate, bool no_derivatives)
{
auto it_code { currentBlockBeginning() };
int var{0}, lag{0};
@ -1004,24 +1000,11 @@ Evaluate::evaluateBlock(int Per_u_, bool evaluate, bool no_derivative)
bool go_on = true;
double ll;
double rr;
double *jacob = nullptr, *jacob_exo = nullptr, *jacob_exo_det = nullptr;
EQN_block = block_num;
stack<double> Stack;
ExternalFunctionCallType call_type{ExternalFunctionCallType::levelWithoutDerivative};
it_code_type it_code_expr;
#ifdef DEBUG
mexPrintf("compute_block_time\n");
#endif
if (evaluate)
{
jacob = mxGetPr(jacobian_block[block_num]);
if (!steady_state)
{
jacob_exo = mxGetPr(jacobian_exo_block[block_num]);
jacob_exo_det = mxGetPr(jacobian_det_exo_block[block_num]);
}
}
#ifdef MATLAB_MEX_FILE
if (utIsInterruptPending())
throw UserException{};
@ -1221,10 +1204,10 @@ Evaluate::evaluateBlock(int Per_u_, bool evaluate, bool no_derivative)
//load a temporary variable in the processor
var = static_cast<FLDT_ *>(*it_code)->get_pos();
#ifdef DEBUG
mexPrintf("FLDT T[it_=%d var=%d, y_kmin=%d, y_kmax=%d == %d]=>%f\n", it_, var, y_kmin, y_kmax, var*(periods+y_kmin+y_kmax)+it_, T[var*(periods+y_kmin+y_kmax)+it_]);
tmp_out << " T[" << it_ << ", " << var << "](" << T[var*(periods+y_kmin+y_kmax)+it_] << ")";
mexPrintf("FLDT T[it_=%d var=%d, y_kmin=%d, y_kmax=%d == %d]=>%f\n", it_, var, y_kmin, y_kmax, var*T_nrows+it_, T[var*T_nrows+it_]);
tmp_out << " T[" << it_ << ", " << var << "](" << T[var*T_nrows+it_] << ")";
#endif
Stack.push(T[var*(periods+y_kmin+y_kmax)+it_]);
Stack.push(T[var*T_nrows+it_]);
break;
case Tags::FLDST:
//load a temporary variable in the processor
@ -1369,10 +1352,10 @@ Evaluate::evaluateBlock(int Per_u_, bool evaluate, bool no_derivative)
mexPrintf("FSTPT\n");
#endif
var = static_cast<FSTPT_ *>(*it_code)->get_pos();
T[var*(periods+y_kmin+y_kmax)+it_] = Stack.top();
T[var*T_nrows+it_] = Stack.top();
#ifdef DEBUG
tmp_out << "=>";
mexPrintf(" T[%d, %d](%f)=%s\n", it_, var, T[var*(periods+y_kmin+y_kmax)+it_], tmp_out.str().c_str());
mexPrintf(" T[%d, %d](%f)=%s\n", it_, var, T[var*T_nrows+it_], tmp_out.str().c_str());
tmp_out.str("");
#endif
@ -2209,7 +2192,7 @@ Evaluate::evaluateBlock(int Per_u_, bool evaluate, bool no_derivative)
mexPrintf("Impossible case in Bytecode\n");
break;
case Tags::FENDEQU:
if (no_derivative)
if (no_derivatives)
go_on = false;
break;
case Tags::FJMPIFEVAL:
@ -2289,46 +2272,22 @@ Evaluate::printCurrentBlock()
}
}
void
Evaluate::initializeTemporaryTerms(bool global_temporary_terms)
int
Evaluate::getNumberOfTemporaryTerms() const
{
BytecodeInstruction *instr {instructions_list.front()};
if (instr->op_code == Tags::FDIMT)
if (steady_state)
{
int ntt {reinterpret_cast<FDIMT_ *>(instr)->get_size()};
#ifdef DEBUG
mexPrintf("FDIMT size=%d\n", ntt);
#endif
if (T)
mxFree(T);
T = static_cast<double *>(mxMalloc(ntt*(periods+y_kmin+y_kmax)*sizeof(double)));
test_mxMalloc(T, __LINE__, __FILE__, __func__, ntt*(periods+y_kmin+y_kmax)*sizeof(double));
}
else if (instr->op_code == Tags::FDIMST)
{
int ntt {reinterpret_cast<FDIMST_ *>(instr)->get_size()};
#ifdef DEBUG
mexPrintf("FDIMST size=%d\n", ntt);
#endif
if (T)
mxFree(T);
if (global_temporary_terms)
{
if (!GlobalTemporaryTerms)
{
mexPrintf("GlobalTemporaryTerms is nullptr\n");
mexEvalString("drawnow;");
}
if (ntt != static_cast<int>(mxGetNumberOfElements(GlobalTemporaryTerms)))
GlobalTemporaryTerms = mxCreateDoubleMatrix(ntt, 1, mxREAL);
T = mxGetPr(GlobalTemporaryTerms);
}
if (instr->op_code == Tags::FDIMST)
return reinterpret_cast<FDIMST_ *>(instr)->get_size();
else
{
T = static_cast<double *>(mxMalloc(ntt*sizeof(double)));
test_mxMalloc(T, __LINE__, __FILE__, __func__, ntt*sizeof(double));
}
throw FatalException {"Evaluate::getNumberOfTemporaryTerms: static .cod file does not begin with FDIMST!"};
}
else
throw FatalException {"Evaluate::initializeTemporaryTerms: .cod file does not begin with FDIMT or FDIMST!"};
{
if (instr->op_code == Tags::FDIMT)
return reinterpret_cast<FDIMT_ *>(instr)->get_size();
else
throw FatalException {"Evaluate::getNumberOfTemporaryTerms: dynamic .cod file does not begin with FDIMT!"};
}
}

View File

@ -56,9 +56,6 @@ private:
ExpressionType EQN_type;
int EQN_equation, EQN_block, EQN_dvar1;
int EQN_lag1, EQN_lag2, EQN_lag3;
map<int, double> TEF;
map<pair<int, int>, double> TEFD;
map<tuple<int, int, int>, double> TEFDD;
string error_location(it_code_type expr_begin, it_code_type faulty_op, int it_) const;
@ -78,19 +75,7 @@ private:
protected:
BasicSymbolTable &symbol_table;
int EQN_block_number;
double *y, *ya;
int y_size;
double *T;
int nb_row_x;
int y_kmin, y_kmax, periods;
double *x, *params;
double *u;
double *steady_y;
double *g1, *r, *res;
vector<mxArray *> jacobian_block, jacobian_exo_block, jacobian_det_exo_block;
mxArray *GlobalTemporaryTerms;
void evaluateBlock(int Per_u_, bool evaluate, bool no_derivatives);
int it_;
void evaluateBlock(int it_, double *y, const double *ya, int y_size, double *x, int nb_row_x, double *params, const double *steady_y, double *u, int Per_u_, double *T, int T_nrows, map<int, double> &TEF, map<pair<int, int>, double> &TEFD, map<tuple<int, int, int>, double> &TEFDD, double *r, double *g1, double *jacob, double *jacob_exo, double *jacob_exo_det, bool evaluate, bool no_derivatives);
int block_num; // Index of the current block
int size; // Size of the current block
@ -113,7 +98,7 @@ protected:
void gotoBlock(int block);
void initializeTemporaryTerms(bool global_temporary_terms);
int getNumberOfTemporaryTerms() const;
auto
getCurrentBlockExogenous() const
@ -142,7 +127,7 @@ protected:
}
public:
Evaluate(int y_size_arg, int y_kmin_arg, int y_kmax_arg, bool steady_state_arg, int periods_arg, BasicSymbolTable &symbol_table_arg);
Evaluate(bool steady_state_arg, BasicSymbolTable &symbol_table_arg);
// TODO: integrate into the constructor
void loadCodeFile(const filesystem::path &codfile);

View File

@ -914,3 +914,38 @@ Interpreter::compute_blocks(const string &file_name, bool evaluate, int block)
mxFree(T);
return {true, blocks};
}
void
Interpreter::initializeTemporaryTerms(bool global_temporary_terms)
{
int ntt { getNumberOfTemporaryTerms() };
if (steady_state)
{
if (T)
mxFree(T);
if (global_temporary_terms)
{
if (!GlobalTemporaryTerms)
{
mexPrintf("GlobalTemporaryTerms is nullptr\n");
mexEvalString("drawnow;");
}
if (ntt != static_cast<int>(mxGetNumberOfElements(GlobalTemporaryTerms)))
GlobalTemporaryTerms = mxCreateDoubleMatrix(ntt, 1, mxREAL);
T = mxGetPr(GlobalTemporaryTerms);
}
else
{
T = static_cast<double *>(mxMalloc(ntt*sizeof(double)));
test_mxMalloc(T, __LINE__, __FILE__, __func__, ntt*sizeof(double));
}
}
else
{
if (T)
mxFree(T);
T = static_cast<double *>(mxMalloc(ntt*(periods+y_kmin+y_kmax)*sizeof(double)));
test_mxMalloc(T, __LINE__, __FILE__, __func__, ntt*(periods+y_kmin+y_kmax)*sizeof(double));
}
}

View File

@ -45,6 +45,7 @@ private:
void solve_simple_one_periods();
void solve_simple_over_periods(bool forward);
void compute_complete_2b(bool no_derivatives, double *_res1, double *_res2, double *_max_res, int *_max_res_idx);
void initializeTemporaryTerms(bool global_temporary_terms);
protected:
void evaluate_a_block(bool initialization, bool single_block, const string &bin_base_name);
int simulate_a_block(const vector_table_conditional_local_type &vector_table_conditional_local, bool single_block, const string &bin_base_name);

View File

@ -27,10 +27,14 @@
dynSparseMatrix::dynSparseMatrix(int y_size_arg, int y_kmin_arg, int y_kmax_arg, bool print_it_arg, bool steady_state_arg, bool block_decomposed_arg, int periods_arg,
int minimal_solving_periods_arg, BasicSymbolTable &symbol_table_arg,
bool print_error_arg) :
Evaluate {y_size_arg, y_kmin_arg, y_kmax_arg, steady_state_arg, periods_arg, symbol_table_arg},
Evaluate {steady_state_arg, symbol_table_arg},
block_decomposed {block_decomposed_arg},
minimal_solving_periods {minimal_solving_periods_arg},
print_it {print_it_arg},
y_size {y_size_arg},
y_kmin {y_kmin_arg},
y_kmax {y_kmax_arg},
periods {periods_arg},
print_error {print_error_arg}
{
pivotva = nullptr;
@ -1823,9 +1827,23 @@ dynSparseMatrix::Sparse_transpose(const mxArray *A_m)
void
dynSparseMatrix::compute_block_time(int Per_u_, bool evaluate, bool no_derivatives)
{
#ifdef DEBUG
mexPrintf("compute_block_time\n");
#endif
double *jacob {nullptr}, *jacob_exo {nullptr}, *jacob_exo_det {nullptr};
if (evaluate)
{
jacob = mxGetPr(jacobian_block[block_num]);
if (!steady_state)
{
jacob_exo = mxGetPr(jacobian_exo_block[block_num]);
jacob_exo_det = mxGetPr(jacobian_det_exo_block[block_num]);
}
}
try
{
evaluateBlock(Per_u_, evaluate, no_derivatives);
evaluateBlock(it_, y, ya, y_size, x, nb_row_x, params, steady_y, u, Per_u_, T, periods+y_kmin+y_kmax, TEF, TEFD, TEFDD, r, g1, jacob, jacob_exo, jacob_exo_det, evaluate, no_derivatives);
}
catch (FloatingPointException &e)
{

View File

@ -202,6 +202,22 @@ protected:
int max_res_idx;
int *index_vara;
double *y, *ya;
int y_size;
double *T;
int nb_row_x;
int y_kmin, y_kmax, periods;
double *x, *params;
double *u;
double *steady_y;
double *g1, *r, *res;
vector<mxArray *> jacobian_block, jacobian_exo_block, jacobian_det_exo_block;
mxArray *GlobalTemporaryTerms;
int it_;
map<int, double> TEF;
map<pair<int, int>, double> TEFD;
map<tuple<int, int, int>, double> TEFDD;
void compute_block_time(int Per_u_, bool evaluate, bool no_derivatives);
bool compute_complete(bool no_derivatives, double &res1, double &res2, double &max_res, int &max_res_idx);