preprocessor: add joint prior syntax, #824

time-shift
Houtan Bastani 2015-03-03 15:08:33 +01:00
parent 219b4fe547
commit 30428aeb17
8 changed files with 209 additions and 1 deletions

View File

@ -362,6 +362,8 @@ estimation_info.measurement_error_corr.range_index = {};
estimation_info.structural_innovation_corr_prior_index = {};
estimation_info.structural_innovation_corr_options_index = {};
estimation_info.structural_innovation_corr.range_index = {};
estimation_info.joint_parameter_prior_index = {};
estimation_info.joint_parameter = cell2table(cell(0,11));
options_.initial_period = NaN; %dates(1,1);
options_.dataset.file = [];
options_.dataset.series = [];

View File

@ -2082,6 +2082,115 @@ SubsamplesEqualStatement::writeOutput(ostream &output, const string &basename) c
<< endl;
}
JointPriorStatement::JointPriorStatement(const vector<string> joint_parameters_arg,
const PriorDistributions &prior_shape_arg,
const OptionsList &options_list_arg) :
joint_parameters(joint_parameters_arg),
prior_shape(prior_shape_arg),
options_list(options_list_arg)
{
}
void
JointPriorStatement::checkPass(ModFileStructure &mod_file_struct, WarningConsolidation &warnings)
{
if (joint_parameters.size() < 2)
{
cerr << "ERROR: you must pass at least two parameters to the joint prior statement" << endl;
exit(EXIT_FAILURE);
}
if (prior_shape == eNoShape)
{
cerr << "ERROR: You must pass the shape option to the prior statement." << endl;
exit(EXIT_FAILURE);
}
if (options_list.num_options.find("mean") == options_list.num_options.end() &&
options_list.num_options.find("mode") == options_list.num_options.end())
{
cerr << "ERROR: You must pass at least one of mean and mode to the prior statement." << endl;
exit(EXIT_FAILURE);
}
OptionsList::num_options_t::const_iterator it_num = options_list.num_options.find("domain");
if (it_num != options_list.num_options.end())
{
using namespace boost;
vector<string> tokenizedDomain;
split(tokenizedDomain, it_num->second, is_any_of("[ ]"), token_compress_on);
if (tokenizedDomain.size() != 4)
{
cerr << "ERROR: You must pass exactly two values to the domain option." << endl;
exit(EXIT_FAILURE);
}
}
}
void
JointPriorStatement::writeOutput(ostream &output, const string &basename) const
{
for (vector<string>::const_iterator it = joint_parameters.begin() ; it != joint_parameters.end(); it++)
output << "eifind = get_new_or_existing_ei_index('joint_parameter_prior_index', '"
<< *it << "', '');" << endl
<< "estimation_info.joint_parameter_prior_index(eifind) = {'" << *it << "'};" << endl;
output << "key = {[";
for (vector<string>::const_iterator it = joint_parameters.begin() ; it != joint_parameters.end(); it++)
output << "get_new_or_existing_ei_index('joint_parameter_prior_index', '" << *it << "', '') ..."
<< endl << " ";
output << "]};" << endl;
string lhs_field("estimation_info.joint_parameter_tmp");
writeOutputHelper(output, "domain", lhs_field);
writeOutputHelper(output, "interval", lhs_field);
writeOutputHelper(output, "mean", lhs_field);
writeOutputHelper(output, "median", lhs_field);
writeOutputHelper(output, "mode", lhs_field);
assert(prior_shape != eNoShape);
output << lhs_field << ".shape = " << prior_shape << ";" << endl;
writeOutputHelper(output, "shift", lhs_field);
writeOutputHelper(output, "stdev", lhs_field);
writeOutputHelper(output, "truncate", lhs_field);
writeOutputHelper(output, "variance", lhs_field);
output << "estimation_info.joint_parameter_tmp = table(key, ..." << endl
<< " " << lhs_field << ".domain , ..." << endl
<< " " << lhs_field << ".interval , ..." << endl
<< " " << lhs_field << ".mean , ..." << endl
<< " " << lhs_field << ".median , ..." << endl
<< " " << lhs_field << ".mode , ..." << endl
<< " " << lhs_field << ".shape , ..." << endl
<< " " << lhs_field << ".shift , ..." << endl
<< " " << lhs_field << ".stdev , ..." << endl
<< " " << lhs_field << ".truncate , ..." << endl
<< " " << lhs_field << ".variance, ..." << endl
<< " 'VariableNames',{'index','domain','interval','mean','median','mode','shape','shift','stdev','truncate','variance'});" << endl;
output << "if height(estimation_info.joint_parameter)" << endl
<< " estimation_info.joint_parameter = [estimation_info.joint_parameter; estimation_info.joint_parameter_tmp];" << endl
<< "else" << endl
<< " estimation_info.joint_parameter = estimation_info.joint_parameter_tmp;" << endl
<< "end" << endl
<< "clear estimation_info.joint_parameter_tmp;" << endl;
}
void
JointPriorStatement::writeOutputHelper(ostream &output, const string &field, const string &lhs_field) const
{
OptionsList::num_options_t::const_iterator itn = options_list.num_options.find(field);
output << lhs_field << "." << field << " = {";
if (itn != options_list.num_options.end())
output << itn->second;
else
output << "{}";
output << "};" << endl;
}
BasicPriorStatement::~BasicPriorStatement()
{
}

View File

@ -679,6 +679,22 @@ public:
virtual void writeOutput(ostream &output, const string &basename) const;
};
class JointPriorStatement : public Statement
{
private:
const vector<string> joint_parameters;
const PriorDistributions prior_shape;
const OptionsList options_list;
public:
JointPriorStatement(const vector<string> joint_parameters_arg,
const PriorDistributions &prior_shape_arg,
const OptionsList &options_list_arg);
virtual void checkPass(ModFileStructure &mod_file_struct, WarningConsolidation &warnings);
virtual void writeOutput(ostream &output, const string &basename) const;
void writeOutputHelper(ostream &output, const string &field, const string &lhs_field) const;
};
class BasicPriorStatement : public Statement
{
public:

View File

@ -171,6 +171,7 @@ class ParsingDriver;
%token PARAMETER_CONVERGENCE_CRITERION NUMBER_OF_LARGE_PERTURBATIONS NUMBER_OF_SMALL_PERTURBATIONS
%token NUMBER_OF_POSTERIOR_DRAWS_AFTER_PERTURBATION MAX_NUMBER_OF_STAGES
%token RANDOM_FUNCTION_CONVERGENCE_CRITERION RANDOM_PARAMETER_CONVERGENCE_CRITERION
%token <vector_string_val> SYMBOL_VEC
%type <node_val> expression expression_or_empty
%type <node_val> equation hand_side
@ -1422,6 +1423,8 @@ prior : symbol '.' PRIOR { driver.set_prior_variance(); driver.prior_shape = eNo
{ driver.set_prior($1, new string ("")); }
| symbol '.' symbol '.' PRIOR { driver.set_prior_variance(); driver.prior_shape = eNoShape; } '(' prior_options_list ')' ';'
{ driver.set_prior($1, $3); }
| SYMBOL_VEC '.' PRIOR { driver.set_prior_variance(); driver.prior_shape = eNoShape; } '(' joint_prior_options_list ')' ';'
{ driver.set_joint_prior($1); }
| STD '(' symbol ')' '.' PRIOR { driver.set_prior_variance(); driver.prior_shape = eNoShape; } '(' prior_options_list ')' ';'
{ driver.set_std_prior($3, new string ("")); }
| STD '(' symbol ')' '.' symbol '.' PRIOR { driver.set_prior_variance(); driver.prior_shape = eNoShape; } '(' prior_options_list ')' ';'
@ -1448,6 +1451,22 @@ prior_options : o_shift
| o_domain
;
joint_prior_options_list : joint_prior_options_list COMMA joint_prior_options
| joint_prior_options
;
joint_prior_options : o_shift
| o_mean_vec
| o_median
| o_stdev
| o_truncate
| o_variance_mat
| o_mode
| o_interval
| o_shape
| o_domain
;
prior_eq : prior_eq_opt EQUAL prior_eq_opt ';'
{
driver.copy_prior($1->at(0), $1->at(1), $1->at(2), $1->at(3),
@ -2544,6 +2563,7 @@ o_shift : SHIFT EQUAL signed_number { driver.option_num("shift", $3); };
o_shape : SHAPE EQUAL prior_distribution { driver.prior_shape = $3; };
o_mode : MODE EQUAL signed_number { driver.option_num("mode", $3); };
o_mean : MEAN EQUAL signed_number { driver.option_num("mean", $3); };
o_mean_vec : MEAN EQUAL vec_value { driver.option_num("mean", $3); };
o_truncate : TRUNCATE EQUAL vec_value { driver.option_num("truncate", $3); };
o_stdev : STDEV EQUAL non_negative_number { driver.option_num("stdev", $3); };
o_jscale : JSCALE EQUAL non_negative_number { driver.option_num("jscale", $3); };
@ -2552,6 +2572,7 @@ o_bounds : BOUNDS EQUAL vec_value_w_inf { driver.option_num("bounds", $3); };
o_domain : DOMAINN EQUAL vec_value { driver.option_num("domain", $3); };
o_interval : INTERVAL EQUAL vec_value { driver.option_num("interval", $3); };
o_variance : VARIANCE EQUAL expression { driver.set_prior_variance($3); }
o_variance_mat : VARIANCE EQUAL vec_of_vec_value { driver.option_num("variance",$3); }
o_prefilter : PREFILTER EQUAL INT_NUMBER { driver.option_num("prefilter", $3); };
o_presample : PRESAMPLE EQUAL INT_NUMBER { driver.option_num("presample", $3); };
o_lik_algo : LIK_ALGO EQUAL INT_NUMBER { driver.option_num("lik_algo", $3); };

View File

@ -839,6 +839,38 @@ DATE -?[0-9]+([YyAa]|[Mm]([1-9]|1[0-2])|[Qq][1-4]|[Ww]([1-9]{1}|[1-4][0-9]|5[0-2
}
}
/* For joint prior statement, match [symbol, symbol, ...]
If no match, begin native and push everything back on stack
*/
<INITIAL>\[([[:space:]]*[A-Za-z_][A-Za-z0-9_]*[[:space:]]*,{1}[[:space:]]*)*([[:space:]]*[A-Za-z_][A-Za-z0-9_]*[[:space:]]*){1}\] {
string yytextcpy = string(yytext);
yytextcpy.erase(remove(yytextcpy.begin(), yytextcpy.end(), '['), yytextcpy.end());
yytextcpy.erase(remove(yytextcpy.begin(), yytextcpy.end(), ']'), yytextcpy.end());
yytextcpy.erase(remove(yytextcpy.begin(), yytextcpy.end(), ' '), yytextcpy.end());
istringstream ss(yytextcpy);
string token;
yylval->vector_string_val = new vector<string *>;
while(getline(ss, token, ','))
if (driver.symbol_exists_and_is_not_modfile_local_or_external_function(token.c_str()))
yylval->vector_string_val->push_back(new string(token));
else
{
for (vector<string *>::iterator it=yylval->vector_string_val->begin();
it != yylval->vector_string_val->end(); it++)
delete *it;
delete yylval->vector_string_val;
BEGIN NATIVE;
yyless(0);
break;
}
if (yylval->vector_string_val->size() > 0)
{
BEGIN DYNARE_STATEMENT;
return token::SYMBOL_VEC;
}
}
/* Enter a native block */
<INITIAL>. { BEGIN NATIVE; yyless(0); }

View File

@ -1410,6 +1410,26 @@ ParsingDriver::set_prior(string *name, string *subsample_name)
delete subsample_name;
}
void
ParsingDriver::set_joint_prior(vector<string *>*symbol_vec)
{
for (vector<string *>::const_iterator it=symbol_vec->begin(); it != symbol_vec->end(); it++)
add_joint_parameter(*it);
mod_file->addStatement(new JointPriorStatement(joint_parameters, prior_shape, options_list));
joint_parameters.clear();
options_list.clear();
prior_shape = eNoShape;
delete symbol_vec;
}
void
ParsingDriver::add_joint_parameter(string *name)
{
check_symbol_is_parameter(name);
joint_parameters.push_back(*name);
delete name;
}
void
ParsingDriver::set_prior_variance(expr_t variance)
{

View File

@ -186,6 +186,8 @@ private:
//! Temporary storage for argument list of external function
stack<vector<expr_t> > stack_external_function_args;
//! Temporary storage for parameters in joint prior statement
vector<string> joint_parameters;
//! Temporary storage for the symb_id associated with the "name" symbol of the current external_function statement
int current_external_function_id;
//! Temporary storage for option list provided to external_function()
@ -411,6 +413,10 @@ public:
void estimation_data();
//! Sets the prior for a parameter
void set_prior(string *arg1, string *arg2);
//! Sets the joint prior for a set of parameters
void set_joint_prior(vector<string *>*symbol_vec);
//! Adds a parameters to the list of joint parameters
void add_joint_parameter(string *name);
//! Adds the variance option to its temporary holding place
void set_prior_variance(expr_t variance=NULL);
//! Copies the prior from_name to_name

View File

@ -35,4 +35,6 @@ alpha.options(init=1);
rho.options(init=1);
beta.options(init=0.2);
std(u).options(init=3);
corr(y,c).options(init=.02);
corr(y,c).options(init=.02);
[alpha , beta , rho].prior(shape=beta, mean=[2 3 4], variance=[[1 2 3],[2 3 4]]);