diff --git a/preprocessor/DynareBison.yy b/preprocessor/DynareBison.yy index 4eed5b498..2bdbf5327 100644 --- a/preprocessor/DynareBison.yy +++ b/preprocessor/DynareBison.yy @@ -416,7 +416,7 @@ expression : '(' expression ')' | MIN '(' expression COMMA expression ')' { $$ = driver.add_min($3, $5); } | symbol { driver.push_external_function_arg_vector_onto_stack(); } '(' comma_expression ')' - { $$ = driver.add_model_var_or_external_function($1); } + { $$ = driver.add_model_var_or_external_function($1,false); } | NORMCDF '(' expression COMMA expression COMMA expression ')' { $$ = driver.add_normcdf($3, $5, $7); } | NORMCDF '(' expression ')' @@ -562,17 +562,17 @@ hand_side : '(' hand_side ')' | SQRT '(' hand_side ')' { $$ = driver.add_sqrt($3); } | MAX '(' hand_side COMMA hand_side ')' - { $$ = driver.add_max($3, $5); } + { $$ = driver.add_max($3, $5); } | MIN '(' hand_side COMMA hand_side ')' - { $$ = driver.add_min($3, $5); } + { $$ = driver.add_min($3, $5); } | symbol { driver.push_external_function_arg_vector_onto_stack(); } '(' comma_hand_side ')' - { $$ = driver.add_model_var_or_external_function($1); } + { $$ = driver.add_model_var_or_external_function($1,true); } | NORMCDF '(' hand_side COMMA hand_side COMMA hand_side ')' - { $$ = driver.add_normcdf($3, $5, $7); } + { $$ = driver.add_normcdf($3, $5, $7); } | NORMCDF '(' hand_side ')' - { $$ = driver.add_normcdf($3); } + { $$ = driver.add_normcdf($3); } | STEADY_STATE '(' hand_side ')' - { $$ = driver.add_steady_state($3); } + { $$ = driver.add_steady_state($3); } ; comma_hand_side : hand_side diff --git a/preprocessor/ExprNode.cc b/preprocessor/ExprNode.cc index 384964234..6965ebd88 100644 --- a/preprocessor/ExprNode.cc +++ b/preprocessor/ExprNode.cc @@ -3206,6 +3206,7 @@ ExternalFunctionNode::prepareForDerivation() NodeID ExternalFunctionNode::computeDerivative(int deriv_id) { + assert(datatree.external_functions_table.getNargs(symb_id) > 0); vector dargs; for (vector::const_iterator it = arguments.begin(); it != arguments.end(); it++) dargs.push_back((*it)->getDerivative(deriv_id)); diff --git a/preprocessor/ExternalFunctionsTable.cc b/preprocessor/ExternalFunctionsTable.cc index 83c03f607..2650a977d 100644 --- a/preprocessor/ExternalFunctionsTable.cc +++ b/preprocessor/ExternalFunctionsTable.cc @@ -31,16 +31,12 @@ ExternalFunctionsTable::ExternalFunctionsTable() } void -ExternalFunctionsTable::addExternalFunction(int symb_id, const external_function_options &external_function_options_arg) +ExternalFunctionsTable::addExternalFunction(int symb_id, const external_function_options &external_function_options_arg, bool track_nargs) { assert(symb_id >= 0); + assert(external_function_options_arg.nargs > 0); - if (external_function_options_arg.nargs <= 0) - { - cerr << "ERROR: The number of arguments passed to an external function must be > 0." << endl; - exit(EXIT_FAILURE); - } - + // Change options to be saved so the table is consistent external_function_options external_function_options_chng = external_function_options_arg; if (external_function_options_arg.firstDerivSymbID == eExtFunSetButNoNameProvided) external_function_options_chng.firstDerivSymbID = symb_id; @@ -48,6 +44,10 @@ ExternalFunctionsTable::addExternalFunction(int symb_id, const external_function if (external_function_options_arg.secondDerivSymbID == eExtFunSetButNoNameProvided) external_function_options_chng.secondDerivSymbID = symb_id; + if (!track_nargs) + external_function_options_chng.nargs = eExtFunNotSet; + + // Ensure 1st & 2nd deriv option consistency if (external_function_options_chng.secondDerivSymbID == symb_id && external_function_options_chng.firstDerivSymbID != symb_id) { @@ -81,29 +81,37 @@ ExternalFunctionsTable::addExternalFunction(int symb_id, const external_function exit(EXIT_FAILURE); } + // Ensure that if we're overwriting something, we mean to do it if (exists(symb_id)) { - if (external_function_options_arg.nargs != getNargs(symb_id)) - { - cerr << "ERROR: The number of arguments passed to the external_function() statement do not " - << "match the number of arguments passed to a previous call or declaration of the top-level function."<< endl; - exit(EXIT_FAILURE); - } + bool ok_to_overwrite = false; + if (getNargs(symb_id) == eExtFunNotSet) // implies that the information stored about this function is not important + ok_to_overwrite = true; - if (external_function_options_chng.firstDerivSymbID != getFirstDerivSymbID(symb_id)) - { - cerr << "ERROR: The first derivative function passed to the external_function() statement does not " - << "match the first derivative function passed to a previous call or declaration of the top-level function."<< endl; - exit(EXIT_FAILURE); - } + if (!ok_to_overwrite) // prevents multiple non-compatible calls to external_function(name=funcname) + { // e.g. e_f(name=a,nargs=1,fd,sd) and e_f(name=a,nargs=2,fd=b,sd=c) should cause an error + if (external_function_options_chng.nargs != getNargs(symb_id)) + { + cerr << "ERROR: The number of arguments passed to the external_function() statement do not " + << "match the number of arguments passed to a previous call or declaration of the top-level function."<< endl; + exit(EXIT_FAILURE); + } - if (external_function_options_chng.secondDerivSymbID != getSecondDerivSymbID(symb_id)) - { - cerr << "ERROR: The second derivative function passed to the external_function() statement does not " - << "match the second derivative function passed to a previous call or declaration of the top-level function."<< endl; - exit(EXIT_FAILURE); + if (external_function_options_chng.firstDerivSymbID != getFirstDerivSymbID(symb_id)) + { + cerr << "ERROR: The first derivative function passed to the external_function() statement does not " + << "match the first derivative function passed to a previous call or declaration of the top-level function."<< endl; + exit(EXIT_FAILURE); + } + + if (external_function_options_chng.secondDerivSymbID != getSecondDerivSymbID(symb_id)) + { + cerr << "ERROR: The second derivative function passed to the external_function() statement does not " + << "match the second derivative function passed to a previous call or declaration of the top-level function."<< endl; + exit(EXIT_FAILURE); + } } } - else - externalFunctionTable[symb_id] = external_function_options_chng; + + externalFunctionTable[symb_id] = external_function_options_chng; } diff --git a/preprocessor/ExternalFunctionsTable.hh b/preprocessor/ExternalFunctionsTable.hh index e1d19cf91..85e97e40f 100644 --- a/preprocessor/ExternalFunctionsTable.hh +++ b/preprocessor/ExternalFunctionsTable.hh @@ -65,7 +65,7 @@ private: public: ExternalFunctionsTable(); //! Adds an external function to the table as well as its derivative functions - void addExternalFunction(int symb_id, const external_function_options &external_function_options_arg); + void addExternalFunction(int symb_id, const external_function_options &external_function_options_arg, bool track_nargs); //! See if the function exists in the External Functions Table inline bool exists(int symb_id) const; //! Get the number of arguments for a given external function @@ -75,7 +75,7 @@ public: //! Get the symbol_id of the second derivative function inline int getSecondDerivSymbID(int symb_id) const throw (UnknownExternalFunctionSymbolIDException); //! Returns the total number of unique external functions declared or used in the .mod file - inline int get_total_number_of_unique_external_functions() const; + inline int get_total_number_of_unique_model_block_external_functions() const; }; inline bool @@ -113,9 +113,15 @@ ExternalFunctionsTable::getSecondDerivSymbID(int symb_id) const throw (UnknownEx } inline int -ExternalFunctionsTable::get_total_number_of_unique_external_functions() const +ExternalFunctionsTable::get_total_number_of_unique_model_block_external_functions() const { - return externalFunctionTable.size(); + int number_of_unique_model_block_external_functions = 0; + for (external_function_table_type::const_iterator it = externalFunctionTable.begin(); + it != externalFunctionTable.end(); it++) + if (it->second.nargs > 0) + number_of_unique_model_block_external_functions++; + + return number_of_unique_model_block_external_functions; } #endif diff --git a/preprocessor/ModFile.cc b/preprocessor/ModFile.cc index 05ffae0b8..409a1cd59 100644 --- a/preprocessor/ModFile.cc +++ b/preprocessor/ModFile.cc @@ -135,7 +135,7 @@ ModFile::checkPass() exit(EXIT_FAILURE); } - if ((use_dll || byte_code) && external_functions_table.get_total_number_of_unique_external_functions()) + if ((use_dll || byte_code) && (external_functions_table.get_total_number_of_unique_model_block_external_functions() > 0)) { cerr << "ERROR: In 'model' block, use of external functions is not compatible with 'use_dll' or 'bytecode'" << endl; exit(EXIT_FAILURE); diff --git a/preprocessor/ParsingDriver.cc b/preprocessor/ParsingDriver.cc index 0f890e227..d240abbf3 100644 --- a/preprocessor/ParsingDriver.cc +++ b/preprocessor/ParsingDriver.cc @@ -1628,6 +1628,7 @@ void ParsingDriver::external_function_option(const string &name_option, string *opt) { external_function_option(name_option, *opt); + delete opt; } void @@ -1680,7 +1681,7 @@ ParsingDriver::external_function() current_external_function_options.firstDerivSymbID != eExtFunSetButNoNameProvided) error("If the second derivative is provided in the top-level function, the first derivative must also be provided in that function."); - mod_file->external_functions_table.addExternalFunction(current_external_function_id, current_external_function_options); + mod_file->external_functions_table.addExternalFunction(current_external_function_id, current_external_function_options, true); reset_current_external_function_options(); } @@ -1698,8 +1699,9 @@ ParsingDriver::add_external_function_arg(NodeID arg) } NodeID -ParsingDriver::add_model_var_or_external_function(string *function_name) +ParsingDriver::add_model_var_or_external_function(string *function_name, bool in_model_block) { + NodeID nid; if (mod_file->symbol_table.exists(*function_name)) { if (mod_file->symbol_table.getType(*function_name) != eExternalFunction) @@ -1734,7 +1736,7 @@ ParsingDriver::add_model_var_or_external_function(string *function_name) if ((double) model_var_arg != model_var_arg_dbl) //make 100% sure int cast didn't lose info error("A model variable is being treated as if it were a function (i.e., takes an argument that is not an integer)."); - NodeID nid = add_model_variable(mod_file->symbol_table.getID(*function_name), model_var_arg); + nid = add_model_variable(mod_file->symbol_table.getID(*function_name), model_var_arg); stack_external_function_args.pop(); delete function_name; return nid; @@ -1748,25 +1750,32 @@ ParsingDriver::add_model_var_or_external_function(string *function_name) int symb_id = mod_file->symbol_table.getID(*function_name); assert(mod_file->external_functions_table.exists(symb_id)); - if ((int)(stack_external_function_args.top().size()) != mod_file->external_functions_table.getNargs(symb_id)) - error("The number of arguments passed to " + *function_name + - " does not match those of a previous call or declaration of this function."); + if (in_model_block) + if (mod_file->external_functions_table.getNargs(symb_id) == eExtFunNotSet) + error("Before using " + *function_name + + "() in the model block, you must first declare it via the external_function() statement"); + else if ((int)(stack_external_function_args.top().size()) != mod_file->external_functions_table.getNargs(symb_id)) + error("The number of arguments passed to " + *function_name + + "() does not match those of a previous call or declaration of this function."); } } else { //First time encountering this external function i.e., not previously declared or encountered + if (in_model_block) + error("To use an external function within the model block, you must first declare it via the external_function() statement."); + declare_symbol(function_name, eExternalFunction, NULL); current_external_function_options.nargs = stack_external_function_args.top().size(); mod_file->external_functions_table.addExternalFunction(mod_file->symbol_table.getID(*function_name), - current_external_function_options); + current_external_function_options, in_model_block); reset_current_external_function_options(); } //By this point, we're sure that this function exists in the External Functions Table and is not a mod var - NodeID id = data_tree->AddExternalFunction(*function_name, stack_external_function_args.top()); + nid = data_tree->AddExternalFunction(*function_name, stack_external_function_args.top()); stack_external_function_args.pop(); delete function_name; - return id; + return nid; } void diff --git a/preprocessor/ParsingDriver.hh b/preprocessor/ParsingDriver.hh index 55d26602c..ce9f1bd2a 100644 --- a/preprocessor/ParsingDriver.hh +++ b/preprocessor/ParsingDriver.hh @@ -484,7 +484,7 @@ public: //! Adds an external function argument void add_external_function_arg(NodeID arg); //! Adds an external function call node - NodeID add_model_var_or_external_function(string *function_name); + NodeID add_model_var_or_external_function(string *function_name, bool in_model_block); //! Adds a native statement void add_native(const char *s); //! Resets data_tree and model_tree pointers to default (i.e. mod_file->expressions_tree)