Bytecode: various simplifications

Also improve the naming of the enum class used for identifying the type of
external function call.
master
Sébastien Villemot 2022-07-08 16:00:02 +02:00
parent a7dc96516b
commit 4b30342dc2
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
5 changed files with 67 additions and 93 deletions

View File

@ -46,7 +46,7 @@ operator<<(BytecodeWriter &code_file, const FCALL_ &instr)
code_file.write(reinterpret_cast<const char *>(&instr.add_input_arguments), sizeof instr.add_input_arguments);
code_file.write(reinterpret_cast<const char *>(&instr.row), sizeof instr.row);
code_file.write(reinterpret_cast<const char *>(&instr.col), sizeof instr.col);
code_file.write(reinterpret_cast<const char *>(&instr.function_type), sizeof instr.function_type);
code_file.write(reinterpret_cast<const char *>(&instr.call_type), sizeof instr.call_type);
int size = static_cast<int>(instr.func_name.size());
code_file.write(reinterpret_cast<char *>(&size), sizeof size);

View File

@ -100,6 +100,17 @@ enum class ExpressionType
FirstExodetDerivative,
};
enum class ExternalFunctionCallType
{
levelWithoutDerivative,
levelWithFirstDerivative,
levelWithFirstAndSecondDerivative,
separatelyProvidedFirstDerivative,
numericalFirstDerivative,
separatelyProvidedSecondDerivative,
numericalSecondDerivative
};
struct Block_contain_type
{
int Equation, Variable, Own_Derivative;
@ -672,9 +683,6 @@ public:
class FLDV_ : public TagWithThreeArguments<SymbolType, int, int>
{
public:
FLDV_(SymbolType type_arg, int pos_arg) : TagWithThreeArguments::TagWithThreeArguments{Tags::FLDV, type_arg, pos_arg, 0}
{
};
FLDV_(SymbolType type_arg, int pos_arg, int lead_lag_arg) :
TagWithThreeArguments::TagWithThreeArguments{Tags::FLDV, type_arg, pos_arg, lead_lag_arg}
{
@ -699,10 +707,6 @@ public:
class FSTPV_ : public TagWithThreeArguments<SymbolType, int, int>
{
public:
FSTPV_(SymbolType type_arg, int pos_arg) :
TagWithThreeArguments::TagWithThreeArguments{Tags::FSTPV, type_arg, pos_arg, 0}
{
};
FSTPV_(SymbolType type_arg, int pos_arg, int lead_lag_arg) :
TagWithThreeArguments::TagWithThreeArguments{Tags::FSTPV, type_arg, pos_arg, lead_lag_arg}
{
@ -733,17 +737,18 @@ private:
string func_name;
string arg_func_name;
int add_input_arguments{0}, row{0}, col{0};
ExternalFunctionType function_type{ExternalFunctionType::withoutDerivative};
ExternalFunctionCallType call_type;
public:
FCALL_() : BytecodeInstruction{Tags::FCALL}
{
};
FCALL_(int nb_output_arguments_arg, int nb_input_arguments_arg, string func_name_arg, int indx_arg) :
FCALL_(int nb_output_arguments_arg, int nb_input_arguments_arg, string func_name_arg, int indx_arg, ExternalFunctionCallType call_type_arg) :
BytecodeInstruction{Tags::FCALL},
nb_output_arguments{nb_output_arguments_arg},
nb_input_arguments{nb_input_arguments_arg},
indx{indx_arg},
func_name{move(func_name_arg)}
func_name{move(func_name_arg)},
call_type{call_type_arg}
{
};
string
@ -807,15 +812,10 @@ public:
{
return col;
};
void
set_function_type(ExternalFunctionType arg_function_type)
ExternalFunctionCallType
get_call_type()
{
function_type = arg_function_type;
};
ExternalFunctionType
get_function_type()
{
return function_type;
return call_type;
}
#ifdef BYTECODE_MEX
@ -829,7 +829,7 @@ public:
memcpy(&add_input_arguments, code, sizeof(add_input_arguments)); code += sizeof(add_input_arguments);
memcpy(&row, code, sizeof(row)); code += sizeof(row);
memcpy(&col, code, sizeof(col)); code += sizeof(col);
memcpy(&function_type, code, sizeof(function_type)); code += sizeof(function_type);
memcpy(&call_type, code, sizeof(call_type)); code += sizeof(call_type);
int size;
memcpy(&size, code, sizeof(size)); code += sizeof(size);
char *name = static_cast<char *>(mxMalloc((size+1)*sizeof(char)));

View File

@ -103,17 +103,6 @@ enum class TrinaryOpcode
normpdf
};
enum class ExternalFunctionType
{
withoutDerivative,
withFirstDerivative,
withFirstAndSecondDerivative,
numericalFirstDerivative,
firstDerivative,
numericalSecondDerivative,
secondDerivative
};
enum class PriorDistributions
{
noShape = 0,

View File

@ -1321,10 +1321,7 @@ VariableNode::writeBytecodeOutput(BytecodeWriter &code_file, ExprNodeBytecodeOut
switch (output_type)
{
case ExprNodeBytecodeOutputType::dynamicModel:
if (type == SymbolType::parameter)
code_file << FLDV_{type, tsid};
else
code_file << FLDV_{type, tsid, lag};
code_file << FLDV_{type, tsid, lag};
break;
case ExprNodeBytecodeOutputType::staticModel:
code_file << FLDSV_{type, tsid};
@ -1333,10 +1330,7 @@ VariableNode::writeBytecodeOutput(BytecodeWriter &code_file, ExprNodeBytecodeOut
code_file << FLDVS_{type, tsid};
break;
case ExprNodeBytecodeOutputType::dynamicAssignmentLHS:
if (type == SymbolType::parameter)
code_file << FSTPV_{type, tsid};
else
code_file << FSTPV_{type, tsid, lag};
code_file << FSTPV_{type, tsid, lag};
break;
case ExprNodeBytecodeOutputType::staticAssignmentLHS:
code_file << FSTPSV_{type, tsid};
@ -6688,7 +6682,7 @@ AbstractExternalFunctionNode::getChainRuleDerivative(int deriv_id, const map<int
return composeDerivatives(dargs);
}
int
void
AbstractExternalFunctionNode::writeBytecodeExternalFunctionArguments(BytecodeWriter &code_file,
ExprNodeBytecodeOutputType output_type,
const temporary_terms_t &temporary_terms,
@ -6698,7 +6692,6 @@ AbstractExternalFunctionNode::writeBytecodeExternalFunctionArguments(BytecodeWri
for (auto argument : arguments)
argument->writeBytecodeOutput(code_file, output_type, temporary_terms,
temporary_terms_idxs, tef_terms);
return static_cast<int>(arguments.size());
}
void
@ -7286,31 +7279,30 @@ ExternalFunctionNode::writeBytecodeExternalFunctionOutput(BytecodeWriter &code_f
int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
int nb_output_arguments{0};
writeBytecodeExternalFunctionArguments(code_file, output_type, temporary_terms,
temporary_terms_idxs, tef_terms);
int nb_output_arguments;
ExternalFunctionCallType call_type;
if (symb_id == first_deriv_symb_id
&& symb_id == second_deriv_symb_id)
nb_output_arguments = 3;
else if (symb_id == first_deriv_symb_id)
nb_output_arguments = 2;
else
nb_output_arguments = 1;
int nb_input_arguments{writeBytecodeExternalFunctionArguments(code_file, output_type, temporary_terms,
temporary_terms_idxs, tef_terms)};
FCALL_ fcall{nb_output_arguments, nb_input_arguments, datatree.symbol_table.getName(symb_id), indx};
switch (nb_output_arguments)
{
case 1:
fcall.set_function_type(ExternalFunctionType::withoutDerivative);
break;
case 2:
fcall.set_function_type(ExternalFunctionType::withFirstDerivative);
break;
case 3:
fcall.set_function_type(ExternalFunctionType::withFirstAndSecondDerivative);
break;
nb_output_arguments = 3;
call_type = ExternalFunctionCallType::levelWithFirstAndSecondDerivative;
}
code_file << fcall << FSTPTEF_{indx};
else if (symb_id == first_deriv_symb_id)
{
nb_output_arguments = 2;
call_type = ExternalFunctionCallType::levelWithFirstDerivative;
}
else
{
nb_output_arguments = 1;
call_type = ExternalFunctionCallType::levelWithoutDerivative;
}
code_file << FCALL_{nb_output_arguments, static_cast<int>(arguments.size()), datatree.symbol_table.getName(symb_id), indx, call_type}
<< FSTPTEF_{indx};
}
}
@ -7807,32 +7799,31 @@ FirstDerivExternalFunctionNode::writeBytecodeExternalFunctionOutput(BytecodeWrit
if (alreadyWrittenAsTefTerm(first_deriv_symb_id, tef_terms))
return;
int nb_add_input_arguments{writeBytecodeExternalFunctionArguments(code_file, output_type, temporary_terms,
temporary_terms_idxs, tef_terms)};
if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
writeBytecodeExternalFunctionArguments(code_file, output_type, temporary_terms,
temporary_terms_idxs, tef_terms);
if (int indx = getIndxInTefTerms(symb_id, tef_terms);
first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
{
int nb_input_arguments{0};
int nb_output_arguments{1};
int indx = getIndxInTefTerms(symb_id, tef_terms);
FCALL_ fcall{nb_output_arguments, nb_input_arguments, "jacob_element", indx};
FCALL_ fcall{nb_output_arguments, nb_input_arguments, "jacob_element", indx,
ExternalFunctionCallType::numericalFirstDerivative};
fcall.set_arg_func_name(datatree.symbol_table.getName(symb_id));
fcall.set_row(inputIndex);
fcall.set_nb_add_input_arguments(nb_add_input_arguments);
fcall.set_function_type(ExternalFunctionType::numericalFirstDerivative);
fcall.set_nb_add_input_arguments(static_cast<int>(arguments.size()));
code_file << fcall << FSTPTEFD_{indx, inputIndex};
}
else
{
tef_terms[{ first_deriv_symb_id, arguments }] = static_cast<int>(tef_terms.size());
int indx = getIndxInTefTerms(symb_id, tef_terms);
int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
int nb_output_arguments{1};
FCALL_ fcall{nb_output_arguments, nb_add_input_arguments, datatree.symbol_table.getName(first_deriv_symb_id), indx};
fcall.set_function_type(ExternalFunctionType::firstDerivative);
code_file << fcall << FSTPTEFD_{indx, inputIndex};
code_file << FCALL_{nb_output_arguments, static_cast<int>(arguments.size()), datatree.symbol_table.getName(first_deriv_symb_id), indx, ExternalFunctionCallType::separatelyProvidedFirstDerivative}
<< FSTPTEFD_{indx, inputIndex};
}
}
@ -8212,31 +8203,25 @@ SecondDerivExternalFunctionNode::writeBytecodeExternalFunctionOutput(BytecodeWri
if (alreadyWrittenAsTefTerm(second_deriv_symb_id, tef_terms))
return;
int nb_add_input_arguments{writeBytecodeExternalFunctionArguments(code_file, output_type, temporary_terms,
temporary_terms_idxs, tef_terms)};
if (second_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
writeBytecodeExternalFunctionArguments(code_file, output_type, temporary_terms,
temporary_terms_idxs, tef_terms);
if (int indx = getIndxInTefTerms(symb_id, tef_terms);
second_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
{
int nb_input_arguments{0};
int nb_output_arguments{1};
int indx = getIndxInTefTerms(symb_id, tef_terms);
FCALL_ fcall{nb_output_arguments, nb_input_arguments, "hess_element", indx};
FCALL_ fcall{1, 0, "hess_element", indx, ExternalFunctionCallType::numericalSecondDerivative};
fcall.set_arg_func_name(datatree.symbol_table.getName(symb_id));
fcall.set_row(inputIndex1);
fcall.set_col(inputIndex2);
fcall.set_nb_add_input_arguments(nb_add_input_arguments);
fcall.set_function_type(ExternalFunctionType::numericalSecondDerivative);
fcall.set_nb_add_input_arguments(static_cast<int>(arguments.size()));
code_file << fcall << FSTPTEFDD_{indx, inputIndex1, inputIndex2};
}
else
{
tef_terms[{ second_deriv_symb_id, arguments }] = static_cast<int>(tef_terms.size());
int indx = getIndxInTefTerms(symb_id, tef_terms);
int nb_output_arguments{1};
FCALL_ fcall{nb_output_arguments, nb_add_input_arguments, datatree.symbol_table.getName(second_deriv_symb_id), indx};
fcall.set_function_type(ExternalFunctionType::secondDerivative);
code_file << fcall << FSTPTEFDD_{indx, inputIndex1, inputIndex2};
code_file << FCALL_{1, static_cast<int>(arguments.size()), datatree.symbol_table.getName(second_deriv_symb_id), indx, ExternalFunctionCallType::separatelyProvidedSecondDerivative}
<< FSTPTEFDD_{indx, inputIndex1, inputIndex2};
}
}

View File

@ -1320,11 +1320,11 @@ protected:
void writeExternalFunctionArguments(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const;
void writeJsonASTExternalFunctionArguments(ostream &output) const;
void writeJsonExternalFunctionArguments(ostream &output, const temporary_terms_t &temporary_terms, const deriv_node_temp_terms_t &tef_terms, bool isdynamic) const;
int writeBytecodeExternalFunctionArguments(BytecodeWriter &code_file,
ExprNodeBytecodeOutputType output_type,
const temporary_terms_t &temporary_terms,
const temporary_terms_idxs_t &temporary_terms_idxs,
const deriv_node_temp_terms_t &tef_terms) const;
void writeBytecodeExternalFunctionArguments(BytecodeWriter &code_file,
ExprNodeBytecodeOutputType output_type,
const temporary_terms_t &temporary_terms,
const temporary_terms_idxs_t &temporary_terms_idxs,
const deriv_node_temp_terms_t &tef_terms) const;
/*! Returns a predicate that tests whether an other ExprNode is an external
function which is computed by the same external function call (i.e. it has
the same so-called "Tef" index) */