Refactor handling of var_expectation_model statements

Creates a VarExpectationModelTable analogous to PacModelTable.
last-simulation-period
Sébastien Villemot 2022-01-20 16:15:43 +01:00
parent e1e5118373
commit aa0e06bc7d
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
9 changed files with 273 additions and 239 deletions

View File

@ -5102,129 +5102,6 @@ GenerateIRFsStatement::writeJsonOutput(ostream &output) const
output << "}";
}
VarExpectationModelStatement::VarExpectationModelStatement(string model_name_arg,
expr_t expression_arg,
string aux_model_name_arg,
string horizon_arg,
expr_t discount_arg,
int time_shift_arg,
const SymbolTable &symbol_table_arg) :
model_name{move(model_name_arg)}, expression{expression_arg},
aux_model_name{move(aux_model_name_arg)}, horizon{move(horizon_arg)},
discount{discount_arg}, time_shift{time_shift_arg}, symbol_table{symbol_table_arg}
{
}
void
VarExpectationModelStatement::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, ExprNode::subst_table_t &subst_table)
{
vector<BinaryOpNode *> neweqs;
expression = expression->substituteUnaryOpNodes(nodes, subst_table, neweqs);
if (neweqs.size() > 0)
{
cerr << "ERROR: the 'expression' option of var_expectation_model contains a variable with a unary operator that is not present in the VAR model" << endl;
exit(EXIT_FAILURE);
}
}
void
VarExpectationModelStatement::substituteDiff(const lag_equivalence_table_t &nodes, ExprNode::subst_table_t &subst_table)
{
vector<BinaryOpNode *> neweqs;
expression = expression->substituteDiff(nodes, subst_table, neweqs);
if (neweqs.size() > 0)
{
cerr << "ERROR: the 'expression' option of var_expectation_model contains a diff'd variable that is not present in the VAR model" << endl;
exit(EXIT_FAILURE);
}
}
void
VarExpectationModelStatement::matchExpression()
{
try
{
auto vpc = expression->matchLinearCombinationOfVariables();
for (const auto &it : vpc)
{
if (get<1>(it) != 0)
throw ExprNode::MatchFailureException{"lead/lags are not allowed"};
if (symbol_table.getType(get<0>(it)) != SymbolType::endogenous)
throw ExprNode::MatchFailureException{"Variable is not an endogenous"};
vars_params_constants.emplace_back(get<0>(it), get<2>(it), get<3>(it));
}
}
catch (ExprNode::MatchFailureException &e)
{
cerr << "ERROR: expression in var_expectation_model is not of the expected form: " << e.message << endl;
exit(EXIT_FAILURE);
}
}
void
VarExpectationModelStatement::writeOutput(ostream &output, const string &basename, bool minimal_workspace) const
{
string mstruct = "M_.var_expectation." + model_name;
output << mstruct << ".auxiliary_model_name = '" << aux_model_name << "';" << endl
<< mstruct << ".horizon = " << horizon << ';' << endl
<< mstruct << ".time_shift = " << time_shift << ';' << endl;
if (!vars_params_constants.size())
{
cerr << "ERROR: VarExpectationModelStatement::writeOutput: matchExpression() has not been called" << endl;
exit(EXIT_FAILURE);
}
ostringstream vars_list, params_list, constants_list;
for (auto it = vars_params_constants.begin(); it != vars_params_constants.end(); ++it)
{
if (it != vars_params_constants.begin())
{
vars_list << ", ";
params_list << ", ";
constants_list << ", ";
}
vars_list << symbol_table.getTypeSpecificID(get<0>(*it))+1;
if (get<1>(*it) == -1)
params_list << "NaN";
else
params_list << symbol_table.getTypeSpecificID(get<1>(*it))+1;
constants_list << get<2>(*it);
}
output << mstruct << ".expr.vars = [ " << vars_list.str() << " ];" << endl
<< mstruct << ".expr.params = [ " << params_list.str() << " ];" << endl
<< mstruct << ".expr.constants = [ " << constants_list.str() << " ];" << endl;
if (auto disc_var = dynamic_cast<const VariableNode *>(discount);
disc_var)
output << mstruct << ".discount_index = " << symbol_table.getTypeSpecificID(disc_var->symb_id) + 1 << ';' << endl;
else
{
output << mstruct << ".discount_value = ";
discount->writeOutput(output);
output << ';' << endl;
}
output << mstruct << ".param_indices = [ ";
for (int param_id : aux_params_ids)
output << symbol_table.getTypeSpecificID(param_id)+1 << ' ';
output << "];" << endl;
}
void
VarExpectationModelStatement::writeJsonOutput(ostream &output) const
{
output << R"({"statementName": "var_expectation_model",)"
<< R"("model_name": ")" << model_name << R"(", )"
<< R"("expression": ")";
expression->writeOutput(output);
output << R"(", )"
<< R"("auxiliary_model_name": ")" << aux_model_name << R"(", )"
<< R"("horizon": ")" << horizon << R"(", )"
<< R"("discount": ")";
discount->writeOutput(output);
output << R"("})";
}
MatchedMomentsStatement::MatchedMomentsStatement(const SymbolTable &symbol_table_arg,
vector<tuple<vector<int>, vector<int>, vector<int>>> moments_arg) :
symbol_table{symbol_table_arg}, moments{move(moments_arg)}

