Preprocessor: in steady_state_model block, allow MATLAB functions which return several arguments (closes #37)

issue#70
Sébastien Villemot 2011-01-26 13:55:01 -05:00
parent 6fba82c3a5
commit 9b3d611a0b
7 changed files with 104 additions and 27 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright (C) 2003-2010 Dynare Team
* Copyright (C) 2003-2011 Dynare Team
*
* This file is part of Dynare.
*
@ -1651,6 +1651,8 @@ steady_state_equation_list : steady_state_equation_list steady_state_equation
steady_state_equation : symbol EQUAL expression ';'
{ driver.add_steady_state_model_equal($1, $3); }
| '[' symbol_list ']' EQUAL expression ';'
{ driver.add_steady_state_model_equal_multiple($5); }
;
o_dr_algo : DR_ALGO EQUAL INT_NUMBER {

View File

@ -4075,7 +4075,7 @@ ExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType output_typ
const temporary_terms_t &temporary_terms,
deriv_node_temp_terms_t &tef_terms) const
{
if (output_type == oMatlabOutsideModel)
if (output_type == oMatlabOutsideModel || output_type == oSteadyStateFile)
{
output << datatree.symbol_table.getName(symb_id) << "(";
writeExternalFunctionArguments(output, output_type, temporary_terms, tef_terms);

View File

@ -1899,3 +1899,32 @@ ParsingDriver::add_steady_state_model_equal(string *varname, expr_t expr)
delete varname;
}
void
ParsingDriver::add_steady_state_model_equal_multiple(expr_t expr)
{
const vector<string> &symbs = symbol_list.get_symbols();
vector<int> ids;
for (size_t i = 0; i < symbs.size(); i++)
{
int id;
try
{
id = mod_file->symbol_table.getID(symbs[i]);
}
catch (SymbolTable::UnknownSymbolNameException &e)
{
// Unknown symbol, declare it as a ModFileLocalVariable
id = mod_file->symbol_table.addSymbol(symbs[i], eModFileLocalVariable);
}
SymbolType type = mod_file->symbol_table.getType(id);
if (type != eEndogenous && type != eModFileLocalVariable && type != eParameter)
error(symbs[i] + " has incorrect type");
ids.push_back(id);
}
mod_file->steady_state_model.addMultipleDefinitions(ids, expr);
symbol_list.clear();
}

View File

@ -496,6 +496,8 @@ public:
void begin_steady_state_model();
//! Add an assignment equation in steady_state_model block
void add_steady_state_model_equal(string *varname, expr_t expr);
//! Add a multiple assignment equation in steady_state_model block
void add_steady_state_model_equal_multiple(expr_t expr);
//! Switches datatree
void begin_trend();
//! Declares a trend variable with its growth factor

View File

@ -1,5 +1,5 @@
/*
* Copyright (C) 2010 Dynare Team
* Copyright (C) 2010-2011 Dynare Team
*
* This file is part of Dynare.
*
@ -30,44 +30,70 @@ SteadyStateModel::SteadyStateModel(SymbolTable &symbol_table_arg, NumericalConst
void
SteadyStateModel::addDefinition(int symb_id, expr_t expr)
{
AddVariable(symb_id); // Create the variable node to be used in write method
assert(symbol_table.getType(symb_id) == eEndogenous
|| symbol_table.getType(symb_id) == eModFileLocalVariable
|| symbol_table.getType(symb_id) == eParameter);
// Add the variable
recursive_order.push_back(symb_id);
def_table[symb_id] = AddEqual(AddVariable(symb_id), expr);
vector<int> v;
v.push_back(symb_id);
recursive_order.push_back(v);
def_table[v] = expr;
}
void
SteadyStateModel::addMultipleDefinitions(const vector<int> &symb_ids, expr_t expr)
{
for (size_t i = 0; i < symb_ids.size(); i++)
{
AddVariable(symb_ids[i]); // Create the variable nodes to be used in write method
assert(symbol_table.getType(symb_ids[i]) == eEndogenous
|| symbol_table.getType(symb_ids[i]) == eModFileLocalVariable
|| symbol_table.getType(symb_ids[i]) == eParameter);
}
recursive_order.push_back(symb_ids);
def_table[symb_ids] = expr;
}
void
SteadyStateModel::checkPass(bool ramsey_policy) const
{
for (vector<int>::const_iterator it = recursive_order.begin();
it != recursive_order.end(); ++it)
vector<int> so_far_defined;
for (size_t i = 0; i < recursive_order.size(); i++)
{
// Check that symbol is not already defined
if (find(recursive_order.begin(), it, *it) != it)
{
cerr << "ERROR: in the 'steady_state' block, variable '" << symbol_table.getName(*it) << "' is declared twice" << endl;
exit(EXIT_FAILURE);
}
const vector<int> &symb_ids = recursive_order[i];
// Check that symbols are not already defined
for (size_t j = 0; j < symb_ids.size(); j++)
if (find(so_far_defined.begin(), so_far_defined.end(), symb_ids[j])
!= so_far_defined.end())
{
cerr << "ERROR: in the 'steady_state' block, variable '" << symbol_table.getName(symb_ids[j]) << "' is declared twice" << endl;
exit(EXIT_FAILURE);
}
// Check that expression has no undefined symbol
if (!ramsey_policy)
{
set<pair<int, int> > used_symbols;
expr_t expr = def_table.find(*it)->second;
expr_t expr = def_table.find(symb_ids)->second;
expr->collectVariables(eEndogenous, used_symbols);
expr->collectVariables(eModFileLocalVariable, used_symbols);
for(set<pair<int, int> >::const_iterator it2 = used_symbols.begin();
it2 != used_symbols.end(); ++it2)
if (find(recursive_order.begin(), it, it2->first) == it
&& *it != it2->first)
for(set<pair<int, int> >::const_iterator it = used_symbols.begin();
it != used_symbols.end(); ++it)
if (find(so_far_defined.begin(), so_far_defined.end(), it->first)
== so_far_defined.end())
{
cerr << "ERROR: in the 'steady_state' block, variable '" << symbol_table.getName(it2->first) << "' is undefined in the declaration of variable '" << symbol_table.getName(*it) << "'" << endl;
cerr << "ERROR: in the 'steady_state' block, variable '" << symbol_table.getName(it->first)
<< "' is undefined in the declaration of variable '" << symbol_table.getName(symb_ids[0]) << "'" << endl;
exit(EXIT_FAILURE);
}
}
copy(symb_ids.begin(), symb_ids.end(), back_inserter(so_far_defined));
}
}
@ -98,11 +124,25 @@ SteadyStateModel::writeSteadyStateFile(const string &basename, bool ramsey_polic
output << " ys_=zeros(" << symbol_table.orig_endo_nbr() << ",1);" << endl;
output << " global M_" << endl;
for(size_t i = 0; i < recursive_order.size(); i++)
for (size_t i = 0; i < recursive_order.size(); i++)
{
const vector<int> &symb_ids = recursive_order[i];
output << " ";
map<int, expr_t>::const_iterator it = def_table.find(recursive_order[i]);
it->second->writeOutput(output, oSteadyStateFile);
if (symb_ids.size() > 1)
output << "[";
for (size_t j = 0; j < symb_ids.size(); j++)
{
variable_node_map_t::const_iterator it = variable_node_map.find(make_pair(symb_ids[j], 0));
assert(it != variable_node_map.end());
dynamic_cast<ExprNode *>(it->second)->writeOutput(output, oSteadyStateFile);
if (j < symb_ids.size()-1)
output << ",";
}
if (symb_ids.size() > 1)
output << "]";
output << "=";
def_table.find(symb_ids)->second->writeOutput(output, oSteadyStateFile);
output << ";" << endl;
}
output << " % Auxiliary equations" << endl;

View File

@ -1,5 +1,5 @@
/*
* Copyright (C) 2010 Dynare Team
* Copyright (C) 2010-2011 Dynare Team
*
* This file is part of Dynare.
*
@ -26,9 +26,9 @@
class SteadyStateModel : public DataTree
{
private:
//! Associates a symbol ID to an expression of the form "var = expr"
map<int, expr_t> def_table;
vector<int> recursive_order;
//! Associates a set of symbol IDs (the variable(s) assigned in a given statement) to an expression (their assigned value)
map<vector<int>, expr_t> def_table;
vector<vector<int> > recursive_order;
//! Reference to static model (for writing auxiliary equations)
const StaticModel &static_model;
@ -37,6 +37,8 @@ public:
SteadyStateModel(SymbolTable &symbol_table_arg, NumericalConstants &num_constants, ExternalFunctionsTable &external_functions_table_arg, const StaticModel &static_model_arg);
//! Add an expression of the form "var = expr;"
void addDefinition(int symb_id, expr_t expr);
//! Add an expression of the form "[ var1, var2, ... ] = expr;"
void addMultipleDefinitions(const vector<int> &symb_ids, expr_t expr);
//! Checks that definitions are in a recursive order, and that no variable is declared twice
/*!
\param[in] ramsey_policy Is there a ramsey_policy statement in the MOD file? If yes, then disable the check on the recursivity of the declarations

View File

@ -1,5 +1,5 @@
/*
* Copyright (C) 2003-2008 Dynare Team
* Copyright (C) 2003-2011 Dynare Team
*
* This file is part of Dynare.
*
@ -41,6 +41,8 @@ public:
void writeOutput(const string &varname, ostream &output) const;
//! Clears all content
void clear();
//! Get a copy of the string vector
vector<string> get_symbols() const { return symbols; };
};
#endif