Allow temporary terms to work with external functions

time-shift
Houtan Bastani 2010-03-04 16:40:07 +01:00
parent 38640c4af8
commit 30ea396a08
4 changed files with 352 additions and 135 deletions

View File

@ -1,61 +0,0 @@
function deriv = subscript_get(nargsout, func, args, varargin)
% function deriv = subscript_get(nargsout, func, args, varargin)
% returns the appropriate entry of the return argument from a user-defined function
% which returns either the jacobian or hessian (or both)
%
% INPUTS
% nargsout [int] integer indicating the number of the return argument containing the jac/hess
% func [function handle] associated with the function
% args [cell array] arguments provided to func
% varargin [cell array] arguments showing the index (or indices) of the element to be returned
%
% OUTPUTS
% deriv [double] the (element1,element2) entry of the hessian
%
% SPECIAL REQUIREMENTS
% none
% Copyright (C) 2010 Dynare Team
%
% This file is part of Dynare.
%
% Dynare is free software: you can redistribute it and/or modify
% it under the terms of the GNU General Public License as published by
% the Free Software Foundation, either version 3 of the License, or
% (at your option) any later version.
%
% Dynare is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
% GNU General Public License for more details.
%
% You should have received a copy of the GNU General Public License
% along with Dynare. If not, see <http://www.gnu.org/licenses/>.
switch size(varargin,2)
case 1 %first deriv
switch nargsout
case 1
[outargs{1}] = func(args{:});
deriv = outargs{1}(varargin{1});
case 2
[outargs{1} outargs{2}] = func(args{:});
deriv = outargs{2}(varargin{1});
otherwise
error('Wrong number of output arguments (%d) passed to subscript_get().',nargsout);
end
case 2 %second deriv
switch nargsout
case 1
[outargs{1}] = func(args{:});
deriv = outargs{1}(varargin{1},varargin{2});
case 3
[outargs{1} outargs{2} outargs{3}] = func(args{:});
deriv = outargs{3}(varargin{1},varargin{2});
otherwise
error('Wrong number of output arguments (%d) passed to subscript_get().',nargsout);
end
otherwise
error('Wrong number of indices (%d) was passed to subscript_get().',size(varargin,2));
end
end

View File

