Refactor handling of var_expectation_model statements
Creates a VarExpectationModelTable analogous to PacModelTable.last-simulation-period
parent
e1e5118373
commit
aa0e06bc7d
|
@ -5102,129 +5102,6 @@ GenerateIRFsStatement::writeJsonOutput(ostream &output) const
|
|||
output << "}";
|
||||
}
|
||||
|
||||
VarExpectationModelStatement::VarExpectationModelStatement(string model_name_arg,
|
||||
expr_t expression_arg,
|
||||
string aux_model_name_arg,
|
||||
string horizon_arg,
|
||||
expr_t discount_arg,
|
||||
int time_shift_arg,
|
||||
const SymbolTable &symbol_table_arg) :
|
||||
model_name{move(model_name_arg)}, expression{expression_arg},
|
||||
aux_model_name{move(aux_model_name_arg)}, horizon{move(horizon_arg)},
|
||||
discount{discount_arg}, time_shift{time_shift_arg}, symbol_table{symbol_table_arg}
|
||||
{
|
||||
}
|
||||
|
||||
void
|
||||
VarExpectationModelStatement::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, ExprNode::subst_table_t &subst_table)
|
||||
{
|
||||
vector<BinaryOpNode *> neweqs;
|
||||
expression = expression->substituteUnaryOpNodes(nodes, subst_table, neweqs);
|
||||
if (neweqs.size() > 0)
|
||||
{
|
||||
cerr << "ERROR: the 'expression' option of var_expectation_model contains a variable with a unary operator that is not present in the VAR model" << endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
VarExpectationModelStatement::substituteDiff(const lag_equivalence_table_t &nodes, ExprNode::subst_table_t &subst_table)
|
||||
{
|
||||
vector<BinaryOpNode *> neweqs;
|
||||
expression = expression->substituteDiff(nodes, subst_table, neweqs);
|
||||
if (neweqs.size() > 0)
|
||||
{
|
||||
cerr << "ERROR: the 'expression' option of var_expectation_model contains a diff'd variable that is not present in the VAR model" << endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
VarExpectationModelStatement::matchExpression()
|
||||
{
|
||||
try
|
||||
{
|
||||
auto vpc = expression->matchLinearCombinationOfVariables();
|
||||
for (const auto &it : vpc)
|
||||
{
|
||||
if (get<1>(it) != 0)
|
||||
throw ExprNode::MatchFailureException{"lead/lags are not allowed"};
|
||||
if (symbol_table.getType(get<0>(it)) != SymbolType::endogenous)
|
||||
throw ExprNode::MatchFailureException{"Variable is not an endogenous"};
|
||||
vars_params_constants.emplace_back(get<0>(it), get<2>(it), get<3>(it));
|
||||
}
|
||||
}
|
||||
catch (ExprNode::MatchFailureException &e)
|
||||
{
|
||||
cerr << "ERROR: expression in var_expectation_model is not of the expected form: " << e.message << endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
VarExpectationModelStatement::writeOutput(ostream &output, const string &basename, bool minimal_workspace) const
|
||||
{
|
||||
string mstruct = "M_.var_expectation." + model_name;
|
||||
output << mstruct << ".auxiliary_model_name = '" << aux_model_name << "';" << endl
|
||||
<< mstruct << ".horizon = " << horizon << ';' << endl
|
||||
<< mstruct << ".time_shift = " << time_shift << ';' << endl;
|
||||
|
||||
if (!vars_params_constants.size())
|
||||
{
|
||||
cerr << "ERROR: VarExpectationModelStatement::writeOutput: matchExpression() has not been called" << endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
ostringstream vars_list, params_list, constants_list;
|
||||
for (auto it = vars_params_constants.begin(); it != vars_params_constants.end(); ++it)
|
||||
{
|
||||
if (it != vars_params_constants.begin())
|
||||
{
|
||||
vars_list << ", ";
|
||||
params_list << ", ";
|
||||
constants_list << ", ";
|
||||
}
|
||||
vars_list << symbol_table.getTypeSpecificID(get<0>(*it))+1;
|
||||
if (get<1>(*it) == -1)
|
||||
params_list << "NaN";
|
||||
else
|
||||
params_list << symbol_table.getTypeSpecificID(get<1>(*it))+1;
|
||||
constants_list << get<2>(*it);
|
||||
}
|
||||
output << mstruct << ".expr.vars = [ " << vars_list.str() << " ];" << endl
|
||||
<< mstruct << ".expr.params = [ " << params_list.str() << " ];" << endl
|
||||
<< mstruct << ".expr.constants = [ " << constants_list.str() << " ];" << endl;
|
||||
|
||||
if (auto disc_var = dynamic_cast<const VariableNode *>(discount);
|
||||
disc_var)
|
||||
output << mstruct << ".discount_index = " << symbol_table.getTypeSpecificID(disc_var->symb_id) + 1 << ';' << endl;
|
||||
else
|
||||
{
|
||||
output << mstruct << ".discount_value = ";
|
||||
discount->writeOutput(output);
|
||||
output << ';' << endl;
|
||||
}
|
||||
output << mstruct << ".param_indices = [ ";
|
||||
for (int param_id : aux_params_ids)
|
||||
output << symbol_table.getTypeSpecificID(param_id)+1 << ' ';
|
||||
output << "];" << endl;
|
||||
}
|
||||
|
||||
void
|
||||
VarExpectationModelStatement::writeJsonOutput(ostream &output) const
|
||||
{
|
||||
output << R"({"statementName": "var_expectation_model",)"
|
||||
<< R"("model_name": ")" << model_name << R"(", )"
|
||||
<< R"("expression": ")";
|
||||
expression->writeOutput(output);
|
||||
output << R"(", )"
|
||||
<< R"("auxiliary_model_name": ")" << aux_model_name << R"(", )"
|
||||
<< R"("horizon": ")" << horizon << R"(", )"
|
||||
<< R"("discount": ")";
|
||||
discount->writeOutput(output);
|
||||
output << R"("})";
|
||||
}
|
||||
|
||||
MatchedMomentsStatement::MatchedMomentsStatement(const SymbolTable &symbol_table_arg,
|
||||
vector<tuple<vector<int>, vector<int>, vector<int>>> moments_arg) :
|
||||
symbol_table{symbol_table_arg}, moments{move(moments_arg)}
|
||||
|
|
|
@ -1211,35 +1211,6 @@ public:
|
|||
void writeJsonOutput(ostream &output) const override;
|
||||
};
|
||||
|
||||
class VarExpectationModelStatement : public Statement
|
||||
{
|
||||
public:
|
||||
const string model_name;
|
||||
private:
|
||||
expr_t expression;
|
||||
public:
|
||||
const string aux_model_name, horizon;
|
||||
const expr_t discount;
|
||||
const int time_shift;
|
||||
const SymbolTable &symbol_table;
|
||||
// List of generated auxiliary param ids, in variable-major order
|
||||
vector<int> aux_params_ids; // TODO: move this to some new VarModelTable object
|
||||
private:
|
||||
vector<tuple<int, int, double>> vars_params_constants;
|
||||
public:
|
||||
VarExpectationModelStatement(string model_name_arg, expr_t expression_arg, string aux_model_name_arg,
|
||||
string horizon_arg, expr_t discount_arg, int time_shift_arg,
|
||||
const SymbolTable &symbol_table_arg);
|
||||
void substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, ExprNode::subst_table_t &subst_table);
|
||||
void substituteDiff(const lag_equivalence_table_t &diff_table, ExprNode::subst_table_t &subst_table);
|
||||
// Analyzes the linear combination contained in the 'expression' option
|
||||
/* Must be called after substituteUnaryOpNodes() and substituteDiff() (in
|
||||
that order) */
|
||||
void matchExpression();
|
||||
void writeOutput(ostream &output, const string &basename, bool minimal_workspace) const override;
|
||||
void writeJsonOutput(ostream &output) const override;
|
||||
};
|
||||
|
||||
class MatchedMomentsStatement : public Statement
|
||||
{
|
||||
private:
|
||||
|
|
|
@ -5587,15 +5587,15 @@ DynamicModel::findPacExpectationEquationNumbers() const
|
|||
}
|
||||
|
||||
pair<lag_equivalence_table_t, ExprNode::subst_table_t>
|
||||
DynamicModel::substituteUnaryOps(PacModelTable &pac_model_table)
|
||||
DynamicModel::substituteUnaryOps(VarExpectationModelTable &var_expectation_model_table, PacModelTable &pac_model_table)
|
||||
{
|
||||
vector<int> eqnumbers(equations.size());
|
||||
iota(eqnumbers.begin(), eqnumbers.end(), 0);
|
||||
return substituteUnaryOps(set<int>(eqnumbers.begin(), eqnumbers.end()), pac_model_table);
|
||||
return substituteUnaryOps(set<int>(eqnumbers.begin(), eqnumbers.end()), var_expectation_model_table, pac_model_table);
|
||||
}
|
||||
|
||||
pair<lag_equivalence_table_t, ExprNode::subst_table_t>
|
||||
DynamicModel::substituteUnaryOps(const set<int> &eqnumbers, PacModelTable &pac_model_table)
|
||||
DynamicModel::substituteUnaryOps(const set<int> &eqnumbers, VarExpectationModelTable &var_expectation_model_table, PacModelTable &pac_model_table)
|
||||
{
|
||||
lag_equivalence_table_t nodes;
|
||||
ExprNode::subst_table_t subst_table;
|
||||
|
@ -5625,6 +5625,8 @@ DynamicModel::substituteUnaryOps(const set<int> &eqnumbers, PacModelTable &pac_m
|
|||
equations[eq] = substeq;
|
||||
}
|
||||
|
||||
// Substitute in expressions of var_expectation_model
|
||||
var_expectation_model_table.substituteUnaryOpsInExpression(nodes, subst_table, neweqs);
|
||||
// Substitute in growth terms in pac_model and pac_target_info
|
||||
pac_model_table.substituteUnaryOpsInGrowth(nodes, subst_table, neweqs);
|
||||
|
||||
|
@ -5642,7 +5644,7 @@ DynamicModel::substituteUnaryOps(const set<int> &eqnumbers, PacModelTable &pac_m
|
|||
}
|
||||
|
||||
pair<lag_equivalence_table_t, ExprNode::subst_table_t>
|
||||
DynamicModel::substituteDiff(PacModelTable &pac_model_table)
|
||||
DynamicModel::substituteDiff(VarExpectationModelTable &var_expectation_model_table, PacModelTable &pac_model_table)
|
||||
{
|
||||
/* Note: at this point, we know that there is no diff operator with a lead,
|
||||
because they have been expanded by DataTree::AddDiff().
|
||||
|
@ -5680,6 +5682,7 @@ DynamicModel::substituteDiff(PacModelTable &pac_model_table)
|
|||
equation = substeq;
|
||||
}
|
||||
|
||||
var_expectation_model_table.substituteDiffNodesInExpression(diff_nodes, diff_subst_table, neweqs);
|
||||
pac_model_table.substituteDiffNodesInGrowth(diff_nodes, diff_subst_table, neweqs);
|
||||
|
||||
// Add new equations
|
||||
|
|
|
@ -521,15 +521,17 @@ public:
|
|||
void substituteModelLocalVariables();
|
||||
|
||||
/* Creates aux vars for all unary operators in all equations. Also makes the
|
||||
substitution in growth terms of pac_model/pac_target_info. */
|
||||
pair<lag_equivalence_table_t, ExprNode::subst_table_t> substituteUnaryOps(PacModelTable &pac_model_table);
|
||||
substitution in growth terms of pac_model/pac_target_info and in
|
||||
expressions of var_expectation_model. */
|
||||
pair<lag_equivalence_table_t, ExprNode::subst_table_t> substituteUnaryOps(VarExpectationModelTable &var_expectation_model_table, PacModelTable &pac_model_table);
|
||||
|
||||
/* Creates aux vars for all unary operators in specified equations. Also makes the
|
||||
substitution in growth terms of pac_model/pac_target_info. */
|
||||
pair<lag_equivalence_table_t, ExprNode::subst_table_t> substituteUnaryOps(const set<int> &eqnumbers, PacModelTable &pac_model_table);
|
||||
substitution in growth terms of pac_model/pac_target_info and in
|
||||
expressions of var_expectation_model. */
|
||||
pair<lag_equivalence_table_t, ExprNode::subst_table_t> substituteUnaryOps(const set<int> &eqnumbers, VarExpectationModelTable &var_expectation_model_table, PacModelTable &pac_model_table);
|
||||
|
||||
//! Substitutes diff operator
|
||||
pair<lag_equivalence_table_t, ExprNode::subst_table_t> substituteDiff(PacModelTable &pac_model_table);
|
||||
pair<lag_equivalence_table_t, ExprNode::subst_table_t> substituteDiff(VarExpectationModelTable &var_expectation_model_table, PacModelTable &pac_model_table);
|
||||
|
||||
//! Substitute VarExpectation operators
|
||||
void substituteVarExpectation(const map<string, expr_t> &subst_table);
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
ModFile::ModFile(WarningConsolidation &warnings_arg)
|
||||
: var_model_table{symbol_table},
|
||||
trend_component_model_table{symbol_table},
|
||||
var_expectation_model_table{symbol_table},
|
||||
pac_model_table{symbol_table},
|
||||
expressions_tree{symbol_table, num_constants, external_functions_table},
|
||||
original_model{symbol_table, num_constants, external_functions_table,
|
||||
|
@ -430,13 +431,13 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, bool
|
|||
lag_equivalence_table_t unary_ops_nodes;
|
||||
ExprNode::subst_table_t unary_ops_subst_table;
|
||||
if (transform_unary_ops)
|
||||
tie(unary_ops_nodes, unary_ops_subst_table) = dynamic_model.substituteUnaryOps(pac_model_table);
|
||||
tie(unary_ops_nodes, unary_ops_subst_table) = dynamic_model.substituteUnaryOps(var_expectation_model_table, pac_model_table);
|
||||
else
|
||||
// substitute only those unary ops that appear in VAR, TCM and PAC model equations
|
||||
tie(unary_ops_nodes, unary_ops_subst_table) = dynamic_model.substituteUnaryOps(unary_ops_eqs, pac_model_table);
|
||||
tie(unary_ops_nodes, unary_ops_subst_table) = dynamic_model.substituteUnaryOps(unary_ops_eqs, var_expectation_model_table, pac_model_table);
|
||||
|
||||
// Create auxiliary variable and equations for Diff operators
|
||||
auto [diff_nodes, diff_subst_table] = dynamic_model.substituteDiff(pac_model_table);
|
||||
auto [diff_nodes, diff_subst_table] = dynamic_model.substituteDiff(var_expectation_model_table, pac_model_table);
|
||||
|
||||
// Fill trend component and VAR model tables
|
||||
dynamic_model.fillTrendComponentModelTable();
|
||||
|
@ -445,6 +446,10 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, bool
|
|||
dynamic_model.fillVarModelTable();
|
||||
original_model.fillVarModelTableFromOrigModel();
|
||||
|
||||
// VAR expectation models
|
||||
var_expectation_model_table.transformPass(diff_subst_table, dynamic_model, var_model_table,
|
||||
trend_component_model_table);
|
||||
|
||||
// PAC model
|
||||
pac_model_table.transformPass(unary_ops_nodes, unary_ops_subst_table,
|
||||
diff_nodes, diff_subst_table,
|
||||
|
@ -494,77 +499,6 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, bool
|
|||
mod_file_struct.ramsey_eq_nbr = dynamic_model.equation_number() - mod_file_struct.orig_eq_nbr;
|
||||
}
|
||||
|
||||
/* Handle var_expectation_model statements: collect information about them,
|
||||
create the new corresponding parameters, and the expressions to replace
|
||||
the var_expectation statements.
|
||||
TODO: move information collection to checkPass(), within a new
|
||||
VarModelTable class */
|
||||
map<string, expr_t> var_expectation_subst_table;
|
||||
for (auto &statement : statements)
|
||||
{
|
||||
auto vems = dynamic_cast<VarExpectationModelStatement *>(statement.get());
|
||||
if (!vems)
|
||||
continue;
|
||||
|
||||
int max_lag;
|
||||
vector<int> lhs;
|
||||
auto &model_name = vems->model_name;
|
||||
if (var_model_table.isExistingVarModelName(vems->aux_model_name))
|
||||
{
|
||||
max_lag = var_model_table.getMaxLag(vems->aux_model_name);
|
||||
lhs = var_model_table.getLhs(vems->aux_model_name);
|
||||
}
|
||||
else if (trend_component_model_table.isExistingTrendComponentModelName(vems->aux_model_name))
|
||||
{
|
||||
max_lag = trend_component_model_table.getMaxLag(vems->aux_model_name) + 1;
|
||||
lhs = dynamic_model.getUndiffLHSForPac(vems->aux_model_name, diff_subst_table);
|
||||
}
|
||||
else
|
||||
{
|
||||
cerr << "ERROR: var_expectation_model " << model_name
|
||||
<< " refers to nonexistent auxiliary model " << vems->aux_model_name << endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
/* Substitute unary and diff operators in the 'expression' option, then
|
||||
match the linear combination in the expression option */
|
||||
vems->substituteUnaryOpNodes(unary_ops_nodes, unary_ops_subst_table);
|
||||
vems->substituteDiff(diff_nodes, diff_subst_table);
|
||||
vems->matchExpression();
|
||||
|
||||
/* Create auxiliary parameters and the expression to be substituted into
|
||||
the var_expectations statement */
|
||||
expr_t subst_expr = dynamic_model.Zero;
|
||||
if (var_model_table.isExistingVarModelName(vems->aux_model_name))
|
||||
{
|
||||
/* If the auxiliary model is a VAR, add a parameter corresponding to
|
||||
the constant. */
|
||||
string constant_param_name = "var_expectation_model_" + model_name + "_constant";
|
||||
int constant_param_id = symbol_table.addSymbol(constant_param_name, SymbolType::parameter);
|
||||
vems->aux_params_ids.push_back(constant_param_id);
|
||||
subst_expr = dynamic_model.AddPlus(subst_expr, dynamic_model.AddVariable(constant_param_id));
|
||||
}
|
||||
for (int lag = 0; lag < max_lag; lag++)
|
||||
for (auto variable : lhs)
|
||||
{
|
||||
string param_name = "var_expectation_model_" + model_name + '_' + symbol_table.getName(variable) + '_' + to_string(lag);
|
||||
int new_param_id = symbol_table.addSymbol(param_name, SymbolType::parameter);
|
||||
vems->aux_params_ids.push_back(new_param_id);
|
||||
|
||||
subst_expr = dynamic_model.AddPlus(subst_expr,
|
||||
dynamic_model.AddTimes(dynamic_model.AddVariable(new_param_id),
|
||||
dynamic_model.AddVariable(variable, -lag + vems->time_shift)));
|
||||
}
|
||||
|
||||
if (var_expectation_subst_table.find(model_name) != var_expectation_subst_table.end())
|
||||
{
|
||||
cerr << "ERROR: model name '" << model_name << "' is used by several var_expectation_model statements" << endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
var_expectation_subst_table[model_name] = subst_expr;
|
||||
}
|
||||
// And finally perform the substitutions
|
||||
dynamic_model.substituteVarExpectation(var_expectation_subst_table);
|
||||
dynamic_model.createVariableMapping();
|
||||
|
||||
/* Create auxiliary vars for leads and lags greater than 2, on both endos and
|
||||
|
@ -929,6 +863,7 @@ ModFile::writeMOutput(const string &basename, bool clear_all, bool clear_global,
|
|||
|
||||
var_model_table.writeOutput(basename, mOutputFile);
|
||||
trend_component_model_table.writeOutput(basename, mOutputFile);
|
||||
var_expectation_model_table.writeOutput(basename, mOutputFile);
|
||||
pac_model_table.writeOutput(basename, mOutputFile);
|
||||
|
||||
// Initialize M_.Sigma_e, M_.Correlation_matrix, M_.H, and M_.Correlation_matrix_ME
|
||||
|
@ -1223,6 +1158,12 @@ ModFile::writeJsonOutputParsingCheck(const string &basename, JsonFileOutputType
|
|||
output << ", ";
|
||||
}
|
||||
|
||||
if (!var_expectation_model_table.empty())
|
||||
{
|
||||
var_expectation_model_table.writeJsonOutput(output);
|
||||
output << ", ";
|
||||
}
|
||||
|
||||
if (!pac_model_table.empty())
|
||||
{
|
||||
pac_model_table.writeJsonOutput(output);
|
||||
|
|
|
@ -55,6 +55,8 @@ public:
|
|||
VarModelTable var_model_table;
|
||||
//! Trend Component Model Table used for storing info about trend component models
|
||||
TrendComponentModelTable trend_component_model_table;
|
||||
//! Table for storing the models declared with var_expectation_model
|
||||
VarExpectationModelTable var_expectation_model_table;
|
||||
//! PAC Model Table used for storing info about pac models
|
||||
PacModelTable pac_model_table;
|
||||
//! Expressions outside model block
|
||||
|
|
|
@ -3457,10 +3457,9 @@ ParsingDriver::var_expectation_model()
|
|||
if (time_shift > 0)
|
||||
error("The 'time_shift' option must be a non-positive integer");
|
||||
|
||||
mod_file->addStatement(make_unique<VarExpectationModelStatement>(model_name, var_expectation_model_expression,
|
||||
var_model_name, horizon,
|
||||
var_expectation_model_discount, time_shift,
|
||||
mod_file->symbol_table));
|
||||
mod_file->var_expectation_model_table.addVarExpectationModel(model_name, var_expectation_model_expression,
|
||||
var_model_name, horizon,
|
||||
var_expectation_model_discount, time_shift);
|
||||
|
||||
options_list.clear();
|
||||
var_expectation_model_discount = nullptr;
|
||||
|
|
209
src/SubModel.cc
209
src/SubModel.cc
|
@ -754,6 +754,215 @@ VarModelTable::getLhsExprT(const string &name_arg) const
|
|||
return lhs_expr_t.find(name_arg)->second;
|
||||
}
|
||||
|
||||
|
||||
VarExpectationModelTable::VarExpectationModelTable(SymbolTable &symbol_table_arg) :
|
||||
symbol_table{symbol_table_arg}
|
||||
{
|
||||
}
|
||||
|
||||
void
|
||||
VarExpectationModelTable::addVarExpectationModel(string name_arg, expr_t expression_arg, string aux_model_name_arg, string horizon_arg, expr_t discount_arg, int time_shift_arg)
|
||||
{
|
||||
if (isExistingVarExpectationModelName(name_arg))
|
||||
{
|
||||
cerr << "Error: a var_expectation_model already exists with the name " << name_arg << endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
expression[name_arg] = expression_arg;
|
||||
aux_model_name[name_arg] = move(aux_model_name_arg);
|
||||
horizon[name_arg] = move(horizon_arg);
|
||||
discount[name_arg] = discount_arg;
|
||||
time_shift[name_arg] = time_shift_arg;
|
||||
names.insert(move(name_arg));
|
||||
}
|
||||
|
||||
bool
|
||||
VarExpectationModelTable::isExistingVarExpectationModelName(const string &name_arg) const
|
||||
{
|
||||
return names.find(name_arg) != names.end();
|
||||
}
|
||||
|
||||
bool
|
||||
VarExpectationModelTable::empty() const
|
||||
{
|
||||
return names.empty();
|
||||
}
|
||||
|
||||
void
|
||||
VarExpectationModelTable::writeOutput(const string &basename, ostream &output) const
|
||||
{
|
||||
for (const auto &name : names)
|
||||
{
|
||||
string mstruct = "M_.var_expectation." + name;
|
||||
output << mstruct << ".auxiliary_model_name = '" << aux_model_name.at(name) << "';" << endl
|
||||
<< mstruct << ".horizon = " << horizon.at(name) << ';' << endl
|
||||
<< mstruct << ".time_shift = " << time_shift.at(name) << ';' << endl;
|
||||
|
||||
auto &vpc = vars_params_constants.at(name);
|
||||
if (!vpc.size())
|
||||
{
|
||||
cerr << "ERROR: VarExpectationModelStatement::writeOutput: matchExpression() has not been called" << endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
ostringstream vars_list, params_list, constants_list;
|
||||
for (auto it = vpc.begin(); it != vpc.end(); ++it)
|
||||
{
|
||||
if (it != vpc.begin())
|
||||
{
|
||||
vars_list << ", ";
|
||||
params_list << ", ";
|
||||
constants_list << ", ";
|
||||
}
|
||||
vars_list << symbol_table.getTypeSpecificID(get<0>(*it))+1;
|
||||
if (get<1>(*it) == -1)
|
||||
params_list << "NaN";
|
||||
else
|
||||
params_list << symbol_table.getTypeSpecificID(get<1>(*it))+1;
|
||||
constants_list << get<2>(*it);
|
||||
}
|
||||
output << mstruct << ".expr.vars = [ " << vars_list.str() << " ];" << endl
|
||||
<< mstruct << ".expr.params = [ " << params_list.str() << " ];" << endl
|
||||
<< mstruct << ".expr.constants = [ " << constants_list.str() << " ];" << endl;
|
||||
|
||||
if (auto disc_var = dynamic_cast<const VariableNode *>(discount.at(name));
|
||||
disc_var)
|
||||
output << mstruct << ".discount_index = " << symbol_table.getTypeSpecificID(disc_var->symb_id) + 1 << ';' << endl;
|
||||
else
|
||||
{
|
||||
output << mstruct << ".discount_value = ";
|
||||
discount.at(name)->writeOutput(output);
|
||||
output << ';' << endl;
|
||||
}
|
||||
output << mstruct << ".param_indices = [ ";
|
||||
for (int param_id : aux_param_symb_ids.at(name))
|
||||
output << symbol_table.getTypeSpecificID(param_id)+1 << ' ';
|
||||
output << "];" << endl;
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
VarExpectationModelTable::substituteUnaryOpsInExpression(const lag_equivalence_table_t &nodes, ExprNode::subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs)
|
||||
{
|
||||
for (const auto &name : names)
|
||||
expression[name] = expression[name]->substituteUnaryOpNodes(nodes, subst_table, neweqs);
|
||||
}
|
||||
|
||||
void
|
||||
VarExpectationModelTable::substituteDiffNodesInExpression(const lag_equivalence_table_t &nodes, ExprNode::subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs)
|
||||
{
|
||||
for (const auto &name : names)
|
||||
expression[name] = expression[name]->substituteDiff(nodes, subst_table, neweqs);
|
||||
}
|
||||
|
||||
void
|
||||
VarExpectationModelTable::transformPass(ExprNode::subst_table_t &diff_subst_table,
|
||||
DynamicModel &dynamic_model, const VarModelTable &var_model_table,
|
||||
const TrendComponentModelTable &trend_component_model_table)
|
||||
{
|
||||
map<string, expr_t> var_expectation_subst_table;
|
||||
|
||||
for (const auto &name : names)
|
||||
{
|
||||
// Collect information about the auxiliary model
|
||||
|
||||
int max_lag;
|
||||
vector<int> lhs;
|
||||
if (var_model_table.isExistingVarModelName(aux_model_name[name]))
|
||||
{
|
||||
max_lag = var_model_table.getMaxLag(aux_model_name[name]);
|
||||
lhs = var_model_table.getLhs(aux_model_name[name]);
|
||||
}
|
||||
else if (trend_component_model_table.isExistingTrendComponentModelName(aux_model_name[name]))
|
||||
{
|
||||
max_lag = trend_component_model_table.getMaxLag(aux_model_name[name]) + 1;
|
||||
lhs = dynamic_model.getUndiffLHSForPac(aux_model_name[name], diff_subst_table);
|
||||
}
|
||||
else
|
||||
{
|
||||
cerr << "ERROR: var_expectation_model " << name
|
||||
<< " refers to nonexistent auxiliary model " << aux_model_name[name] << endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
// Match the linear combination in the expression option
|
||||
try
|
||||
{
|
||||
auto vpc = expression[name]->matchLinearCombinationOfVariables();
|
||||
for (const auto &[variable_id, lag, param_id, constant] : vpc)
|
||||
{
|
||||
if (lag != 0)
|
||||
throw ExprNode::MatchFailureException{"lead/lags are not allowed"};
|
||||
if (symbol_table.getType(variable_id) != SymbolType::endogenous)
|
||||
throw ExprNode::MatchFailureException{"Variable is not an endogenous"};
|
||||
vars_params_constants[name].emplace_back(variable_id, param_id, constant);
|
||||
}
|
||||
}
|
||||
catch (ExprNode::MatchFailureException &e)
|
||||
{
|
||||
cerr << "ERROR: expression in var_expectation_model " << name << " is not of the expected form: " << e.message << endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
/* Create auxiliary parameters and the expression to be substituted into
|
||||
the var_expectations statement */
|
||||
expr_t subst_expr = dynamic_model.Zero;
|
||||
if (var_model_table.isExistingVarModelName(aux_model_name[name]))
|
||||
{
|
||||
/* If the auxiliary model is a VAR, add a parameter corresponding to
|
||||
the constant. */
|
||||
string constant_param_name = "var_expectation_model_" + name + "_constant";
|
||||
int constant_param_id = symbol_table.addSymbol(constant_param_name, SymbolType::parameter);
|
||||
aux_param_symb_ids[name].push_back(constant_param_id);
|
||||
subst_expr = dynamic_model.AddPlus(subst_expr, dynamic_model.AddVariable(constant_param_id));
|
||||
}
|
||||
for (int lag = 0; lag < max_lag; lag++)
|
||||
for (auto variable : lhs)
|
||||
{
|
||||
string param_name = "var_expectation_model_" + name + '_' + symbol_table.getName(variable) + '_' + to_string(lag);
|
||||
int new_param_id = symbol_table.addSymbol(param_name, SymbolType::parameter);
|
||||
aux_param_symb_ids[name].push_back(new_param_id);
|
||||
|
||||
subst_expr = dynamic_model.AddPlus(subst_expr,
|
||||
dynamic_model.AddTimes(dynamic_model.AddVariable(new_param_id),
|
||||
dynamic_model.AddVariable(variable, -lag + time_shift[name])));
|
||||
}
|
||||
|
||||
if (var_expectation_subst_table.find(name) != var_expectation_subst_table.end())
|
||||
{
|
||||
cerr << "ERROR: model name '" << name << "' is used by several var_expectation_model statements" << endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
var_expectation_subst_table[name] = subst_expr;
|
||||
}
|
||||
|
||||
// Actually substitute var_expectation statements
|
||||
dynamic_model.substituteVarExpectation(var_expectation_subst_table);
|
||||
/* At this point, we know that all var_expectation operators have been
|
||||
substituted, because of the error check performed in
|
||||
VarExpectationNode::substituteVarExpectation(). */
|
||||
}
|
||||
|
||||
void
|
||||
VarExpectationModelTable::writeJsonOutput(ostream &output) const
|
||||
{
|
||||
for (const auto &name : names)
|
||||
{
|
||||
output << R"({"statementName": "var_expectation_model",)"
|
||||
<< R"("model_name": ")" << name << R"(", )"
|
||||
<< R"("expression": ")";
|
||||
expression.at(name)->writeOutput(output);
|
||||
output << R"(", )"
|
||||
<< R"("auxiliary_model_name": ")" << aux_model_name.at(name) << R"(", )"
|
||||
<< R"("horizon": ")" << horizon.at(name) << R"(", )"
|
||||
<< R"("discount": ")";
|
||||
discount.at(name)->writeOutput(output);
|
||||
output << R"("})";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
PacModelTable::PacModelTable(SymbolTable &symbol_table_arg) :
|
||||
symbol_table{symbol_table_arg}
|
||||
{
|
||||
|
|
|
@ -190,6 +190,36 @@ VarModelTable::empty() const
|
|||
return names.empty();
|
||||
}
|
||||
|
||||
class VarExpectationModelTable
|
||||
{
|
||||
private:
|
||||
SymbolTable &symbol_table;
|
||||
set<string> names;
|
||||
map<string, expr_t> expression;
|
||||
map<string, string> aux_model_name;
|
||||
map<string, string> horizon;
|
||||
map<string, expr_t> discount;
|
||||
map<string, int> time_shift;
|
||||
// For each model, list of generated auxiliary param ids, in variable-major order
|
||||
map<string, vector<int>> aux_param_symb_ids;
|
||||
// Decomposition of the expression
|
||||
map<string, vector<tuple<int, int, double>>> vars_params_constants;
|
||||
public:
|
||||
explicit VarExpectationModelTable(SymbolTable &symbol_table_arg);
|
||||
void addVarExpectationModel(string name_arg, expr_t expression_arg, string aux_model_name_arg,
|
||||
string horizon_arg, expr_t discount_arg, int time_shift_arg);
|
||||
bool isExistingVarExpectationModelName(const string &name_arg) const;
|
||||
bool empty() const;
|
||||
void substituteUnaryOpsInExpression(const lag_equivalence_table_t &nodes, ExprNode::subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs);
|
||||
// Called by DynamicModel::substituteDiff()
|
||||
void substituteDiffNodesInExpression(const lag_equivalence_table_t &diff_nodes, ExprNode::subst_table_t &diff_subst_table, vector<BinaryOpNode *> &neweqs);
|
||||
void transformPass(ExprNode::subst_table_t &diff_subst_table,
|
||||
DynamicModel &dynamic_model, const VarModelTable &var_model_table,
|
||||
const TrendComponentModelTable &trend_component_model_table);
|
||||
void writeOutput(const string &basename, ostream &output) const;
|
||||
void writeJsonOutput(ostream &output) const;
|
||||
};
|
||||
|
||||
class PacModelTable
|
||||
{
|
||||
private:
|
||||
|
|
Loading…
Reference in New Issue