diff --git a/ExprNode.cc b/ExprNode.cc index 861f479c..40c1cd2a 100644 --- a/ExprNode.cc +++ b/ExprNode.cc @@ -3602,6 +3602,22 @@ ExternalFunctionNode::writeExternalFunctionArguments(ostream &output, ExprNodeOu } } +void +ExternalFunctionNode::writePrhs(ostream &output, ExprNodeOutputType output_type, + const temporary_terms_t &temporary_terms, + deriv_node_temp_terms_t &tef_terms, const string &ending) const +{ + output << "mxArray *prhs"<< ending << "[nrhs"<< ending << "];" << endl; + int i = 0; + for (vector::const_iterator it = arguments.begin(); + it != arguments.end(); it++) + { + output << "prhs" << ending << "[" << i++ << "] = mxCreateDoubleScalar("; // All external_function arguments are scalars + (*it)->writeOutput(output, output_type, temporary_terms, tef_terms); + output << ");" << endl; + } +} + void ExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, @@ -3625,6 +3641,8 @@ ExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType output_typ return; } + if (IS_C(output_type)) + output << "*"; output << "TEF_" << getIndxInTefTerms(symb_id, tef_terms); } @@ -3647,17 +3665,61 @@ ExternalFunctionNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutpu int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id); assert(second_deriv_symb_id != eExtFunSetButNoNameProvided); - if (symb_id == first_deriv_symb_id && - symb_id == second_deriv_symb_id) - output << "[TEF_" << indx << " TEFD_"<< indx << " TEFDD_"<< indx << "] = "; - else if (symb_id == first_deriv_symb_id) - output << "[TEF_" << indx << " TEFD_"<< indx << "] = "; - else - output << "TEF_" << indx << " = "; + if (IS_C(output_type)) + { + stringstream ending; + ending << "_tef_" << getIndxInTefTerms(symb_id, tef_terms); + if (symb_id == first_deriv_symb_id && + symb_id == second_deriv_symb_id) + output << "int nlhs" << ending.str() << " = 3;" << endl + << "double *TEF_" << indx << ", " + << "*TEFD_" << indx << ", " + << "*TEFDD_" << indx << ";" << endl; + else if (symb_id == first_deriv_symb_id) + output << "int nlhs" << ending.str() << " = 2;" << endl + << "double *TEF_" << indx << ", " + << "*TEFD_" << indx << "; " << endl; + else + output << "int nlhs" << ending.str() << " = 1;" << endl + << "double *TEF_" << indx << ";" << endl; - output << datatree.symbol_table.getName(symb_id) << "("; - writeExternalFunctionArguments(output, output_type, temporary_terms, tef_terms); - output << ");" << endl; + output << "mxArray *plhs" << ending.str()<< "[nlhs"<< ending.str() << "];" << endl; + output << "int nrhs" << ending.str()<< " = " << arguments.size() << ";" << endl; + writePrhs(output, output_type, temporary_terms, tef_terms, ending.str()); + + output << "mexCallMATLAB(" + << "nlhs" << ending.str() << ", " + << "plhs" << ending.str() << ", " + << "nrhs" << ending.str() << ", " + << "prhs" << ending.str() << ", \"" + << datatree.symbol_table.getName(symb_id) << "\");" << endl; + + if (symb_id == first_deriv_symb_id && + symb_id == second_deriv_symb_id) + output << "TEF_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl + << "TEFD_" << indx << " = mxGetPr(plhs" << ending.str() << "[1]);" << endl + << "TEFDD_" << indx << " = mxGetPr(plhs" << ending.str() << "[2]);" << endl + << "int TEFDD_" << indx << "_nrows = (int)mxGetM(plhs" << ending.str()<< "[2]);" << endl; + else if (symb_id == first_deriv_symb_id) + output << "TEF_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl + << "TEFD_" << indx << " = mxGetPr(plhs" << ending.str() << "[1]);" << endl; + else + output << "TEF_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl; + } + else + { + if (symb_id == first_deriv_symb_id && + symb_id == second_deriv_symb_id) + output << "[TEF_" << indx << " TEFD_"<< indx << " TEFDD_"<< indx << "] = "; + else if (symb_id == first_deriv_symb_id) + output << "[TEF_" << indx << " TEFD_"<< indx << "] = "; + else + output << "TEF_" << indx << " = "; + + output << datatree.symbol_table.getName(symb_id) << "("; + writeExternalFunctionArguments(output, output_type, temporary_terms, tef_terms); + output << ");" << endl; + } } } @@ -3940,13 +4002,22 @@ FirstDerivExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id); assert(first_deriv_symb_id != eExtFunSetButNoNameProvided); + int tmpIndx = inputIndex; + if (IS_C(output_type)) + tmpIndx = tmpIndx - 1; + if (first_deriv_symb_id == symb_id) - output << "TEFD_" << getIndxInTefTerms(symb_id, tef_terms) << "(" << inputIndex << ")"; + output << "TEFD_" << getIndxInTefTerms(symb_id, tef_terms) + << LEFT_ARRAY_SUBSCRIPT(output_type) << tmpIndx << RIGHT_ARRAY_SUBSCRIPT(output_type); else if (first_deriv_symb_id == eExtFunNotSet) - output << "TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex; + { + if (IS_C(output_type)) + output << "*"; + output << "TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex; + } else output << "TEFD_def_" << getIndxInTefTerms(first_deriv_symb_id, tef_terms) - << "(" << inputIndex << ")"; + << LEFT_ARRAY_SUBSCRIPT(output_type) << tmpIndx << RIGHT_ARRAY_SUBSCRIPT(output_type); } void @@ -3958,11 +4029,73 @@ FirstDerivExternalFunctionNode::writeExternalFunctionOutput(ostream &output, Exp int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id); assert(first_deriv_symb_id != eExtFunSetButNoNameProvided); - if (!alreadyWrittenAsTefTerm(first_deriv_symb_id, tef_terms)) + if (first_deriv_symb_id == symb_id || alreadyWrittenAsTefTerm(first_deriv_symb_id, tef_terms)) + return; + + if (IS_C(output_type)) + if (first_deriv_symb_id == eExtFunNotSet) + { + stringstream ending; + ending << "_tefd_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex; + output << "int nlhs" << ending.str() << " = 1;" << endl + << "double *TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex << ";" << endl + << "mxArray *plhs" << ending.str() << "[nlhs"<< ending.str() << "];" << endl + << "int nrhs" << ending.str() << " = 3;" << endl + << "mxArray *prhs" << ending.str() << "[nrhs"<< ending.str() << "];" << endl + << "mwSize dims" << ending.str() << "[2];" << endl; + + output << "dims" << ending.str() << "[0] = 1;" << endl + << "dims" << ending.str() << "[1] = " << arguments.size() << ";" << endl; + + output << "prhs" << ending.str() << "[0] = mxCreateString(\"" << datatree.symbol_table.getName(symb_id) << "\");" << endl + << "prhs" << ending.str() << "[1] = mxCreateDoubleScalar(" << inputIndex << ");"<< endl + << "prhs" << ending.str() << "[2] = mxCreateCellArray(2, dims" << ending.str() << ");"<< endl; + + int i = 0; + for (vector::const_iterator it = arguments.begin(); + it != arguments.end(); it++) + { + output << "mxSetCell(prhs" << ending.str() << "[2], " + << i++ << ", " + << "mxCreateDoubleScalar("; // All external_function arguments are scalars + (*it)->writeOutput(output, output_type, temporary_terms, tef_terms); + output << "));" << endl; + } + + output << "mexCallMATLAB(" + << "nlhs" << ending.str() << ", " + << "plhs" << ending.str() << ", " + << "nrhs" << ending.str() << ", " + << "prhs" << ending.str() << ", \"" + << "jacob_element\");" << endl; + + output << "TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex + << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl; + } + else + { + tef_terms[make_pair(first_deriv_symb_id, arguments)] = (int) tef_terms.size(); + int indx = getIndxInTefTerms(first_deriv_symb_id, tef_terms); + stringstream ending; + ending << "_tefd_def_" << indx; + output << "int nlhs" << ending.str() << " = 1;" << endl + << "double *TEFD_def_" << indx << ";" << endl + << "mxArray *plhs" << ending.str() << "[nlhs"<< ending.str() << "];" << endl + << "int nrhs" << ending.str() << " = " << arguments.size() << ";" << endl; + writePrhs(output, output_type, temporary_terms, tef_terms, ending.str()); + + output << "mexCallMATLAB(" + << "nlhs" << ending.str() << ", " + << "plhs" << ending.str() << ", " + << "nrhs" << ending.str() << ", " + << "prhs" << ending.str() << ", \"" + << datatree.symbol_table.getName(first_deriv_symb_id) << "\");" << endl; + + output << "TEFD_def_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl; + } + else { - if (first_deriv_symb_id == symb_id) - return; - else if (first_deriv_symb_id == eExtFunNotSet) + if (first_deriv_symb_id == eExtFunNotSet) output << "TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex << " = jacob_element('" << datatree.symbol_table.getName(symb_id) << "'," << inputIndex << ",{"; else @@ -4041,13 +4174,37 @@ SecondDerivExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id); assert(second_deriv_symb_id != eExtFunSetButNoNameProvided); + int tmpIndex1 = inputIndex1; + int tmpIndex2 = inputIndex2; + if (IS_C(output_type)) + { + tmpIndex1 = tmpIndex1 - 1; + tmpIndex2 = tmpIndex2 - 1; + } + + int indx = getIndxInTefTerms(symb_id, tef_terms); if (second_deriv_symb_id == symb_id) - output << "TEFDD_" << getIndxInTefTerms(symb_id, tef_terms) << "(" << inputIndex1 << "," << inputIndex2 << ")"; + if (IS_C(output_type)) + output << "TEFDD_" << indx + << LEFT_ARRAY_SUBSCRIPT(output_type) << tmpIndex1 << " * TEFDD_" << indx << "_nrows + " + << tmpIndex2 << RIGHT_ARRAY_SUBSCRIPT(output_type); + else + output << "TEFDD_" << getIndxInTefTerms(symb_id, tef_terms) + << LEFT_ARRAY_SUBSCRIPT(output_type) << tmpIndex1 << "," << tmpIndex2 << RIGHT_ARRAY_SUBSCRIPT(output_type); else if (second_deriv_symb_id == eExtFunNotSet) - output << "TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2; + { + if (IS_C(output_type)) + output << "*"; + output << "TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2; + } else - output << "TEFDD_def_" << getIndxInTefTerms(second_deriv_symb_id, tef_terms) - << "(" << inputIndex1 << "," << inputIndex2 << ")"; + if (IS_C(output_type)) + output << "TEFDD_def_" << getIndxInTefTerms(second_deriv_symb_id, tef_terms) + << LEFT_ARRAY_SUBSCRIPT(output_type) << tmpIndex1 << " * PROBLEM_" << indx << "_nrows" + << tmpIndex2 << RIGHT_ARRAY_SUBSCRIPT(output_type); + else + output << "TEFDD_def_" << getIndxInTefTerms(second_deriv_symb_id, tef_terms) + << LEFT_ARRAY_SUBSCRIPT(output_type) << tmpIndex1 << "," << tmpIndex2 << RIGHT_ARRAY_SUBSCRIPT(output_type); } void @@ -4059,11 +4216,76 @@ SecondDerivExternalFunctionNode::writeExternalFunctionOutput(ostream &output, Ex int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id); assert(second_deriv_symb_id != eExtFunSetButNoNameProvided); - if (!alreadyWrittenAsTefTerm(second_deriv_symb_id, tef_terms)) + if (alreadyWrittenAsTefTerm(second_deriv_symb_id, tef_terms) || + second_deriv_symb_id == symb_id) + return; + + if (IS_C(output_type)) + if (second_deriv_symb_id == eExtFunNotSet) + { + stringstream ending; + ending << "_tefdd_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2; + output << "int nlhs" << ending.str() << " = 1;" << endl + << "double *TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2 << ";" << endl + << "mxArray *plhs" << ending.str() << "[nlhs"<< ending.str() << "];" << endl + << "int nrhs" << ending.str() << " = 4;" << endl + << "mxArray *prhs" << ending.str() << "[nrhs"<< ending.str() << "];" << endl + << "mwSize dims" << ending.str() << "[2];" << endl; + + output << "dims" << ending.str() << "[0] = 1;" << endl + << "dims" << ending.str() << "[1] = " << arguments.size() << ";" << endl; + + output << "prhs" << ending.str() << "[0] = mxCreateString(\"" << datatree.symbol_table.getName(symb_id) << "\");" << endl + << "prhs" << ending.str() << "[1] = mxCreateDoubleScalar(" << inputIndex1 << ");"<< endl + << "prhs" << ending.str() << "[2] = mxCreateDoubleScalar(" << inputIndex2 << ");"<< endl + << "prhs" << ending.str() << "[3] = mxCreateCellArray(2, dims" << ending.str() << ");"<< endl; + + int i = 0; + for (vector::const_iterator it = arguments.begin(); + it != arguments.end(); it++) + { + output << "mxSetCell(prhs" << ending.str() << "[3], " + << i++ << ", " + << "mxCreateDoubleScalar("; // All external_function arguments are scalars + (*it)->writeOutput(output, output_type, temporary_terms, tef_terms); + output << "));" << endl; + } + + output << "mexCallMATLAB(" + << "nlhs" << ending.str() << ", " + << "plhs" << ending.str() << ", " + << "nrhs" << ending.str() << ", " + << "prhs" << ending.str() << ", \"" + << "hess_element\");" << endl; + + output << "TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2 + << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl; + } + else + { + tef_terms[make_pair(second_deriv_symb_id, arguments)] = (int) tef_terms.size(); + int indx = getIndxInTefTerms(second_deriv_symb_id, tef_terms); + stringstream ending; + ending << "_tefdd_def_" << indx; + + output << "int nlhs" << ending.str() << " = 1;" << endl + << "double *TEFDD_def_" << indx << ";" << endl + << "mxArray *plhs" << ending.str() << "[nlhs"<< ending.str() << "];" << endl + << "int nrhs" << ending.str() << " = " << arguments.size() << ";" << endl; + writePrhs(output, output_type, temporary_terms, tef_terms, ending.str()); + + output << "mexCallMATLAB(" + << "nlhs" << ending.str() << ", " + << "plhs" << ending.str() << ", " + << "nrhs" << ending.str() << ", " + << "prhs" << ending.str() << ", \"" + << datatree.symbol_table.getName(second_deriv_symb_id) << "\");" << endl; + + output << "TEFDD_def_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl; + } + else { - if (second_deriv_symb_id == symb_id) - return; - else if (second_deriv_symb_id == eExtFunNotSet) + if (second_deriv_symb_id == eExtFunNotSet) output << "TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2 << " = hess_element('" << datatree.symbol_table.getName(symb_id) << "'," << inputIndex1 << "," << inputIndex2 << ",{"; diff --git a/ExprNode.hh b/ExprNode.hh index a01eb2e3..9a4df0e2 100644 --- a/ExprNode.hh +++ b/ExprNode.hh @@ -702,6 +702,7 @@ public: virtual expr_t decreaseLeadsLagsPredeterminedVariables() const; virtual bool isNumConstNodeEqualTo(double value) const; virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const; + virtual void writePrhs(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, deriv_node_temp_terms_t &tef_terms, const string &ending) const; }; class FirstDerivExternalFunctionNode : public ExternalFunctionNode diff --git a/ModFile.cc b/ModFile.cc index c346f784..f2da960c 100644 --- a/ModFile.cc +++ b/ModFile.cc @@ -145,9 +145,9 @@ ModFile::checkPass() exit(EXIT_FAILURE); } - if ((use_dll || byte_code) && (external_functions_table.get_total_number_of_unique_model_block_external_functions() > 0)) + if (byte_code && (external_functions_table.get_total_number_of_unique_model_block_external_functions() > 0)) { - cerr << "ERROR: In 'model' block, use of external functions is not compatible with 'use_dll' or 'bytecode'" << endl; + cerr << "ERROR: In 'model' block, use of external functions is not compatible with 'bytecode'" << endl; exit(EXIT_FAILURE); } diff --git a/ModelTree.cc b/ModelTree.cc index dbfc62e9..e4beba9d 100644 --- a/ModelTree.cc +++ b/ModelTree.cc @@ -1085,31 +1085,29 @@ ModelTree::writeTemporaryTerms(const temporary_terms_t &tt, ostream &output, // To store the functions that have already been written in the form TEF* = ext_fun(); deriv_node_temp_terms_t tef_terms; - if (tt.size() > 0 && (IS_C(output_type))) - output << "double" << endl; for (temporary_terms_t::const_iterator it = tt.begin(); it != tt.end(); it++) { - if (IS_C(output_type) && it != tt.begin()) - output << "," << endl; - if (dynamic_cast(*it) != NULL) (*it)->writeExternalFunctionOutput(output, output_type, tt2, tef_terms); + if (IS_C(output_type)) + output << "double "; + (*it)->writeOutput(output, output_type, tt, tef_terms); output << " = "; - (*it)->writeOutput(output, output_type, tt2, tef_terms); + if (IS_C(output_type)) + output << ";" << endl; + // Insert current node into tt2 tt2.insert(*it); if (IS_MATLAB(output_type)) output << ";" << endl; } - if (IS_C(output_type)) - output << ";" << endl; } void