@ -144,6 +144,21 @@ ExprNode::writeOutput(ostream &output)
writeOutput(output, oMatlabOutsideModel, temporary_terms_type());
}
void
ExprNode::writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const
{
deriv_node_temp_terms_type tef_terms;
writeOutput(output, output_type, temporary_terms, tef_terms);
}
void
ExprNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms,
deriv_node_temp_terms_type &tef_terms) const
{
// Nothing to do
}
VariableNode *
ExprNode::createEndoLeadAuxiliaryVarForMyself(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
{
@ -247,7 +262,8 @@ NumConstNode::collectTemporary_terms(const temporary_terms_type &temporary_terms
void
NumConstNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms) const
const temporary_terms_type &temporary_terms,
deriv_node_temp_terms_type &tef_terms) const
{
temporary_terms_type::const_iterator it = temporary_terms.find(const_cast<NumConstNode *>(this));
if (it != temporary_terms.end())
@ -433,7 +449,8 @@ VariableNode::collectTemporary_terms(const temporary_terms_type &temporary_terms
void
VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms) const
const temporary_terms_type &temporary_terms,
deriv_node_temp_terms_type &tef_terms) const
{
// If node is a temporary term
temporary_terms_type::const_iterator it = temporary_terms.find(const_cast<VariableNode *>(this));
@ -1234,7 +1251,8 @@ UnaryOpNode::collectTemporary_terms(const temporary_terms_type &temporary_terms,
void
UnaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms) const
const temporary_terms_type &temporary_terms,
deriv_node_temp_terms_type &tef_terms) const
{
// If node is a temporary term
temporary_terms_type::const_iterator it = temporary_terms.find(const_cast<UnaryOpNode *>(this));
@ -1359,6 +1377,14 @@ UnaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
output << RIGHT_PAR(output_type);
}
void
UnaryOpNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms,
deriv_node_temp_terms_type &tef_terms) const
{
arg->writeExternalFunctionOutput(output, output_type, temporary_terms, tef_terms);
}
double
UnaryOpNode::eval_opcode(UnaryOpcode op_code, double v) throw (EvalException)
{
@ -2126,9 +2152,17 @@ BinaryOpNode::collectTemporary_terms(const temporary_terms_type &temporary_terms
}
}
void
BinaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const
{
deriv_node_temp_terms_type tef_terms;
writeOutput(output, output_type, temporary_terms, tef_terms);
}
void
BinaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms) const
const temporary_terms_type &temporary_terms,
deriv_node_temp_terms_type &tef_terms) const
{
// If current node is a temporary term
temporary_terms_type::const_iterator it = temporary_terms.find(const_cast<BinaryOpNode *>(this));
@ -2286,6 +2320,15 @@ BinaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
output << RIGHT_PAR(output_type);
}
void
BinaryOpNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms,
deriv_node_temp_terms_type &tef_terms) const
{
arg1->writeExternalFunctionOutput(output, output_type, temporary_terms, tef_terms);
arg2->writeExternalFunctionOutput(output, output_type, temporary_terms, tef_terms);
}
void
BinaryOpNode::collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const
{
@ -3005,7 +3048,8 @@ TrinaryOpNode::collectTemporary_terms(const temporary_terms_type &temporary_term
void
TrinaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms) const
const temporary_terms_type &temporary_terms,
deriv_node_temp_terms_type &tef_terms) const
{
// If current node is a temporary term
temporary_terms_type::const_iterator it = temporary_terms.find(const_cast<TrinaryOpNode *>(this));
@ -3043,6 +3087,16 @@ TrinaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
}
}
void
TrinaryOpNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms,
deriv_node_temp_terms_type &tef_terms) const
{
arg1->writeExternalFunctionOutput(output, output_type, temporary_terms, tef_terms);
arg2->writeExternalFunctionOutput(output, output_type, temporary_terms, tef_terms);
arg3->writeExternalFunctionOutput(output, output_type, temporary_terms, tef_terms);
}
void
TrinaryOpNode::collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const
{
@ -3242,22 +3296,81 @@ ExternalFunctionNode::computeTemporaryTerms(map<NodeID, int> &reference_count,
temporary_terms_type &temporary_terms,
bool is_matlab) const
{
temporary_terms.insert(const_cast<ExternalFunctionNode *>(this));
}
void
ExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms) const
ExternalFunctionNode::writeExternalFunctionArguments(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms,
deriv_node_temp_terms_type &tef_terms) const
{
output << datatree.symbol_table.getName(symb_id) << "(";
for (vector<NodeID>::const_iterator it = arguments.begin();
it != arguments.end(); it++)
{
if (it != arguments.begin())
output << ",";
(*it)->writeOutput(output, output_type, temporary_terms);
(*it)->writeOutput(output, output_type, temporary_terms, tef_terms);
}
}
void
ExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms,
deriv_node_temp_terms_type &tef_terms) const
{
if (output_type == oMatlabOutsideModel)
{
output << datatree.symbol_table.getName(symb_id) << "(";
writeExternalFunctionArguments(output, output_type, temporary_terms, tef_terms);
output << ")";
return;
}
temporary_terms_type::const_iterator it = temporary_terms.find(const_cast<ExternalFunctionNode *>(this));
if (it != temporary_terms.end())
{
if (output_type == oMatlabDynamicModelSparse)
output << "T" << idx << "(it_)";
else
output << "T" << idx;
return;
}
output << "TEF_" << getIndxInTefTerms(symb_id, tef_terms);
}
void
ExternalFunctionNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms,
deriv_node_temp_terms_type &tef_terms) const
{
int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
assert(first_deriv_symb_id != eExtFunSetButNoNameProvided);
for (vector<NodeID>::const_iterator it = arguments.begin();
it != arguments.end(); it++)
(*it)->writeExternalFunctionOutput(output, output_type, temporary_terms, tef_terms);
if (!alreadyWrittenAsTefTerm(symb_id, tef_terms))
{
tef_terms[make_pair(symb_id, arguments)] = (int) tef_terms.size();
int indx = getIndxInTefTerms(symb_id, tef_terms);
int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
assert(second_deriv_symb_id != eExtFunSetButNoNameProvided);
if (symb_id == first_deriv_symb_id &&
symb_id == second_deriv_symb_id)
output << "[TEF_" << indx << " TEFD_"<< indx << " TEFDD_"<< indx << "] = ";
else if (symb_id == first_deriv_symb_id)
output << "[TEF_" << indx << " TEFD_"<< indx << "] = ";
else
output << "TEF_" << indx << " = ";
output << datatree.symbol_table.getName(symb_id) << "(";
writeExternalFunctionArguments(output, output_type, temporary_terms, tef_terms);
output << ");" << endl;
}
output << ")";
}
void
@ -3423,6 +3536,24 @@ ExternalFunctionNode::buildSimilarExternalFunctionNode(vector<NodeID> &alt_args,
return alt_datatree.AddExternalFunction(symb_id, alt_args);
}
bool
ExternalFunctionNode::alreadyWrittenAsTefTerm(int the_symb_id, deriv_node_temp_terms_type &tef_terms) const
{
deriv_node_temp_terms_type::const_iterator it = tef_terms.find(make_pair(the_symb_id, arguments));
if (it != tef_terms.end())
return true;
return false;
}
int
ExternalFunctionNode::getIndxInTefTerms(int the_symb_id, deriv_node_temp_terms_type &tef_terms) const throw (UnknownFunctionNameAndArgs)
{
deriv_node_temp_terms_type::const_iterator it = tef_terms.find(make_pair(the_symb_id, arguments));
if (it != tef_terms.end())
return it->second;
throw UnknownFunctionNameAndArgs();
}
FirstDerivExternalFunctionNode::FirstDerivExternalFunctionNode(DataTree &datatree_arg,
int top_level_symb_id_arg,
const vector<NodeID> &arguments_arg,
@ -3434,6 +3565,26 @@ FirstDerivExternalFunctionNode::FirstDerivExternalFunctionNode(DataTree &datatre
datatree.first_deriv_external_function_node_map[make_pair(make_pair(arguments,inputIndex),symb_id)] = this;
}
void
FirstDerivExternalFunctionNode::computeTemporaryTerms(map<NodeID, int> &reference_count,
temporary_terms_type &temporary_terms,
bool is_matlab) const
{
temporary_terms.insert(const_cast<FirstDerivExternalFunctionNode *>(this));
}
void
FirstDerivExternalFunctionNode::computeTemporaryTerms(map<NodeID, int> &reference_count,
temporary_terms_type &temporary_terms,
map<NodeID, pair<int, int> > &first_occurence,
int Curr_block,
vector< vector<temporary_terms_type> > &v_temporary_terms,
int equation) const
{
cerr << "FirstDerivExternalFunctionNode::computeTemporaryTerms: not implemented" << endl;
exit(EXIT_FAILURE);
}
NodeID
FirstDerivExternalFunctionNode::composeDerivatives(const vector<NodeID> &dargs)
{
@ -3442,7 +3593,6 @@ FirstDerivExternalFunctionNode::composeDerivatives(const vector<NodeID> &dargs)
if (dargs.at(i) != 0)
dNodes.push_back(datatree.AddTimes(dargs.at(i),
datatree.AddSecondDerivExternalFunctionNode(symb_id, arguments, inputIndex, i+1)));
NodeID theDeriv = datatree.Zero;
for (vector<NodeID>::const_iterator it = dNodes.begin(); it != dNodes.end(); it++)
theDeriv = datatree.AddPlus(theDeriv, *it);
@ -3451,40 +3601,63 @@ FirstDerivExternalFunctionNode::composeDerivatives(const vector<NodeID> &dargs)
void
FirstDerivExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms) const
const temporary_terms_type &temporary_terms,
deriv_node_temp_terms_type &tef_terms) const
{
int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
switch (first_deriv_symb_id)
assert(output_type != oMatlabOutsideModel);
// If current node is a temporary term
temporary_terms_type::const_iterator it = temporary_terms.find(const_cast<FirstDerivExternalFunctionNode *>(this));
if (it != temporary_terms.end())
{
case eExtFunSetButNoNameProvided:
cerr << "ERROR in: FirstDerivExternalFunctionNode::writeOutput(). Please inform Dynare Team." << endl;
exit(EXIT_FAILURE);
case eExtFunNotSet:
output << "jacob_element(@" << datatree.symbol_table.getName(symb_id) << ","
<< inputIndex << ",{";
break;
default:
int numOutArgs;
if (first_deriv_symb_id==symb_id)
numOutArgs = 2; // means that the external function also returns at least the first derivative
if (output_type == oMatlabDynamicModelSparse)
output << "T" << idx << "(it_)";
else
numOutArgs = 1; // means that there is a function that returns only the second derivative
output << "subscript_get(" << numOutArgs << ",@" << datatree.symbol_table.getName(first_deriv_symb_id) << ",{";
output << "T" << idx;
return;
}
for (vector<NodeID>::const_iterator it = arguments.begin();
it != arguments.end(); it++)
int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
assert(first_deriv_symb_id != eExtFunSetButNoNameProvided);
if (first_deriv_symb_id == symb_id)
output << "TEFD_" << getIndxInTefTerms(symb_id, tef_terms) << "(" << inputIndex << ")";
else if (first_deriv_symb_id == eExtFunNotSet)
output << "TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex;
else
output << "TEFD_def_" << getIndxInTefTerms(first_deriv_symb_id, tef_terms)
<< "(" << inputIndex << ")";
}
void
FirstDerivExternalFunctionNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms,
deriv_node_temp_terms_type &tef_terms) const
{
assert(output_type != oMatlabOutsideModel);
int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
assert(first_deriv_symb_id != eExtFunSetButNoNameProvided);
if (!alreadyWrittenAsTefTerm(first_deriv_symb_id, tef_terms))
{
if (it != arguments.begin())
output << ",";
if (first_deriv_symb_id == symb_id)
return;
else if (first_deriv_symb_id == eExtFunNotSet)
output << "TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex << " = jacob_element(@"
<< datatree.symbol_table.getName(symb_id) << "," << inputIndex << ",{";
else
{
tef_terms[make_pair(first_deriv_symb_id, arguments)] = (int) tef_terms.size();
output << "TEFD_def_" << getIndxInTefTerms(first_deriv_symb_id, tef_terms)
<< " = " << datatree.symbol_table.getName(first_deriv_symb_id) << "(";
}
(*it)->writeOutput(output, output_type, temporary_terms);
writeExternalFunctionArguments(output, output_type, temporary_terms, tef_terms);
if (first_deriv_symb_id == eExtFunNotSet)
output << "}";
output << ");" << endl;
}
output << "}";
if (first_deriv_symb_id!=eExtFunNotSet)
output << "," << inputIndex;
output << ")";
}
SecondDerivExternalFunctionNode::SecondDerivExternalFunctionNode(DataTree &datatree_arg,
@ -3500,6 +3673,26 @@ SecondDerivExternalFunctionNode::SecondDerivExternalFunctionNode(DataTree &datat
datatree.second_deriv_external_function_node_map[make_pair(make_pair(arguments,make_pair(inputIndex1,inputIndex2)),symb_id)] = this;
}
void
SecondDerivExternalFunctionNode::computeTemporaryTerms(map<NodeID, int> &reference_count,
temporary_terms_type &temporary_terms,
bool is_matlab) const
{
temporary_terms.insert(const_cast<SecondDerivExternalFunctionNode *>(this));
}
void
SecondDerivExternalFunctionNode::computeTemporaryTerms(map<NodeID, int> &reference_count,
temporary_terms_type &temporary_terms,
map<NodeID, pair<int, int> > &first_occurence,
int Curr_block,
vector< vector<temporary_terms_type> > &v_temporary_terms,
int equation) const
{
cerr << "SecondDerivExternalFunctionNode::computeTemporaryTerms: not implemented" << endl;
exit(EXIT_FAILURE);
}
NodeID
SecondDerivExternalFunctionNode::computeDerivative(int deriv_id)
{
@ -3509,37 +3702,62 @@ SecondDerivExternalFunctionNode::computeDerivative(int deriv_id)
void
SecondDerivExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms) const
const temporary_terms_type &temporary_terms,
deriv_node_temp_terms_type &tef_terms) const
{
int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
switch (second_deriv_symb_id)
assert(output_type != oMatlabOutsideModel);
// If current node is a temporary term
temporary_terms_type::const_iterator it = temporary_terms.find(const_cast<SecondDerivExternalFunctionNode *>(this));
if (it != temporary_terms.end())
{
case eExtFunSetButNoNameProvided:
cerr << "ERROR in: FirstDerivExternalFunctionNode::writeOutput(). Please inform Dynare Team." << endl;
exit(EXIT_FAILURE);
case eExtFunNotSet:
output << "hess_element(@" << datatree.symbol_table.getName(symb_id) << ","
<< inputIndex1 << "," << inputIndex2 << ",{";
break;
default:
int numOutArgs;
if (second_deriv_symb_id==symb_id)
numOutArgs = 3; // means that the external function also returns the first and second derivatives
if (output_type == oMatlabDynamicModelSparse)
output << "T" << idx << "(it_)";
else
numOutArgs = 1; // means that there is a function that returns only the second derivative
output << "subscript_get(" << numOutArgs << ",@" << datatree.symbol_table.getName(second_deriv_symb_id) << ",{";
output << "T" << idx;
return;
}
for (vector<NodeID>::const_iterator it = arguments.begin();
it != arguments.end(); it++)
{
if (it != arguments.begin())
output << ",";
(*it)->writeOutput(output, output_type, temporary_terms);
}
int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
assert(second_deriv_symb_id != eExtFunSetButNoNameProvided);
output << "}";
if (second_deriv_symb_id!=eExtFunNotSet)
output << "," << inputIndex1 << "," << inputIndex2;
output << ")";
if (second_deriv_symb_id == symb_id)
output << "TEFDD_" << getIndxInTefTerms(symb_id, tef_terms) << "(" << inputIndex1 << "," << inputIndex2 << ")";
else if (second_deriv_symb_id == eExtFunNotSet)
output << "TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2;
else
output << "TEFDD_def_" << getIndxInTefTerms(second_deriv_symb_id, tef_terms)
<< "(" << inputIndex1 << "," << inputIndex2 << ")";
}
void
SecondDerivExternalFunctionNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms,
deriv_node_temp_terms_type &tef_terms) const
{
assert(output_type != oMatlabOutsideModel);
int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
assert(second_deriv_symb_id != eExtFunSetButNoNameProvided);
if (!alreadyWrittenAsTefTerm(second_deriv_symb_id, tef_terms))
{
if (second_deriv_symb_id == symb_id)
return;
else if (second_deriv_symb_id == eExtFunNotSet)
output << "TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2
<< " = hess_element(@" << datatree.symbol_table.getName(symb_id) << ","
<< inputIndex1 << "," << inputIndex2 << ",{";
else
{
tef_terms[make_pair(second_deriv_symb_id, arguments)] = (int) tef_terms.size();
output << "TEFDD_def_" << getIndxInTefTerms(second_deriv_symb_id, tef_terms)
<< " = " << datatree.symbol_table.getName(second_deriv_symb_id) << "(";
}
writeExternalFunctionArguments(output, output_type, temporary_terms, tef_terms);
if (second_deriv_symb_id == eExtFunNotSet)
output << "}";
output << ");" << endl;
}
}