View File

@ -1211,35 +1211,6 @@ public:
void writeJsonOutput(ostream &output) const override;
};
class VarExpectationModelStatement : public Statement
{
public:
const string model_name;
private:
expr_t expression;
public:
const string aux_model_name, horizon;
const expr_t discount;
const int time_shift;
const SymbolTable &symbol_table;
// List of generated auxiliary param ids, in variable-major order
vector<int> aux_params_ids; // TODO: move this to some new VarModelTable object
private:
vector<tuple<int, int, double>> vars_params_constants;
public:
VarExpectationModelStatement(string model_name_arg, expr_t expression_arg, string aux_model_name_arg,
string horizon_arg, expr_t discount_arg, int time_shift_arg,
const SymbolTable &symbol_table_arg);
void substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, ExprNode::subst_table_t &subst_table);
void substituteDiff(const lag_equivalence_table_t &diff_table, ExprNode::subst_table_t &subst_table);
// Analyzes the linear combination contained in the 'expression' option
/* Must be called after substituteUnaryOpNodes() and substituteDiff() (in
that order) */
void matchExpression();
void writeOutput(ostream &output, const string &basename, bool minimal_workspace) const override;
void writeJsonOutput(ostream &output) const override;
};
class MatchedMomentsStatement : public Statement
{
private:

View File

@ -5587,15 +5587,15 @@ DynamicModel::findPacExpectationEquationNumbers() const
}
pair<lag_equivalence_table_t, ExprNode::subst_table_t>
DynamicModel::substituteUnaryOps(PacModelTable &pac_model_table)
DynamicModel::substituteUnaryOps(VarExpectationModelTable &var_expectation_model_table, PacModelTable &pac_model_table)
{
vector<int> eqnumbers(equations.size());
iota(eqnumbers.begin(), eqnumbers.end(), 0);
return substituteUnaryOps(set<int>(eqnumbers.begin(), eqnumbers.end()), pac_model_table);
return substituteUnaryOps(set<int>(eqnumbers.begin(), eqnumbers.end()), var_expectation_model_table, pac_model_table);
}
pair<lag_equivalence_table_t, ExprNode::subst_table_t>
DynamicModel::substituteUnaryOps(const set<int> &eqnumbers, PacModelTable &pac_model_table)
DynamicModel::substituteUnaryOps(const set<int> &eqnumbers, VarExpectationModelTable &var_expectation_model_table, PacModelTable &pac_model_table)
{
lag_equivalence_table_t nodes;
ExprNode::subst_table_t subst_table;
@ -5625,6 +5625,8 @@ DynamicModel::substituteUnaryOps(const set<int> &eqnumbers, PacModelTable &pac_m
equations[eq] = substeq;
}
// Substitute in expressions of var_expectation_model
var_expectation_model_table.substituteUnaryOpsInExpression(nodes, subst_table, neweqs);
// Substitute in growth terms in pac_model and pac_target_info
pac_model_table.substituteUnaryOpsInGrowth(nodes, subst_table, neweqs);
@ -5642,7 +5644,7 @@ DynamicModel::substituteUnaryOps(const set<int> &eqnumbers, PacModelTable &pac_m
}
pair<lag_equivalence_table_t, ExprNode::subst_table_t>
DynamicModel::substituteDiff(PacModelTable &pac_model_table)
DynamicModel::substituteDiff(VarExpectationModelTable &var_expectation_model_table, PacModelTable &pac_model_table)
{
/* Note: at this point, we know that there is no diff operator with a lead,
because they have been expanded by DataTree::AddDiff().
@ -5680,6 +5682,7 @@ DynamicModel::substituteDiff(PacModelTable &pac_model_table)
equation = substeq;
}
var_expectation_model_table.substituteDiffNodesInExpression(diff_nodes, diff_subst_table, neweqs);
pac_model_table.substituteDiffNodesInGrowth(diff_nodes, diff_subst_table, neweqs);
// Add new equations

View File

@ -521,15 +521,17 @@ public:
void substituteModelLocalVariables();
/* Creates aux vars for all unary operators in all equations. Also makes the
substitution in growth terms of pac_model/pac_target_info. */
pair<lag_equivalence_table_t, ExprNode::subst_table_t> substituteUnaryOps(PacModelTable &pac_model_table);
substitution in growth terms of pac_model/pac_target_info and in
expressions of var_expectation_model. */
pair<lag_equivalence_table_t, ExprNode::subst_table_t> substituteUnaryOps(VarExpectationModelTable &var_expectation_model_table, PacModelTable &pac_model_table);
/* Creates aux vars for all unary operators in specified equations. Also makes the
substitution in growth terms of pac_model/pac_target_info. */
pair<lag_equivalence_table_t, ExprNode::subst_table_t> substituteUnaryOps(const set<int> &eqnumbers, PacModelTable &pac_model_table);
substitution in growth terms of pac_model/pac_target_info and in
expressions of var_expectation_model. */
pair<lag_equivalence_table_t, ExprNode::subst_table_t> substituteUnaryOps(const set<int> &eqnumbers, VarExpectationModelTable &var_expectation_model_table, PacModelTable &pac_model_table);
//! Substitutes diff operator
pair<lag_equivalence_table_t, ExprNode::subst_table_t> substituteDiff(PacModelTable &pac_model_table);
pair<lag_equivalence_table_t, ExprNode::subst_table_t> substituteDiff(VarExpectationModelTable &var_expectation_model_table, PacModelTable &pac_model_table);
//! Substitute VarExpectation operators
void substituteVarExpectation(const map<string, expr_t> &subst_table);

View File

@ -34,6 +34,7 @@
ModFile::ModFile(WarningConsolidation &warnings_arg)
: var_model_table{symbol_table},
trend_component_model_table{symbol_table},
var_expectation_model_table{symbol_table},
pac_model_table{symbol_table},
expressions_tree{symbol_table, num_constants, external_functions_table},
original_model{symbol_table, num_constants, external_functions_table,
@ -430,13 +431,13 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, bool
lag_equivalence_table_t unary_ops_nodes;
ExprNode::subst_table_t unary_ops_subst_table;
if (transform_unary_ops)
tie(unary_ops_nodes, unary_ops_subst_table) = dynamic_model.substituteUnaryOps(pac_model_table);
tie(unary_ops_nodes, unary_ops_subst_table) = dynamic_model.substituteUnaryOps(var_expectation_model_table, pac_model_table);
else
// substitute only those unary ops that appear in VAR, TCM and PAC model equations
tie(unary_ops_nodes, unary_ops_subst_table) = dynamic_model.substituteUnaryOps(unary_ops_eqs, pac_model_table);
tie(unary_ops_nodes, unary_ops_subst_table) = dynamic_model.substituteUnaryOps(unary_ops_eqs, var_expectation_model_table, pac_model_table);
// Create auxiliary variable and equations for Diff operators
auto [diff_nodes, diff_subst_table] = dynamic_model.substituteDiff(pac_model_table);
auto [diff_nodes, diff_subst_table] = dynamic_model.substituteDiff(var_expectation_model_table, pac_model_table);
// Fill trend component and VAR model tables
dynamic_model.fillTrendComponentModelTable();
@ -445,6 +446,10 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, bool
dynamic_model.fillVarModelTable();
original_model.fillVarModelTableFromOrigModel();
// VAR expectation models
var_expectation_model_table.transformPass(diff_subst_table, dynamic_model, var_model_table,
trend_component_model_table);
// PAC model
pac_model_table.transformPass(unary_ops_nodes, unary_ops_subst_table,
diff_nodes, diff_subst_table,
@ -494,77 +499,6 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, bool
mod_file_struct.ramsey_eq_nbr = dynamic_model.equation_number() - mod_file_struct.orig_eq_nbr;
}
/* Handle var_expectation_model statements: collect information about them,
create the new corresponding parameters, and the expressions to replace
the var_expectation statements.
TODO: move information collection to checkPass(), within a new
VarModelTable class */
map<string, expr_t> var_expectation_subst_table;
for (auto &statement : statements)
{
auto vems = dynamic_cast<VarExpectationModelStatement *>(statement.get());
if (!vems)
continue;
int max_lag;
vector<int> lhs;
auto &model_name = vems->model_name;
if (var_model_table.isExistingVarModelName(vems->aux_model_name))
{
max_lag = var_model_table.getMaxLag(vems->aux_model_name);
lhs = var_model_table.getLhs(vems->aux_model_name);
}
else if (trend_component_model_table.isExistingTrendComponentModelName(vems->aux_model_name))
{
max_lag = trend_component_model_table.getMaxLag(vems->aux_model_name) + 1;
lhs = dynamic_model.getUndiffLHSForPac(vems->aux_model_name, diff_subst_table);
}
else
{
cerr << "ERROR: var_expectation_model " << model_name
<< " refers to nonexistent auxiliary model " << vems->aux_model_name << endl;
exit(EXIT_FAILURE);
}
/* Substitute unary and diff operators in the 'expression' option, then
match the linear combination in the expression option */
vems->substituteUnaryOpNodes(unary_ops_nodes, unary_ops_subst_table);
vems->substituteDiff(diff_nodes, diff_subst_table);
vems->matchExpression();
/* Create auxiliary parameters and the expression to be substituted into
the var_expectations statement */
expr_t subst_expr = dynamic_model.Zero;
if (var_model_table.isExistingVarModelName(vems->aux_model_name))
{
/* If the auxiliary model is a VAR, add a parameter corresponding to
the constant. */
string constant_param_name = "var_expectation_model_" + model_name + "_constant";
int constant_param_id = symbol_table.addSymbol(constant_param_name, SymbolType::parameter);
vems->aux_params_ids.push_back(constant_param_id);
subst_expr = dynamic_model.AddPlus(subst_expr, dynamic_model.AddVariable(constant_param_id));
}
for (int lag = 0; lag < max_lag; lag++)
for (auto variable : lhs)
{
string param_name = "var_expectation_model_" + model_name + '_' + symbol_table.getName(variable) + '_' + to_string(lag);
int new_param_id = symbol_table.addSymbol(param_name, SymbolType::parameter);
vems->aux_params_ids.push_back(new_param_id);
subst_expr = dynamic_model.AddPlus(subst_expr,
dynamic_model.AddTimes(dynamic_model.AddVariable(new_param_id),
dynamic_model.AddVariable(variable, -lag + vems->time_shift)));
}
if (var_expectation_subst_table.find(model_name) != var_expectation_subst_table.end())
{
cerr << "ERROR: model name '" << model_name << "' is used by several var_expectation_model statements" << endl;
exit(EXIT_FAILURE);
}
var_expectation_subst_table[model_name] = subst_expr;
}
// And finally perform the substitutions
dynamic_model.substituteVarExpectation(var_expectation_subst_table);
dynamic_model.createVariableMapping();
/* Create auxiliary vars for leads and lags greater than 2, on both endos and
@ -929,6 +863,7 @@ ModFile::writeMOutput(const string &basename, bool clear_all, bool clear_global,
var_model_table.writeOutput(basename, mOutputFile);
trend_component_model_table.writeOutput(basename, mOutputFile);
var_expectation_model_table.writeOutput(basename, mOutputFile);
pac_model_table.writeOutput(basename, mOutputFile);
// Initialize M_.Sigma_e, M_.Correlation_matrix, M_.H, and M_.Correlation_matrix_ME
@ -1223,6 +1158,12 @@ ModFile::writeJsonOutputParsingCheck(const string &basename, JsonFileOutputType
output << ", ";
}
if (!var_expectation_model_table.empty())
{
var_expectation_model_table.writeJsonOutput(output);
output << ", ";
}
if (!pac_model_table.empty())
{
pac_model_table.writeJsonOutput(output);

View File

@ -55,6 +55,8 @@ public:
VarModelTable var_model_table;
//! Trend Component Model Table used for storing info about trend component models
TrendComponentModelTable trend_component_model_table;
//! Table for storing the models declared with var_expectation_model
VarExpectationModelTable var_expectation_model_table;
//! PAC Model Table used for storing info about pac models
PacModelTable pac_model_table;
//! Expressions outside model block

View File

@ -3457,10 +3457,9 @@ ParsingDriver::var_expectation_model()
if (time_shift > 0)
error("The 'time_shift' option must be a non-positive integer");
mod_file->addStatement(make_unique<VarExpectationModelStatement>(model_name, var_expectation_model_expression,
var_model_name, horizon,
var_expectation_model_discount, time_shift,
mod_file->symbol_table));
mod_file->var_expectation_model_table.addVarExpectationModel(model_name, var_expectation_model_expression,
var_model_name, horizon,
var_expectation_model_discount, time_shift);
options_list.clear();
var_expectation_model_discount = nullptr;

View File

@ -754,6 +754,215 @@ VarModelTable::getLhsExprT(const string &name_arg) const
return lhs_expr_t.find(name_arg)->second;
}
VarExpectationModelTable::VarExpectationModelTable(SymbolTable &symbol_table_arg) :
symbol_table{symbol_table_arg}
{
}
void
VarExpectationModelTable::addVarExpectationModel(string name_arg, expr_t expression_arg, string aux_model_name_arg, string horizon_arg, expr_t discount_arg, int time_shift_arg)
{
if (isExistingVarExpectationModelName(name_arg))
{
cerr << "Error: a var_expectation_model already exists with the name " << name_arg << endl;
exit(EXIT_FAILURE);
}
expression[name_arg] = expression_arg;
aux_model_name[name_arg] = move(aux_model_name_arg);
horizon[name_arg] = move(horizon_arg);
discount[name_arg] = discount_arg;
time_shift[name_arg] = time_shift_arg;
names.insert(move(name_arg));
}
bool
VarExpectationModelTable::isExistingVarExpectationModelName(const string &name_arg) const
{
return names.find(name_arg) != names.end();
}
bool
VarExpectationModelTable::empty() const
{
return names.empty();
}
void
VarExpectationModelTable::writeOutput(const string &basename, ostream &output) const
{
for (const auto &name : names)
{
string mstruct = "M_.var_expectation." + name;
output << mstruct << ".auxiliary_model_name = '" << aux_model_name.at(name) << "';" << endl
<< mstruct << ".horizon = " << horizon.at(name) << ';' << endl
<< mstruct << ".time_shift = " << time_shift.at(name) << ';' << endl;
auto &vpc = vars_params_constants.at(name);
if (!vpc.size())
{
cerr << "ERROR: VarExpectationModelStatement::writeOutput: matchExpression() has not been called" << endl;
exit(EXIT_FAILURE);
}
ostringstream vars_list, params_list, constants_list;
for (auto it = vpc.begin(); it != vpc.end(); ++it)
{
if (it != vpc.begin())
{
vars_list << ", ";
params_list << ", ";
constants_list << ", ";
}
vars_list << symbol_table.getTypeSpecificID(get<0>(*it))+1;
if (get<1>(*it) == -1)
params_list << "NaN";
else
params_list << symbol_table.getTypeSpecificID(get<1>(*it))+1;
constants_list << get<2>(*it);
}
output << mstruct << ".expr.vars = [ " << vars_list.str() << " ];" << endl
<< mstruct << ".expr.params = [ " << params_list.str() << " ];" << endl
<< mstruct << ".expr.constants = [ " << constants_list.str() << " ];" << endl;
if (auto disc_var = dynamic_cast<const VariableNode *>(discount.at(name));
disc_var)
output << mstruct << ".discount_index = " << symbol_table.getTypeSpecificID(disc_var->symb_id) + 1 << ';' << endl;
else
{
output << mstruct << ".discount_value = ";
discount.at(name)->writeOutput(output);
output << ';' << endl;
}
output << mstruct << ".param_indices = [ ";
for (int param_id : aux_param_symb_ids.at(name))
output << symbol_table.getTypeSpecificID(param_id)+1 << ' ';
output << "];" << endl;
}
}
void
VarExpectationModelTable::substituteUnaryOpsInExpression(const lag_equivalence_table_t &nodes, ExprNode::subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs)
{
for (const auto &name : names)
expression[name] = expression[name]->substituteUnaryOpNodes(nodes, subst_table, neweqs);
}
void
VarExpectationModelTable::substituteDiffNodesInExpression(const lag_equivalence_table_t &nodes, ExprNode::subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs)
{
for (const auto &name : names)
expression[name] = expression[name]->substituteDiff(nodes, subst_table, neweqs);
}
void
VarExpectationModelTable::transformPass(ExprNode::subst_table_t &diff_subst_table,
DynamicModel &dynamic_model, const VarModelTable &var_model_table,
const TrendComponentModelTable &trend_component_model_table)
{
map<string, expr_t> var_expectation_subst_table;
for (const auto &name : names)
{
// Collect information about the auxiliary model
int max_lag;
vector<int> lhs;
if (var_model_table.isExistingVarModelName(aux_model_name[name]))
{
max_lag = var_model_table.getMaxLag(aux_model_name[name]);
lhs = var_model_table.getLhs(aux_model_name[name]);
}
else if (trend_component_model_table.isExistingTrendComponentModelName(aux_model_name[name]))
{
max_lag = trend_component_model_table.getMaxLag(aux_model_name[name]) + 1;
lhs = dynamic_model.getUndiffLHSForPac(aux_model_name[name], diff_subst_table);
}
else
{
cerr << "ERROR: var_expectation_model " << name
<< " refers to nonexistent auxiliary model " << aux_model_name[name] << endl;
exit(EXIT_FAILURE);
}
// Match the linear combination in the expression option
try
{
auto vpc = expression[name]->matchLinearCombinationOfVariables();
for (const auto &[variable_id, lag, param_id, constant] : vpc)
{
if (lag != 0)
throw ExprNode::MatchFailureException{"lead/lags are not allowed"};
if (symbol_table.getType(variable_id) != SymbolType::endogenous)
throw ExprNode::MatchFailureException{"Variable is not an endogenous"};
vars_params_constants[name].emplace_back(variable_id, param_id, constant);
}
}
catch (ExprNode::MatchFailureException &e)
{
cerr << "ERROR: expression in var_expectation_model " << name << " is not of the expected form: " << e.message << endl;
exit(EXIT_FAILURE);
}
/* Create auxiliary parameters and the expression to be substituted into
the var_expectations statement */
expr_t subst_expr = dynamic_model.Zero;
if (var_model_table.isExistingVarModelName(aux_model_name[name]))
{
/* If the auxiliary model is a VAR, add a parameter corresponding to
the constant. */
string constant_param_name = "var_expectation_model_" + name + "_constant";
int constant_param_id = symbol_table.addSymbol(constant_param_name, SymbolType::parameter);
aux_param_symb_ids[name].push_back(constant_param_id);
subst_expr = dynamic_model.AddPlus(subst_expr, dynamic_model.AddVariable(constant_param_id));
}
for (int lag = 0; lag < max_lag; lag++)
for (auto variable : lhs)
{
string param_name = "var_expectation_model_" + name + '_' + symbol_table.getName(variable) + '_' + to_string(lag);
int new_param_id = symbol_table.addSymbol(param_name, SymbolType::parameter);
aux_param_symb_ids[name].push_back(new_param_id);
subst_expr = dynamic_model.AddPlus(subst_expr,
dynamic_model.AddTimes(dynamic_model.AddVariable(new_param_id),
dynamic_model.AddVariable(variable, -lag + time_shift[name])));
}
if (var_expectation_subst_table.find(name) != var_expectation_subst_table.end())
{
cerr << "ERROR: model name '" << name << "' is used by several var_expectation_model statements" << endl;
exit(EXIT_FAILURE);
}
var_expectation_subst_table[name] = subst_expr;
}
// Actually substitute var_expectation statements
dynamic_model.substituteVarExpectation(var_expectation_subst_table);
/* At this point, we know that all var_expectation operators have been
substituted, because of the error check performed in
VarExpectationNode::substituteVarExpectation(). */
}
void
VarExpectationModelTable::writeJsonOutput(ostream &output) const
{
for (const auto &name : names)
{
output << R"({"statementName": "var_expectation_model",)"
<< R"("model_name": ")" << name << R"(", )"
<< R"("expression": ")";
expression.at(name)->writeOutput(output);
output << R"(", )"
<< R"("auxiliary_model_name": ")" << aux_model_name.at(name) << R"(", )"
<< R"("horizon": ")" << horizon.at(name) << R"(", )"
<< R"("discount": ")";
discount.at(name)->writeOutput(output);
output << R"("})";
}
}
PacModelTable::PacModelTable(SymbolTable &symbol_table_arg) :
symbol_table{symbol_table_arg}
{

View File

@ -190,6 +190,36 @@ VarModelTable::empty() const
return names.empty();
}
class VarExpectationModelTable
{
private:
SymbolTable &symbol_table;
set<string> names;
map<string, expr_t> expression;
map<string, string> aux_model_name;
map<string, string> horizon;
map<string, expr_t> discount;
map<string, int> time_shift;
// For each model, list of generated auxiliary param ids, in variable-major order
map<string, vector<int>> aux_param_symb_ids;
// Decomposition of the expression
map<string, vector<tuple<int, int, double>>> vars_params_constants;
public:
explicit VarExpectationModelTable(SymbolTable &symbol_table_arg);
void addVarExpectationModel(string name_arg, expr_t expression_arg, string aux_model_name_arg,
string horizon_arg, expr_t discount_arg, int time_shift_arg);
bool isExistingVarExpectationModelName(const string &name_arg) const;
bool empty() const;
void substituteUnaryOpsInExpression(const lag_equivalence_table_t &nodes, ExprNode::subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs);
// Called by DynamicModel::substituteDiff()
void substituteDiffNodesInExpression(const lag_equivalence_table_t &diff_nodes, ExprNode::subst_table_t &diff_subst_table, vector<BinaryOpNode *> &neweqs);
void transformPass(ExprNode::subst_table_t &diff_subst_table,
DynamicModel &dynamic_model, const VarModelTable &var_model_table,
const TrendComponentModelTable &trend_component_model_table);
void writeOutput(const string &basename, ostream &output) const;
void writeJsonOutput(ostream &output) const;
};
class PacModelTable
{
private: