From a4045ae2da485d22d3ad2f1c62d761fc8e773fa4 Mon Sep 17 00:00:00 2001 From: Houtan Bastani Date: Fri, 20 May 2016 15:35:09 +0200 Subject: [PATCH] preprocessor: make modifications to osr_params and osr_params_bounds output. #948 --- ComputingTasks.cc | 43 +++++++++++++++++++++++++++---------------- ComputingTasks.hh | 8 ++++---- ParsingDriver.cc | 4 ++-- 3 files changed, 33 insertions(+), 22 deletions(-) diff --git a/ComputingTasks.cc b/ComputingTasks.cc index a47e8497..dfdf21b7 100644 --- a/ComputingTasks.cc +++ b/ComputingTasks.cc @@ -1033,8 +1033,9 @@ ObservationTrendsStatement::writeOutput(ostream &output, const string &basename, } } -OsrParamsStatement::OsrParamsStatement(const SymbolList &symbol_list_arg) : - symbol_list(symbol_list_arg) +OsrParamsStatement::OsrParamsStatement(const SymbolList &symbol_list_arg, const SymbolTable &symbol_table_arg) : + symbol_list(symbol_list_arg), + symbol_table(symbol_table_arg) { } @@ -1047,7 +1048,13 @@ OsrParamsStatement::checkPass(ModFileStructure &mod_file_struct, WarningConsolid void OsrParamsStatement::writeOutput(ostream &output, const string &basename, bool minimal_workspace) const { - symbol_list.writeOutput("osr_params_", output); + symbol_list.writeOutput("M_.osr.param_names", output); + output << "M_.osr.param_names = cellstr(M_.osr.param_names);" << endl + << "M_.osr.param_indices = zeros(length(M_.osr.param_names), 1);" << endl; + int i = 0; + vector symbols = symbol_list.get_symbols(); + for (vector::const_iterator it = symbols.begin(); it != symbols.end(); it++) + output << "M_.osr.param_indices(" << ++i <<") = " << symbol_table.getTypeSpecificID(*it) + 1 << ";" << endl; } OsrStatement::OsrStatement(const SymbolList &symbol_list_arg, @@ -1057,27 +1064,31 @@ OsrStatement::OsrStatement(const SymbolList &symbol_list_arg, { } -OsrParamsBoundsStatement::OsrParamsBoundsStatement(const vector &osr_params_list_arg, - const SymbolTable &symbol_table_arg) : - osr_params_list(osr_params_list_arg), - symbol_table(symbol_table_arg) +OsrParamsBoundsStatement::OsrParamsBoundsStatement(const vector &osr_params_list_arg) : + osr_params_list(osr_params_list_arg) { } +void +OsrParamsBoundsStatement::checkPass(ModFileStructure &mod_file_struct, WarningConsolidation &warnings) +{ + if (!mod_file_struct.osr_params_present) + { + cerr << "ERROR: you must have an osr_params statement before the osr_params_bounds block." << endl; + exit(EXIT_FAILURE); + } +} + void OsrParamsBoundsStatement::writeOutput(ostream &output, const string &basename, bool minimal_workspace) const { - int nbnds = osr_params_list.size(); - output << "M_.osr.param_names = cell(" << nbnds << ", 1);" << endl - << "M_.osr.param_indices = zeros(" << nbnds << ", 1);" << endl - << "M_.osr.bounds = zeros(" << nbnds << ", 2);" << endl; - int i = 1; + + output << "M_.osr.param_bounds = [-inf(length(M_.osr.param_names), 1), inf(length(M_.osr.param_names), 1)];" << endl; + for (vector::const_iterator it = osr_params_list.begin(); - it != osr_params_list.end(); it++, i++) + it != osr_params_list.end(); it++) { - output << "M_.osr.param_names{" << i << "} = '" << it->name << "';" << endl - << "M_.osr.param_indices(" << i <<") = " << symbol_table.getTypeSpecificID(it->name) + 1 << ";" << endl - << "M_.osr.bounds(" << i << ", :) = ["; + output << "M_.osr.param_bounds(strcmp(M_.osr.param_names, '" << it->name << "'), :) = ["; it->low_bound->writeOutput(output); output << ", "; it->up_bound->writeOutput(output); diff --git a/ComputingTasks.hh b/ComputingTasks.hh index 7698a83f..8f0e4925 100644 --- a/ComputingTasks.hh +++ b/ComputingTasks.hh @@ -249,8 +249,9 @@ class OsrParamsStatement : public Statement { private: const SymbolList symbol_list; + const SymbolTable &symbol_table; public: - OsrParamsStatement(const SymbolList &symbol_list_arg); + OsrParamsStatement(const SymbolList &symbol_list_arg, const SymbolTable &symbol_table_arg); virtual void checkPass(ModFileStructure &mod_file_struct, WarningConsolidation &warnings); virtual void writeOutput(ostream &output, const string &basename, bool minimal_workspace) const; }; @@ -287,10 +288,9 @@ class OsrParamsBoundsStatement : public Statement { private: const vector osr_params_list; - const SymbolTable &symbol_table; public: - OsrParamsBoundsStatement(const vector &osr_params_list_arg, - const SymbolTable &symbol_table_arg); + OsrParamsBoundsStatement(const vector &osr_params_list_arg); + virtual void checkPass(ModFileStructure &mod_file_struct, WarningConsolidation &warnings); virtual void writeOutput(ostream &output, const string &basename, bool minimal_workspace) const; }; diff --git a/ParsingDriver.cc b/ParsingDriver.cc index 467bf114..8aa1b098 100644 --- a/ParsingDriver.cc +++ b/ParsingDriver.cc @@ -1299,7 +1299,7 @@ ParsingDriver::add_osr_params_element() void ParsingDriver::osr_params_bounds() { - mod_file->addStatement(new OsrParamsBoundsStatement(osr_params_list, mod_file->symbol_table)); + mod_file->addStatement(new OsrParamsBoundsStatement(osr_params_list)); osr_params_list.clear(); } @@ -1773,7 +1773,7 @@ ParsingDriver::optim_weights() void ParsingDriver::set_osr_params() { - mod_file->addStatement(new OsrParamsStatement(symbol_list)); + mod_file->addStatement(new OsrParamsStatement(symbol_list, mod_file->symbol_table)); symbol_list.clear(); }