diff --git a/matlab/subscript_get.m b/matlab/subscript_get.m deleted file mode 100644 index bd8e052f6..000000000 --- a/matlab/subscript_get.m +++ /dev/null @@ -1,61 +0,0 @@ -function deriv = subscript_get(nargsout, func, args, varargin) -% function deriv = subscript_get(nargsout, func, args, varargin) -% returns the appropriate entry of the return argument from a user-defined function -% which returns either the jacobian or hessian (or both) -% -% INPUTS -% nargsout [int] integer indicating the number of the return argument containing the jac/hess -% func [function handle] associated with the function -% args [cell array] arguments provided to func -% varargin [cell array] arguments showing the index (or indices) of the element to be returned -% -% OUTPUTS -% deriv [double] the (element1,element2) entry of the hessian -% -% SPECIAL REQUIREMENTS -% none - -% Copyright (C) 2010 Dynare Team -% -% This file is part of Dynare. -% -% Dynare is free software: you can redistribute it and/or modify -% it under the terms of the GNU General Public License as published by -% the Free Software Foundation, either version 3 of the License, or -% (at your option) any later version. -% -% Dynare is distributed in the hope that it will be useful, -% but WITHOUT ANY WARRANTY; without even the implied warranty of -% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -% GNU General Public License for more details. -% -% You should have received a copy of the GNU General Public License -% along with Dynare. If not, see . - -switch size(varargin,2) - case 1 %first deriv - switch nargsout - case 1 - [outargs{1}] = func(args{:}); - deriv = outargs{1}(varargin{1}); - case 2 - [outargs{1} outargs{2}] = func(args{:}); - deriv = outargs{2}(varargin{1}); - otherwise - error('Wrong number of output arguments (%d) passed to subscript_get().',nargsout); - end - case 2 %second deriv - switch nargsout - case 1 - [outargs{1}] = func(args{:}); - deriv = outargs{1}(varargin{1},varargin{2}); - case 3 - [outargs{1} outargs{2} outargs{3}] = func(args{:}); - deriv = outargs{3}(varargin{1},varargin{2}); - otherwise - error('Wrong number of output arguments (%d) passed to subscript_get().',nargsout); - end - otherwise - error('Wrong number of indices (%d) was passed to subscript_get().',size(varargin,2)); -end -end diff --git a/preprocessor/ExprNode.cc b/preprocessor/ExprNode.cc index 3ed558c9d..d64a00f9a 100644 --- a/preprocessor/ExprNode.cc +++ b/preprocessor/ExprNode.cc @@ -144,6 +144,21 @@ ExprNode::writeOutput(ostream &output) writeOutput(output, oMatlabOutsideModel, temporary_terms_type()); } +void +ExprNode::writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const +{ + deriv_node_temp_terms_type tef_terms; + writeOutput(output, output_type, temporary_terms, tef_terms); +} + +void +ExprNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type, + const temporary_terms_type &temporary_terms, + deriv_node_temp_terms_type &tef_terms) const +{ + // Nothing to do +} + VariableNode * ExprNode::createEndoLeadAuxiliaryVarForMyself(subst_table_t &subst_table, vector &neweqs) const { @@ -247,7 +262,8 @@ NumConstNode::collectTemporary_terms(const temporary_terms_type &temporary_terms void NumConstNode::writeOutput(ostream &output, ExprNodeOutputType output_type, - const temporary_terms_type &temporary_terms) const + const temporary_terms_type &temporary_terms, + deriv_node_temp_terms_type &tef_terms) const { temporary_terms_type::const_iterator it = temporary_terms.find(const_cast(this)); if (it != temporary_terms.end()) @@ -433,7 +449,8 @@ VariableNode::collectTemporary_terms(const temporary_terms_type &temporary_terms void VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type, - const temporary_terms_type &temporary_terms) const + const temporary_terms_type &temporary_terms, + deriv_node_temp_terms_type &tef_terms) const { // If node is a temporary term temporary_terms_type::const_iterator it = temporary_terms.find(const_cast(this)); @@ -1234,7 +1251,8 @@ UnaryOpNode::collectTemporary_terms(const temporary_terms_type &temporary_terms, void UnaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type, - const temporary_terms_type &temporary_terms) const + const temporary_terms_type &temporary_terms, + deriv_node_temp_terms_type &tef_terms) const { // If node is a temporary term temporary_terms_type::const_iterator it = temporary_terms.find(const_cast(this)); @@ -1359,6 +1377,14 @@ UnaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type, output << RIGHT_PAR(output_type); } +void +UnaryOpNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type, + const temporary_terms_type &temporary_terms, + deriv_node_temp_terms_type &tef_terms) const +{ + arg->writeExternalFunctionOutput(output, output_type, temporary_terms, tef_terms); +} + double UnaryOpNode::eval_opcode(UnaryOpcode op_code, double v) throw (EvalException) { @@ -2126,9 +2152,17 @@ BinaryOpNode::collectTemporary_terms(const temporary_terms_type &temporary_terms } } +void +BinaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const +{ + deriv_node_temp_terms_type tef_terms; + writeOutput(output, output_type, temporary_terms, tef_terms); +} + void BinaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type, - const temporary_terms_type &temporary_terms) const + const temporary_terms_type &temporary_terms, + deriv_node_temp_terms_type &tef_terms) const { // If current node is a temporary term temporary_terms_type::const_iterator it = temporary_terms.find(const_cast(this)); @@ -2286,6 +2320,15 @@ BinaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type, output << RIGHT_PAR(output_type); } +void +BinaryOpNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type, + const temporary_terms_type &temporary_terms, + deriv_node_temp_terms_type &tef_terms) const +{ + arg1->writeExternalFunctionOutput(output, output_type, temporary_terms, tef_terms); + arg2->writeExternalFunctionOutput(output, output_type, temporary_terms, tef_terms); +} + void BinaryOpNode::collectVariables(SymbolType type_arg, set > &result) const { @@ -3005,7 +3048,8 @@ TrinaryOpNode::collectTemporary_terms(const temporary_terms_type &temporary_term void TrinaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type, - const temporary_terms_type &temporary_terms) const + const temporary_terms_type &temporary_terms, + deriv_node_temp_terms_type &tef_terms) const { // If current node is a temporary term temporary_terms_type::const_iterator it = temporary_terms.find(const_cast(this)); @@ -3043,6 +3087,16 @@ TrinaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type, } } +void +TrinaryOpNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type, + const temporary_terms_type &temporary_terms, + deriv_node_temp_terms_type &tef_terms) const +{ + arg1->writeExternalFunctionOutput(output, output_type, temporary_terms, tef_terms); + arg2->writeExternalFunctionOutput(output, output_type, temporary_terms, tef_terms); + arg3->writeExternalFunctionOutput(output, output_type, temporary_terms, tef_terms); +} + void TrinaryOpNode::collectVariables(SymbolType type_arg, set > &result) const { @@ -3242,22 +3296,81 @@ ExternalFunctionNode::computeTemporaryTerms(map &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const { + temporary_terms.insert(const_cast(this)); } void -ExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType output_type, - const temporary_terms_type &temporary_terms) const +ExternalFunctionNode::writeExternalFunctionArguments(ostream &output, ExprNodeOutputType output_type, + const temporary_terms_type &temporary_terms, + deriv_node_temp_terms_type &tef_terms) const { - output << datatree.symbol_table.getName(symb_id) << "("; for (vector::const_iterator it = arguments.begin(); it != arguments.end(); it++) { if (it != arguments.begin()) output << ","; - (*it)->writeOutput(output, output_type, temporary_terms); + (*it)->writeOutput(output, output_type, temporary_terms, tef_terms); + } +} + +void +ExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType output_type, + const temporary_terms_type &temporary_terms, + deriv_node_temp_terms_type &tef_terms) const +{ + if (output_type == oMatlabOutsideModel) + { + output << datatree.symbol_table.getName(symb_id) << "("; + writeExternalFunctionArguments(output, output_type, temporary_terms, tef_terms); + output << ")"; + return; + } + + temporary_terms_type::const_iterator it = temporary_terms.find(const_cast(this)); + if (it != temporary_terms.end()) + { + if (output_type == oMatlabDynamicModelSparse) + output << "T" << idx << "(it_)"; + else + output << "T" << idx; + return; + } + + output << "TEF_" << getIndxInTefTerms(symb_id, tef_terms); +} + +void +ExternalFunctionNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type, + const temporary_terms_type &temporary_terms, + deriv_node_temp_terms_type &tef_terms) const +{ + int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id); + assert(first_deriv_symb_id != eExtFunSetButNoNameProvided); + + for (vector::const_iterator it = arguments.begin(); + it != arguments.end(); it++) + (*it)->writeExternalFunctionOutput(output, output_type, temporary_terms, tef_terms); + + if (!alreadyWrittenAsTefTerm(symb_id, tef_terms)) + { + tef_terms[make_pair(symb_id, arguments)] = (int) tef_terms.size(); + int indx = getIndxInTefTerms(symb_id, tef_terms); + int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id); + assert(second_deriv_symb_id != eExtFunSetButNoNameProvided); + + if (symb_id == first_deriv_symb_id && + symb_id == second_deriv_symb_id) + output << "[TEF_" << indx << " TEFD_"<< indx << " TEFDD_"<< indx << "] = "; + else if (symb_id == first_deriv_symb_id) + output << "[TEF_" << indx << " TEFD_"<< indx << "] = "; + else + output << "TEF_" << indx << " = "; + + output << datatree.symbol_table.getName(symb_id) << "("; + writeExternalFunctionArguments(output, output_type, temporary_terms, tef_terms); + output << ");" << endl; } - output << ")"; } void @@ -3423,6 +3536,24 @@ ExternalFunctionNode::buildSimilarExternalFunctionNode(vector &alt_args, return alt_datatree.AddExternalFunction(symb_id, alt_args); } +bool +ExternalFunctionNode::alreadyWrittenAsTefTerm(int the_symb_id, deriv_node_temp_terms_type &tef_terms) const +{ + deriv_node_temp_terms_type::const_iterator it = tef_terms.find(make_pair(the_symb_id, arguments)); + if (it != tef_terms.end()) + return true; + return false; +} + +int +ExternalFunctionNode::getIndxInTefTerms(int the_symb_id, deriv_node_temp_terms_type &tef_terms) const throw (UnknownFunctionNameAndArgs) +{ + deriv_node_temp_terms_type::const_iterator it = tef_terms.find(make_pair(the_symb_id, arguments)); + if (it != tef_terms.end()) + return it->second; + throw UnknownFunctionNameAndArgs(); +} + FirstDerivExternalFunctionNode::FirstDerivExternalFunctionNode(DataTree &datatree_arg, int top_level_symb_id_arg, const vector &arguments_arg, @@ -3434,6 +3565,26 @@ FirstDerivExternalFunctionNode::FirstDerivExternalFunctionNode(DataTree &datatre datatree.first_deriv_external_function_node_map[make_pair(make_pair(arguments,inputIndex),symb_id)] = this; } +void +FirstDerivExternalFunctionNode::computeTemporaryTerms(map &reference_count, + temporary_terms_type &temporary_terms, + bool is_matlab) const +{ + temporary_terms.insert(const_cast(this)); +} + +void +FirstDerivExternalFunctionNode::computeTemporaryTerms(map &reference_count, + temporary_terms_type &temporary_terms, + map > &first_occurence, + int Curr_block, + vector< vector > &v_temporary_terms, + int equation) const +{ + cerr << "FirstDerivExternalFunctionNode::computeTemporaryTerms: not implemented" << endl; + exit(EXIT_FAILURE); +} + NodeID FirstDerivExternalFunctionNode::composeDerivatives(const vector &dargs) { @@ -3442,7 +3593,6 @@ FirstDerivExternalFunctionNode::composeDerivatives(const vector &dargs) if (dargs.at(i) != 0) dNodes.push_back(datatree.AddTimes(dargs.at(i), datatree.AddSecondDerivExternalFunctionNode(symb_id, arguments, inputIndex, i+1))); - NodeID theDeriv = datatree.Zero; for (vector::const_iterator it = dNodes.begin(); it != dNodes.end(); it++) theDeriv = datatree.AddPlus(theDeriv, *it); @@ -3451,40 +3601,63 @@ FirstDerivExternalFunctionNode::composeDerivatives(const vector &dargs) void FirstDerivExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType output_type, - const temporary_terms_type &temporary_terms) const + const temporary_terms_type &temporary_terms, + deriv_node_temp_terms_type &tef_terms) const { - int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id); - switch (first_deriv_symb_id) + assert(output_type != oMatlabOutsideModel); + + // If current node is a temporary term + temporary_terms_type::const_iterator it = temporary_terms.find(const_cast(this)); + if (it != temporary_terms.end()) { - case eExtFunSetButNoNameProvided: - cerr << "ERROR in: FirstDerivExternalFunctionNode::writeOutput(). Please inform Dynare Team." << endl; - exit(EXIT_FAILURE); - case eExtFunNotSet: - output << "jacob_element(@" << datatree.symbol_table.getName(symb_id) << "," - << inputIndex << ",{"; - break; - default: - int numOutArgs; - if (first_deriv_symb_id==symb_id) - numOutArgs = 2; // means that the external function also returns at least the first derivative + if (output_type == oMatlabDynamicModelSparse) + output << "T" << idx << "(it_)"; else - numOutArgs = 1; // means that there is a function that returns only the second derivative - output << "subscript_get(" << numOutArgs << ",@" << datatree.symbol_table.getName(first_deriv_symb_id) << ",{"; + output << "T" << idx; + return; } - for (vector::const_iterator it = arguments.begin(); - it != arguments.end(); it++) + int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id); + assert(first_deriv_symb_id != eExtFunSetButNoNameProvided); + + if (first_deriv_symb_id == symb_id) + output << "TEFD_" << getIndxInTefTerms(symb_id, tef_terms) << "(" << inputIndex << ")"; + else if (first_deriv_symb_id == eExtFunNotSet) + output << "TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex; + else + output << "TEFD_def_" << getIndxInTefTerms(first_deriv_symb_id, tef_terms) + << "(" << inputIndex << ")"; +} + +void +FirstDerivExternalFunctionNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type, + const temporary_terms_type &temporary_terms, + deriv_node_temp_terms_type &tef_terms) const +{ + assert(output_type != oMatlabOutsideModel); + int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id); + assert(first_deriv_symb_id != eExtFunSetButNoNameProvided); + + if (!alreadyWrittenAsTefTerm(first_deriv_symb_id, tef_terms)) { - if (it != arguments.begin()) - output << ","; + if (first_deriv_symb_id == symb_id) + return; + else if (first_deriv_symb_id == eExtFunNotSet) + output << "TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex << " = jacob_element(@" + << datatree.symbol_table.getName(symb_id) << "," << inputIndex << ",{"; + else + { + tef_terms[make_pair(first_deriv_symb_id, arguments)] = (int) tef_terms.size(); + output << "TEFD_def_" << getIndxInTefTerms(first_deriv_symb_id, tef_terms) + << " = " << datatree.symbol_table.getName(first_deriv_symb_id) << "("; + } - (*it)->writeOutput(output, output_type, temporary_terms); + writeExternalFunctionArguments(output, output_type, temporary_terms, tef_terms); + + if (first_deriv_symb_id == eExtFunNotSet) + output << "}"; + output << ");" << endl; } - - output << "}"; - if (first_deriv_symb_id!=eExtFunNotSet) - output << "," << inputIndex; - output << ")"; } SecondDerivExternalFunctionNode::SecondDerivExternalFunctionNode(DataTree &datatree_arg, @@ -3500,6 +3673,26 @@ SecondDerivExternalFunctionNode::SecondDerivExternalFunctionNode(DataTree &datat datatree.second_deriv_external_function_node_map[make_pair(make_pair(arguments,make_pair(inputIndex1,inputIndex2)),symb_id)] = this; } +void +SecondDerivExternalFunctionNode::computeTemporaryTerms(map &reference_count, + temporary_terms_type &temporary_terms, + bool is_matlab) const +{ + temporary_terms.insert(const_cast(this)); +} + +void +SecondDerivExternalFunctionNode::computeTemporaryTerms(map &reference_count, + temporary_terms_type &temporary_terms, + map > &first_occurence, + int Curr_block, + vector< vector > &v_temporary_terms, + int equation) const +{ + cerr << "SecondDerivExternalFunctionNode::computeTemporaryTerms: not implemented" << endl; + exit(EXIT_FAILURE); +} + NodeID SecondDerivExternalFunctionNode::computeDerivative(int deriv_id) { @@ -3509,37 +3702,62 @@ SecondDerivExternalFunctionNode::computeDerivative(int deriv_id) void SecondDerivExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType output_type, - const temporary_terms_type &temporary_terms) const + const temporary_terms_type &temporary_terms, + deriv_node_temp_terms_type &tef_terms) const { - int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id); - switch (second_deriv_symb_id) + assert(output_type != oMatlabOutsideModel); + + // If current node is a temporary term + temporary_terms_type::const_iterator it = temporary_terms.find(const_cast(this)); + if (it != temporary_terms.end()) { - case eExtFunSetButNoNameProvided: - cerr << "ERROR in: FirstDerivExternalFunctionNode::writeOutput(). Please inform Dynare Team." << endl; - exit(EXIT_FAILURE); - case eExtFunNotSet: - output << "hess_element(@" << datatree.symbol_table.getName(symb_id) << "," - << inputIndex1 << "," << inputIndex2 << ",{"; - break; - default: - int numOutArgs; - if (second_deriv_symb_id==symb_id) - numOutArgs = 3; // means that the external function also returns the first and second derivatives + if (output_type == oMatlabDynamicModelSparse) + output << "T" << idx << "(it_)"; else - numOutArgs = 1; // means that there is a function that returns only the second derivative - output << "subscript_get(" << numOutArgs << ",@" << datatree.symbol_table.getName(second_deriv_symb_id) << ",{"; + output << "T" << idx; + return; } - for (vector::const_iterator it = arguments.begin(); - it != arguments.end(); it++) - { - if (it != arguments.begin()) - output << ","; - (*it)->writeOutput(output, output_type, temporary_terms); - } + int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id); + assert(second_deriv_symb_id != eExtFunSetButNoNameProvided); - output << "}"; - if (second_deriv_symb_id!=eExtFunNotSet) - output << "," << inputIndex1 << "," << inputIndex2; - output << ")"; + if (second_deriv_symb_id == symb_id) + output << "TEFDD_" << getIndxInTefTerms(symb_id, tef_terms) << "(" << inputIndex1 << "," << inputIndex2 << ")"; + else if (second_deriv_symb_id == eExtFunNotSet) + output << "TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2; + else + output << "TEFDD_def_" << getIndxInTefTerms(second_deriv_symb_id, tef_terms) + << "(" << inputIndex1 << "," << inputIndex2 << ")"; +} + +void +SecondDerivExternalFunctionNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type, + const temporary_terms_type &temporary_terms, + deriv_node_temp_terms_type &tef_terms) const +{ + assert(output_type != oMatlabOutsideModel); + int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id); + assert(second_deriv_symb_id != eExtFunSetButNoNameProvided); + + if (!alreadyWrittenAsTefTerm(second_deriv_symb_id, tef_terms)) + { + if (second_deriv_symb_id == symb_id) + return; + else if (second_deriv_symb_id == eExtFunNotSet) + output << "TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2 + << " = hess_element(@" << datatree.symbol_table.getName(symb_id) << "," + << inputIndex1 << "," << inputIndex2 << ",{"; + else + { + tef_terms[make_pair(second_deriv_symb_id, arguments)] = (int) tef_terms.size(); + output << "TEFDD_def_" << getIndxInTefTerms(second_deriv_symb_id, tef_terms) + << " = " << datatree.symbol_table.getName(second_deriv_symb_id) << "("; + } + + writeExternalFunctionArguments(output, output_type, temporary_terms, tef_terms); + + if (second_deriv_symb_id == eExtFunNotSet) + output << "}"; + output << ");" << endl; + } } diff --git a/preprocessor/ExprNode.hh b/preprocessor/ExprNode.hh index e6729ca9c..7f9e1c4c5 100644 --- a/preprocessor/ExprNode.hh +++ b/preprocessor/ExprNode.hh @@ -54,6 +54,9 @@ typedef map map_idx_type; /*! The key is a symbol id. Lags are assumed to be null */ typedef map eval_context_type; +//! Type for tracking first/second derivative functions that have already been written as temporary terms +typedef map >, int> deriv_node_temp_terms_type; + //! Possible types of output when writing ExprNode(s) enum ExprNodeOutputType { @@ -170,11 +173,19 @@ public: virtual void computeTemporaryTerms(map &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const; //! Writes output of node, using a Txxx notation for nodes in temporary_terms - virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const = 0; + virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms, deriv_node_temp_terms_type &tef_terms) const = 0; //! Writes output of node (with no temporary terms and with "outside model" output type) void writeOutput(ostream &output); + //! Overloads main writeOutput method to pass an empty value to the tef_terms argument + virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const; + + //! Writes the output for an external function, ensuring that the external function is called as few times as possible using temporary terms + virtual void writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type, + const temporary_terms_type &temporary_terms, + deriv_node_temp_terms_type &tef_terms) const; + //! Computes the set of all variables of a given symbol type in the expression /*! Variables are stored as integer pairs of the form (symb_id, lag). @@ -341,7 +352,7 @@ public: return id; }; virtual void prepareForDerivation(); - virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const; + virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms, deriv_node_temp_terms_type &tef_terms) const; virtual void collectVariables(SymbolType type_arg, set > &result) const; virtual void collectTemporary_terms(const temporary_terms_type &temporary_terms, temporary_terms_inuse_type &temporary_terms_inuse, int Curr_Block) const; virtual double eval(const eval_context_type &eval_context) const throw (EvalException); @@ -372,7 +383,7 @@ private: public: VariableNode(DataTree &datatree_arg, int symb_id_arg, int lag_arg); virtual void prepareForDerivation(); - virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const; + virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms, deriv_node_temp_terms_type &tef_terms) const; virtual void collectVariables(SymbolType type_arg, set > &result) const; virtual void computeTemporaryTerms(map &reference_count, temporary_terms_type &temporary_terms, @@ -421,7 +432,10 @@ public: UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const NodeID arg_arg, const int expectation_information_set_arg, const string &expectation_information_set_name_arg); virtual void prepareForDerivation(); virtual void computeTemporaryTerms(map &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const; - virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const; + virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms, deriv_node_temp_terms_type &tef_terms) const; + virtual void writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type, + const temporary_terms_type &temporary_terms, + deriv_node_temp_terms_type &tef_terms) const; virtual void computeTemporaryTerms(map &reference_count, temporary_terms_type &temporary_terms, map > &first_occurence, @@ -478,6 +492,10 @@ public: virtual int precedence(ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const; virtual void computeTemporaryTerms(map &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const; virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const; + virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms, deriv_node_temp_terms_type &tef_terms) const; + virtual void writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type, + const temporary_terms_type &temporary_terms, + deriv_node_temp_terms_type &tef_terms) const; virtual void computeTemporaryTerms(map &reference_count, temporary_terms_type &temporary_terms, map > &first_occurence, @@ -541,7 +559,10 @@ public: virtual void prepareForDerivation(); virtual int precedence(ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const; virtual void computeTemporaryTerms(map &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const; - virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const; + virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms, deriv_node_temp_terms_type &tef_terms) const; + virtual void writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type, + const temporary_terms_type &temporary_terms, + deriv_node_temp_terms_type &tef_terms) const; virtual void computeTemporaryTerms(map &reference_count, temporary_terms_type &temporary_terms, map > &first_occurence, @@ -576,14 +597,27 @@ private: virtual NodeID computeDerivative(int deriv_id); virtual NodeID composeDerivatives(const vector &dargs); protected: + //! Thrown when trying to access an unknown entry in external_function_node_map + class UnknownFunctionNameAndArgs + { + }; const int symb_id; const vector arguments; + //! Returns true if the given external function has been written as a temporary term + bool alreadyWrittenAsTefTerm(int the_symb_id, deriv_node_temp_terms_type &tef_terms) const; + //! Returns the index in the tef_terms map of this external function + int getIndxInTefTerms(int the_symb_id, deriv_node_temp_terms_type &tef_terms) const throw (UnknownFunctionNameAndArgs); + //! Helper function to write output arguments of any given external function + void writeExternalFunctionArguments(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms, deriv_node_temp_terms_type &tef_terms) const; public: ExternalFunctionNode(DataTree &datatree_arg, int symb_id_arg, const vector &arguments_arg); virtual void prepareForDerivation(); virtual void computeTemporaryTerms(map &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const; - virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const; + virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms, deriv_node_temp_terms_type &tef_terms) const; + virtual void writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type, + const temporary_terms_type &temporary_terms, + deriv_node_temp_terms_type &tef_terms) const; virtual void computeTemporaryTerms(map &reference_count, temporary_terms_type &temporary_terms, map > &first_occurence, @@ -619,7 +653,17 @@ public: int top_level_symb_id_arg, const vector &arguments_arg, int inputIndex_arg); - virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const; + virtual void computeTemporaryTerms(map &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const; + virtual void computeTemporaryTerms(map &reference_count, + temporary_terms_type &temporary_terms, + map > &first_occurence, + int Curr_block, + vector< vector > &v_temporary_terms, + int equation) const; + virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms, deriv_node_temp_terms_type &tef_terms) const; + virtual void writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type, + const temporary_terms_type &temporary_terms, + deriv_node_temp_terms_type &tef_terms) const; }; class SecondDerivExternalFunctionNode : public ExternalFunctionNode @@ -634,7 +678,17 @@ public: const vector &arguments_arg, int inputIndex1_arg, int inputIndex2_arg); - virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const; + virtual void computeTemporaryTerms(map &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const; + virtual void computeTemporaryTerms(map &reference_count, + temporary_terms_type &temporary_terms, + map > &first_occurence, + int Curr_block, + vector< vector > &v_temporary_terms, + int equation) const; + virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms, deriv_node_temp_terms_type &tef_terms) const; + virtual void writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type, + const temporary_terms_type &temporary_terms, + deriv_node_temp_terms_type &tef_terms) const; }; #endif diff --git a/preprocessor/ModelTree.cc b/preprocessor/ModelTree.cc index ebc55c94e..44452c16b 100644 --- a/preprocessor/ModelTree.cc +++ b/preprocessor/ModelTree.cc @@ -992,6 +992,9 @@ ModelTree::writeTemporaryTerms(const temporary_terms_type &tt, ostream &output, // Local var used to keep track of temp nodes already written temporary_terms_type tt2; + // To store the functions that have already been written in the form TEF* = ext_fun(); + deriv_node_temp_terms_type tef_terms; + if (tt.size() > 0 && (IS_C(output_type))) output << "double" << endl; @@ -1001,10 +1004,13 @@ ModelTree::writeTemporaryTerms(const temporary_terms_type &tt, ostream &output, if (IS_C(output_type) && it != tt.begin()) output << "," << endl; - (*it)->writeOutput(output, output_type, tt); + if (dynamic_cast(*it) != NULL) + (*it)->writeExternalFunctionOutput(output, output_type, tt2, tef_terms); + + (*it)->writeOutput(output, output_type, tt, tef_terms); output << " = "; - (*it)->writeOutput(output, output_type, tt2); + (*it)->writeOutput(output, output_type, tt2, tef_terms); // Insert current node into tt2 tt2.insert(*it);