View File

@ -54,6 +54,9 @@ typedef map<int, int> map_idx_type;
/*! The key is a symbol id. Lags are assumed to be null */
typedef map<int, double> eval_context_type;
//! Type for tracking first/second derivative functions that have already been written as temporary terms
typedef map<pair<int, vector<NodeID> >, int> deriv_node_temp_terms_type;
//! Possible types of output when writing ExprNode(s)
enum ExprNodeOutputType
{
@ -170,11 +173,19 @@ public:
virtual void computeTemporaryTerms(map<NodeID, int> &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const;
//! Writes output of node, using a Txxx notation for nodes in temporary_terms
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const = 0;
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms, deriv_node_temp_terms_type &tef_terms) const = 0;
//! Writes output of node (with no temporary terms and with "outside model" output type)
void writeOutput(ostream &output);
//! Overloads main writeOutput method to pass an empty value to the tef_terms argument
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;
//! Writes the output for an external function, ensuring that the external function is called as few times as possible using temporary terms
virtual void writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms,
deriv_node_temp_terms_type &tef_terms) const;
//! Computes the set of all variables of a given symbol type in the expression
/*!
Variables are stored as integer pairs of the form (symb_id, lag).
@ -341,7 +352,7 @@ public:
return id;
};
virtual void prepareForDerivation();
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms, deriv_node_temp_terms_type &tef_terms) const;
virtual void collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
virtual void collectTemporary_terms(const temporary_terms_type &temporary_terms, temporary_terms_inuse_type &temporary_terms_inuse, int Curr_Block) const;
virtual double eval(const eval_context_type &eval_context) const throw (EvalException);
@ -372,7 +383,7 @@ private:
public:
VariableNode(DataTree &datatree_arg, int symb_id_arg, int lag_arg);
virtual void prepareForDerivation();
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms, deriv_node_temp_terms_type &tef_terms) const;
virtual void collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
virtual void computeTemporaryTerms(map<NodeID, int> &reference_count,
temporary_terms_type &temporary_terms,
@ -421,7 +432,10 @@ public:
UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const NodeID arg_arg, const int expectation_information_set_arg, const string &expectation_information_set_name_arg);
virtual void prepareForDerivation();
virtual void computeTemporaryTerms(map<NodeID, int> &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const;
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms, deriv_node_temp_terms_type &tef_terms) const;
virtual void writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms,
deriv_node_temp_terms_type &tef_terms) const;
virtual void computeTemporaryTerms(map<NodeID, int> &reference_count,
temporary_terms_type &temporary_terms,
map<NodeID, pair<int, int> > &first_occurence,
@ -478,6 +492,10 @@ public:
virtual int precedence(ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;
virtual void computeTemporaryTerms(map<NodeID, int> &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const;
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms, deriv_node_temp_terms_type &tef_terms) const;
virtual void writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms,
deriv_node_temp_terms_type &tef_terms) const;
virtual void computeTemporaryTerms(map<NodeID, int> &reference_count,
temporary_terms_type &temporary_terms,
map<NodeID, pair<int, int> > &first_occurence,
@ -541,7 +559,10 @@ public:
virtual void prepareForDerivation();
virtual int precedence(ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;
virtual void computeTemporaryTerms(map<NodeID, int> &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const;
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms, deriv_node_temp_terms_type &tef_terms) const;
virtual void writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms,
deriv_node_temp_terms_type &tef_terms) const;
virtual void computeTemporaryTerms(map<NodeID, int> &reference_count,
temporary_terms_type &temporary_terms,
map<NodeID, pair<int, int> > &first_occurence,
@ -576,14 +597,27 @@ private:
virtual NodeID computeDerivative(int deriv_id);
virtual NodeID composeDerivatives(const vector<NodeID> &dargs);
protected:
//! Thrown when trying to access an unknown entry in external_function_node_map
class UnknownFunctionNameAndArgs
{
};
const int symb_id;
const vector<NodeID> arguments;
//! Returns true if the given external function has been written as a temporary term
bool alreadyWrittenAsTefTerm(int the_symb_id, deriv_node_temp_terms_type &tef_terms) const;
//! Returns the index in the tef_terms map of this external function
int getIndxInTefTerms(int the_symb_id, deriv_node_temp_terms_type &tef_terms) const throw (UnknownFunctionNameAndArgs);
//! Helper function to write output arguments of any given external function
void writeExternalFunctionArguments(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms, deriv_node_temp_terms_type &tef_terms) const;
public:
ExternalFunctionNode(DataTree &datatree_arg, int symb_id_arg,
const vector<NodeID> &arguments_arg);
virtual void prepareForDerivation();
virtual void computeTemporaryTerms(map<NodeID, int> &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const;
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms, deriv_node_temp_terms_type &tef_terms) const;
virtual void writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms,
deriv_node_temp_terms_type &tef_terms) const;
virtual void computeTemporaryTerms(map<NodeID, int> &reference_count,
temporary_terms_type &temporary_terms,
map<NodeID, pair<int, int> > &first_occurence,
@ -619,7 +653,17 @@ public:
int top_level_symb_id_arg,
const vector<NodeID> &arguments_arg,
int inputIndex_arg);
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;
virtual void computeTemporaryTerms(map<NodeID, int> &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const;
virtual void computeTemporaryTerms(map<NodeID, int> &reference_count,
temporary_terms_type &temporary_terms,
map<NodeID, pair<int, int> > &first_occurence,
int Curr_block,
vector< vector<temporary_terms_type> > &v_temporary_terms,
int equation) const;
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms, deriv_node_temp_terms_type &tef_terms) const;
virtual void writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms,
deriv_node_temp_terms_type &tef_terms) const;
};
class SecondDerivExternalFunctionNode : public ExternalFunctionNode
@ -634,7 +678,17 @@ public:
const vector<NodeID> &arguments_arg,
int inputIndex1_arg,
int inputIndex2_arg);
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;
virtual void computeTemporaryTerms(map<NodeID, int> &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const;
virtual void computeTemporaryTerms(map<NodeID, int> &reference_count,
temporary_terms_type &temporary_terms,
map<NodeID, pair<int, int> > &first_occurence,
int Curr_block,
vector< vector<temporary_terms_type> > &v_temporary_terms,
int equation) const;
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms, deriv_node_temp_terms_type &tef_terms) const;
virtual void writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_type &temporary_terms,
deriv_node_temp_terms_type &tef_terms) const;
};
#endif

View File

@ -992,6 +992,9 @@ ModelTree::writeTemporaryTerms(const temporary_terms_type &tt, ostream &output,
// Local var used to keep track of temp nodes already written
temporary_terms_type tt2;
// To store the functions that have already been written in the form TEF* = ext_fun();
deriv_node_temp_terms_type tef_terms;
if (tt.size() > 0 && (IS_C(output_type)))
output << "double" << endl;
@ -1001,10 +1004,13 @@ ModelTree::writeTemporaryTerms(const temporary_terms_type &tt, ostream &output,
if (IS_C(output_type) && it != tt.begin())
output << "," << endl;
(*it)->writeOutput(output, output_type, tt);
if (dynamic_cast<ExternalFunctionNode *>(*it) != NULL)
(*it)->writeExternalFunctionOutput(output, output_type, tt2, tef_terms);
(*it)->writeOutput(output, output_type, tt, tef_terms);
output << " = ";
(*it)->writeOutput(output, output_type, tt2);
(*it)->writeOutput(output, output_type, tt2, tef_terms);
// Insert current node into tt2
tt2.insert(*it);