diff --git a/src/ComputingTasks.cc b/src/ComputingTasks.cc index b13c815d..adf9a222 100644 --- a/src/ComputingTasks.cc +++ b/src/ComputingTasks.cc @@ -5102,129 +5102,6 @@ GenerateIRFsStatement::writeJsonOutput(ostream &output) const output << "}"; } -VarExpectationModelStatement::VarExpectationModelStatement(string model_name_arg, - expr_t expression_arg, - string aux_model_name_arg, - string horizon_arg, - expr_t discount_arg, - int time_shift_arg, - const SymbolTable &symbol_table_arg) : - model_name{move(model_name_arg)}, expression{expression_arg}, - aux_model_name{move(aux_model_name_arg)}, horizon{move(horizon_arg)}, - discount{discount_arg}, time_shift{time_shift_arg}, symbol_table{symbol_table_arg} -{ -} - -void -VarExpectationModelStatement::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, ExprNode::subst_table_t &subst_table) -{ - vector neweqs; - expression = expression->substituteUnaryOpNodes(nodes, subst_table, neweqs); - if (neweqs.size() > 0) - { - cerr << "ERROR: the 'expression' option of var_expectation_model contains a variable with a unary operator that is not present in the VAR model" << endl; - exit(EXIT_FAILURE); - } -} - -void -VarExpectationModelStatement::substituteDiff(const lag_equivalence_table_t &nodes, ExprNode::subst_table_t &subst_table) -{ - vector neweqs; - expression = expression->substituteDiff(nodes, subst_table, neweqs); - if (neweqs.size() > 0) - { - cerr << "ERROR: the 'expression' option of var_expectation_model contains a diff'd variable that is not present in the VAR model" << endl; - exit(EXIT_FAILURE); - } -} - -void -VarExpectationModelStatement::matchExpression() -{ - try - { - auto vpc = expression->matchLinearCombinationOfVariables(); - for (const auto &it : vpc) - { - if (get<1>(it) != 0) - throw ExprNode::MatchFailureException{"lead/lags are not allowed"}; - if (symbol_table.getType(get<0>(it)) != SymbolType::endogenous) - throw ExprNode::MatchFailureException{"Variable is not an endogenous"}; - vars_params_constants.emplace_back(get<0>(it), get<2>(it), get<3>(it)); - } - } - catch (ExprNode::MatchFailureException &e) - { - cerr << "ERROR: expression in var_expectation_model is not of the expected form: " << e.message << endl; - exit(EXIT_FAILURE); - } -} - -void -VarExpectationModelStatement::writeOutput(ostream &output, const string &basename, bool minimal_workspace) const -{ - string mstruct = "M_.var_expectation." + model_name; - output << mstruct << ".auxiliary_model_name = '" << aux_model_name << "';" << endl - << mstruct << ".horizon = " << horizon << ';' << endl - << mstruct << ".time_shift = " << time_shift << ';' << endl; - - if (!vars_params_constants.size()) - { - cerr << "ERROR: VarExpectationModelStatement::writeOutput: matchExpression() has not been called" << endl; - exit(EXIT_FAILURE); - } - - ostringstream vars_list, params_list, constants_list; - for (auto it = vars_params_constants.begin(); it != vars_params_constants.end(); ++it) - { - if (it != vars_params_constants.begin()) - { - vars_list << ", "; - params_list << ", "; - constants_list << ", "; - } - vars_list << symbol_table.getTypeSpecificID(get<0>(*it))+1; - if (get<1>(*it) == -1) - params_list << "NaN"; - else - params_list << symbol_table.getTypeSpecificID(get<1>(*it))+1; - constants_list << get<2>(*it); - } - output << mstruct << ".expr.vars = [ " << vars_list.str() << " ];" << endl - << mstruct << ".expr.params = [ " << params_list.str() << " ];" << endl - << mstruct << ".expr.constants = [ " << constants_list.str() << " ];" << endl; - - if (auto disc_var = dynamic_cast(discount); - disc_var) - output << mstruct << ".discount_index = " << symbol_table.getTypeSpecificID(disc_var->symb_id) + 1 << ';' << endl; - else - { - output << mstruct << ".discount_value = "; - discount->writeOutput(output); - output << ';' << endl; - } - output << mstruct << ".param_indices = [ "; - for (int param_id : aux_params_ids) - output << symbol_table.getTypeSpecificID(param_id)+1 << ' '; - output << "];" << endl; -} - -void -VarExpectationModelStatement::writeJsonOutput(ostream &output) const -{ - output << R"({"statementName": "var_expectation_model",)" - << R"("model_name": ")" << model_name << R"(", )" - << R"("expression": ")"; - expression->writeOutput(output); - output << R"(", )" - << R"("auxiliary_model_name": ")" << aux_model_name << R"(", )" - << R"("horizon": ")" << horizon << R"(", )" - << R"("discount": ")"; - discount->writeOutput(output); - output << R"("})"; -} - MatchedMomentsStatement::MatchedMomentsStatement(const SymbolTable &symbol_table_arg, vector, vector, vector>> moments_arg) : symbol_table{symbol_table_arg}, moments{move(moments_arg)} diff --git a/src/ComputingTasks.hh b/src/ComputingTasks.hh index 0f789182..0b03881a 100644 --- a/src/ComputingTasks.hh +++ b/src/ComputingTasks.hh @@ -1211,35 +1211,6 @@ public: void writeJsonOutput(ostream &output) const override; }; -class VarExpectationModelStatement : public Statement -{ -public: - const string model_name; -private: - expr_t expression; -public: - const string aux_model_name, horizon; - const expr_t discount; - const int time_shift; - const SymbolTable &symbol_table; - // List of generated auxiliary param ids, in variable-major order - vector aux_params_ids; // TODO: move this to some new VarModelTable object -private: - vector> vars_params_constants; -public: - VarExpectationModelStatement(string model_name_arg, expr_t expression_arg, string aux_model_name_arg, - string horizon_arg, expr_t discount_arg, int time_shift_arg, - const SymbolTable &symbol_table_arg); - void substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, ExprNode::subst_table_t &subst_table); - void substituteDiff(const lag_equivalence_table_t &diff_table, ExprNode::subst_table_t &subst_table); - // Analyzes the linear combination contained in the 'expression' option - /* Must be called after substituteUnaryOpNodes() and substituteDiff() (in - that order) */ - void matchExpression(); - void writeOutput(ostream &output, const string &basename, bool minimal_workspace) const override; - void writeJsonOutput(ostream &output) const override; -}; - class MatchedMomentsStatement : public Statement { private: diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc index fe3cc898..e4dd30b3 100644 --- a/src/DynamicModel.cc +++ b/src/DynamicModel.cc @@ -5587,15 +5587,15 @@ DynamicModel::findPacExpectationEquationNumbers() const } pair -DynamicModel::substituteUnaryOps(PacModelTable &pac_model_table) +DynamicModel::substituteUnaryOps(VarExpectationModelTable &var_expectation_model_table, PacModelTable &pac_model_table) { vector eqnumbers(equations.size()); iota(eqnumbers.begin(), eqnumbers.end(), 0); - return substituteUnaryOps(set(eqnumbers.begin(), eqnumbers.end()), pac_model_table); + return substituteUnaryOps(set(eqnumbers.begin(), eqnumbers.end()), var_expectation_model_table, pac_model_table); } pair -DynamicModel::substituteUnaryOps(const set &eqnumbers, PacModelTable &pac_model_table) +DynamicModel::substituteUnaryOps(const set &eqnumbers, VarExpectationModelTable &var_expectation_model_table, PacModelTable &pac_model_table) { lag_equivalence_table_t nodes; ExprNode::subst_table_t subst_table; @@ -5625,6 +5625,8 @@ DynamicModel::substituteUnaryOps(const set &eqnumbers, PacModelTable &pac_m equations[eq] = substeq; } + // Substitute in expressions of var_expectation_model + var_expectation_model_table.substituteUnaryOpsInExpression(nodes, subst_table, neweqs); // Substitute in growth terms in pac_model and pac_target_info pac_model_table.substituteUnaryOpsInGrowth(nodes, subst_table, neweqs); @@ -5642,7 +5644,7 @@ DynamicModel::substituteUnaryOps(const set &eqnumbers, PacModelTable &pac_m } pair -DynamicModel::substituteDiff(PacModelTable &pac_model_table) +DynamicModel::substituteDiff(VarExpectationModelTable &var_expectation_model_table, PacModelTable &pac_model_table) { /* Note: at this point, we know that there is no diff operator with a lead, because they have been expanded by DataTree::AddDiff(). @@ -5680,6 +5682,7 @@ DynamicModel::substituteDiff(PacModelTable &pac_model_table) equation = substeq; } + var_expectation_model_table.substituteDiffNodesInExpression(diff_nodes, diff_subst_table, neweqs); pac_model_table.substituteDiffNodesInGrowth(diff_nodes, diff_subst_table, neweqs); // Add new equations diff --git a/src/DynamicModel.hh b/src/DynamicModel.hh index 50cd6d47..e0776c00 100644 --- a/src/DynamicModel.hh +++ b/src/DynamicModel.hh @@ -521,15 +521,17 @@ public: void substituteModelLocalVariables(); /* Creates aux vars for all unary operators in all equations. Also makes the - substitution in growth terms of pac_model/pac_target_info. */ - pair substituteUnaryOps(PacModelTable &pac_model_table); + substitution in growth terms of pac_model/pac_target_info and in + expressions of var_expectation_model. */ + pair substituteUnaryOps(VarExpectationModelTable &var_expectation_model_table, PacModelTable &pac_model_table); /* Creates aux vars for all unary operators in specified equations. Also makes the - substitution in growth terms of pac_model/pac_target_info. */ - pair substituteUnaryOps(const set &eqnumbers, PacModelTable &pac_model_table); + substitution in growth terms of pac_model/pac_target_info and in + expressions of var_expectation_model. */ + pair substituteUnaryOps(const set &eqnumbers, VarExpectationModelTable &var_expectation_model_table, PacModelTable &pac_model_table); //! Substitutes diff operator - pair substituteDiff(PacModelTable &pac_model_table); + pair substituteDiff(VarExpectationModelTable &var_expectation_model_table, PacModelTable &pac_model_table); //! Substitute VarExpectation operators void substituteVarExpectation(const map &subst_table); diff --git a/src/ModFile.cc b/src/ModFile.cc index 3c797720..a033be96 100644 --- a/src/ModFile.cc +++ b/src/ModFile.cc @@ -34,6 +34,7 @@ ModFile::ModFile(WarningConsolidation &warnings_arg) : var_model_table{symbol_table}, trend_component_model_table{symbol_table}, + var_expectation_model_table{symbol_table}, pac_model_table{symbol_table}, expressions_tree{symbol_table, num_constants, external_functions_table}, original_model{symbol_table, num_constants, external_functions_table, @@ -430,13 +431,13 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, bool lag_equivalence_table_t unary_ops_nodes; ExprNode::subst_table_t unary_ops_subst_table; if (transform_unary_ops) - tie(unary_ops_nodes, unary_ops_subst_table) = dynamic_model.substituteUnaryOps(pac_model_table); + tie(unary_ops_nodes, unary_ops_subst_table) = dynamic_model.substituteUnaryOps(var_expectation_model_table, pac_model_table); else // substitute only those unary ops that appear in VAR, TCM and PAC model equations - tie(unary_ops_nodes, unary_ops_subst_table) = dynamic_model.substituteUnaryOps(unary_ops_eqs, pac_model_table); + tie(unary_ops_nodes, unary_ops_subst_table) = dynamic_model.substituteUnaryOps(unary_ops_eqs, var_expectation_model_table, pac_model_table); // Create auxiliary variable and equations for Diff operators - auto [diff_nodes, diff_subst_table] = dynamic_model.substituteDiff(pac_model_table); + auto [diff_nodes, diff_subst_table] = dynamic_model.substituteDiff(var_expectation_model_table, pac_model_table); // Fill trend component and VAR model tables dynamic_model.fillTrendComponentModelTable(); @@ -445,6 +446,10 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, bool dynamic_model.fillVarModelTable(); original_model.fillVarModelTableFromOrigModel(); + // VAR expectation models + var_expectation_model_table.transformPass(diff_subst_table, dynamic_model, var_model_table, + trend_component_model_table); + // PAC model pac_model_table.transformPass(unary_ops_nodes, unary_ops_subst_table, diff_nodes, diff_subst_table, @@ -494,77 +499,6 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, bool mod_file_struct.ramsey_eq_nbr = dynamic_model.equation_number() - mod_file_struct.orig_eq_nbr; } - /* Handle var_expectation_model statements: collect information about them, - create the new corresponding parameters, and the expressions to replace - the var_expectation statements. - TODO: move information collection to checkPass(), within a new - VarModelTable class */ - map var_expectation_subst_table; - for (auto &statement : statements) - { - auto vems = dynamic_cast(statement.get()); - if (!vems) - continue; - - int max_lag; - vector lhs; - auto &model_name = vems->model_name; - if (var_model_table.isExistingVarModelName(vems->aux_model_name)) - { - max_lag = var_model_table.getMaxLag(vems->aux_model_name); - lhs = var_model_table.getLhs(vems->aux_model_name); - } - else if (trend_component_model_table.isExistingTrendComponentModelName(vems->aux_model_name)) - { - max_lag = trend_component_model_table.getMaxLag(vems->aux_model_name) + 1; - lhs = dynamic_model.getUndiffLHSForPac(vems->aux_model_name, diff_subst_table); - } - else - { - cerr << "ERROR: var_expectation_model " << model_name - << " refers to nonexistent auxiliary model " << vems->aux_model_name << endl; - exit(EXIT_FAILURE); - } - - /* Substitute unary and diff operators in the 'expression' option, then - match the linear combination in the expression option */ - vems->substituteUnaryOpNodes(unary_ops_nodes, unary_ops_subst_table); - vems->substituteDiff(diff_nodes, diff_subst_table); - vems->matchExpression(); - - /* Create auxiliary parameters and the expression to be substituted into - the var_expectations statement */ - expr_t subst_expr = dynamic_model.Zero; - if (var_model_table.isExistingVarModelName(vems->aux_model_name)) - { - /* If the auxiliary model is a VAR, add a parameter corresponding to - the constant. */ - string constant_param_name = "var_expectation_model_" + model_name + "_constant"; - int constant_param_id = symbol_table.addSymbol(constant_param_name, SymbolType::parameter); - vems->aux_params_ids.push_back(constant_param_id); - subst_expr = dynamic_model.AddPlus(subst_expr, dynamic_model.AddVariable(constant_param_id)); - } - for (int lag = 0; lag < max_lag; lag++) - for (auto variable : lhs) - { - string param_name = "var_expectation_model_" + model_name + '_' + symbol_table.getName(variable) + '_' + to_string(lag); - int new_param_id = symbol_table.addSymbol(param_name, SymbolType::parameter); - vems->aux_params_ids.push_back(new_param_id); - - subst_expr = dynamic_model.AddPlus(subst_expr, - dynamic_model.AddTimes(dynamic_model.AddVariable(new_param_id), - dynamic_model.AddVariable(variable, -lag + vems->time_shift))); - } - - if (var_expectation_subst_table.find(model_name) != var_expectation_subst_table.end()) - { - cerr << "ERROR: model name '" << model_name << "' is used by several var_expectation_model statements" << endl; - exit(EXIT_FAILURE); - } - var_expectation_subst_table[model_name] = subst_expr; - } - // And finally perform the substitutions - dynamic_model.substituteVarExpectation(var_expectation_subst_table); dynamic_model.createVariableMapping(); /* Create auxiliary vars for leads and lags greater than 2, on both endos and @@ -929,6 +863,7 @@ ModFile::writeMOutput(const string &basename, bool clear_all, bool clear_global, var_model_table.writeOutput(basename, mOutputFile); trend_component_model_table.writeOutput(basename, mOutputFile); + var_expectation_model_table.writeOutput(basename, mOutputFile); pac_model_table.writeOutput(basename, mOutputFile); // Initialize M_.Sigma_e, M_.Correlation_matrix, M_.H, and M_.Correlation_matrix_ME @@ -1223,6 +1158,12 @@ ModFile::writeJsonOutputParsingCheck(const string &basename, JsonFileOutputType output << ", "; } + if (!var_expectation_model_table.empty()) + { + var_expectation_model_table.writeJsonOutput(output); + output << ", "; + } + if (!pac_model_table.empty()) { pac_model_table.writeJsonOutput(output); diff --git a/src/ModFile.hh b/src/ModFile.hh index 4c2cd072..1d920a99 100644 --- a/src/ModFile.hh +++ b/src/ModFile.hh @@ -55,6 +55,8 @@ public: VarModelTable var_model_table; //! Trend Component Model Table used for storing info about trend component models TrendComponentModelTable trend_component_model_table; + //! Table for storing the models declared with var_expectation_model + VarExpectationModelTable var_expectation_model_table; //! PAC Model Table used for storing info about pac models PacModelTable pac_model_table; //! Expressions outside model block diff --git a/src/ParsingDriver.cc b/src/ParsingDriver.cc index 4d4f6414..9b9741df 100644 --- a/src/ParsingDriver.cc +++ b/src/ParsingDriver.cc @@ -3457,10 +3457,9 @@ ParsingDriver::var_expectation_model() if (time_shift > 0) error("The 'time_shift' option must be a non-positive integer"); - mod_file->addStatement(make_unique(model_name, var_expectation_model_expression, - var_model_name, horizon, - var_expectation_model_discount, time_shift, - mod_file->symbol_table)); + mod_file->var_expectation_model_table.addVarExpectationModel(model_name, var_expectation_model_expression, + var_model_name, horizon, + var_expectation_model_discount, time_shift); options_list.clear(); var_expectation_model_discount = nullptr; diff --git a/src/SubModel.cc b/src/SubModel.cc index 6b9f6ff8..bc0a3a48 100644 --- a/src/SubModel.cc +++ b/src/SubModel.cc @@ -754,6 +754,215 @@ VarModelTable::getLhsExprT(const string &name_arg) const return lhs_expr_t.find(name_arg)->second; } + +VarExpectationModelTable::VarExpectationModelTable(SymbolTable &symbol_table_arg) : + symbol_table{symbol_table_arg} +{ +} + +void +VarExpectationModelTable::addVarExpectationModel(string name_arg, expr_t expression_arg, string aux_model_name_arg, string horizon_arg, expr_t discount_arg, int time_shift_arg) +{ + if (isExistingVarExpectationModelName(name_arg)) + { + cerr << "Error: a var_expectation_model already exists with the name " << name_arg << endl; + exit(EXIT_FAILURE); + } + + expression[name_arg] = expression_arg; + aux_model_name[name_arg] = move(aux_model_name_arg); + horizon[name_arg] = move(horizon_arg); + discount[name_arg] = discount_arg; + time_shift[name_arg] = time_shift_arg; + names.insert(move(name_arg)); +} + +bool +VarExpectationModelTable::isExistingVarExpectationModelName(const string &name_arg) const +{ + return names.find(name_arg) != names.end(); +} + +bool +VarExpectationModelTable::empty() const +{ + return names.empty(); +} + +void +VarExpectationModelTable::writeOutput(const string &basename, ostream &output) const +{ + for (const auto &name : names) + { + string mstruct = "M_.var_expectation." + name; + output << mstruct << ".auxiliary_model_name = '" << aux_model_name.at(name) << "';" << endl + << mstruct << ".horizon = " << horizon.at(name) << ';' << endl + << mstruct << ".time_shift = " << time_shift.at(name) << ';' << endl; + + auto &vpc = vars_params_constants.at(name); + if (!vpc.size()) + { + cerr << "ERROR: VarExpectationModelStatement::writeOutput: matchExpression() has not been called" << endl; + exit(EXIT_FAILURE); + } + + ostringstream vars_list, params_list, constants_list; + for (auto it = vpc.begin(); it != vpc.end(); ++it) + { + if (it != vpc.begin()) + { + vars_list << ", "; + params_list << ", "; + constants_list << ", "; + } + vars_list << symbol_table.getTypeSpecificID(get<0>(*it))+1; + if (get<1>(*it) == -1) + params_list << "NaN"; + else + params_list << symbol_table.getTypeSpecificID(get<1>(*it))+1; + constants_list << get<2>(*it); + } + output << mstruct << ".expr.vars = [ " << vars_list.str() << " ];" << endl + << mstruct << ".expr.params = [ " << params_list.str() << " ];" << endl + << mstruct << ".expr.constants = [ " << constants_list.str() << " ];" << endl; + + if (auto disc_var = dynamic_cast(discount.at(name)); + disc_var) + output << mstruct << ".discount_index = " << symbol_table.getTypeSpecificID(disc_var->symb_id) + 1 << ';' << endl; + else + { + output << mstruct << ".discount_value = "; + discount.at(name)->writeOutput(output); + output << ';' << endl; + } + output << mstruct << ".param_indices = [ "; + for (int param_id : aux_param_symb_ids.at(name)) + output << symbol_table.getTypeSpecificID(param_id)+1 << ' '; + output << "];" << endl; + } +} + +void +VarExpectationModelTable::substituteUnaryOpsInExpression(const lag_equivalence_table_t &nodes, ExprNode::subst_table_t &subst_table, vector &neweqs) +{ + for (const auto &name : names) + expression[name] = expression[name]->substituteUnaryOpNodes(nodes, subst_table, neweqs); +} + +void +VarExpectationModelTable::substituteDiffNodesInExpression(const lag_equivalence_table_t &nodes, ExprNode::subst_table_t &subst_table, vector &neweqs) +{ + for (const auto &name : names) + expression[name] = expression[name]->substituteDiff(nodes, subst_table, neweqs); +} + +void +VarExpectationModelTable::transformPass(ExprNode::subst_table_t &diff_subst_table, + DynamicModel &dynamic_model, const VarModelTable &var_model_table, + const TrendComponentModelTable &trend_component_model_table) +{ + map var_expectation_subst_table; + + for (const auto &name : names) + { + // Collect information about the auxiliary model + + int max_lag; + vector lhs; + if (var_model_table.isExistingVarModelName(aux_model_name[name])) + { + max_lag = var_model_table.getMaxLag(aux_model_name[name]); + lhs = var_model_table.getLhs(aux_model_name[name]); + } + else if (trend_component_model_table.isExistingTrendComponentModelName(aux_model_name[name])) + { + max_lag = trend_component_model_table.getMaxLag(aux_model_name[name]) + 1; + lhs = dynamic_model.getUndiffLHSForPac(aux_model_name[name], diff_subst_table); + } + else + { + cerr << "ERROR: var_expectation_model " << name + << " refers to nonexistent auxiliary model " << aux_model_name[name] << endl; + exit(EXIT_FAILURE); + } + + // Match the linear combination in the expression option + try + { + auto vpc = expression[name]->matchLinearCombinationOfVariables(); + for (const auto &[variable_id, lag, param_id, constant] : vpc) + { + if (lag != 0) + throw ExprNode::MatchFailureException{"lead/lags are not allowed"}; + if (symbol_table.getType(variable_id) != SymbolType::endogenous) + throw ExprNode::MatchFailureException{"Variable is not an endogenous"}; + vars_params_constants[name].emplace_back(variable_id, param_id, constant); + } + } + catch (ExprNode::MatchFailureException &e) + { + cerr << "ERROR: expression in var_expectation_model " << name << " is not of the expected form: " << e.message << endl; + exit(EXIT_FAILURE); + } + + /* Create auxiliary parameters and the expression to be substituted into + the var_expectations statement */ + expr_t subst_expr = dynamic_model.Zero; + if (var_model_table.isExistingVarModelName(aux_model_name[name])) + { + /* If the auxiliary model is a VAR, add a parameter corresponding to + the constant. */ + string constant_param_name = "var_expectation_model_" + name + "_constant"; + int constant_param_id = symbol_table.addSymbol(constant_param_name, SymbolType::parameter); + aux_param_symb_ids[name].push_back(constant_param_id); + subst_expr = dynamic_model.AddPlus(subst_expr, dynamic_model.AddVariable(constant_param_id)); + } + for (int lag = 0; lag < max_lag; lag++) + for (auto variable : lhs) + { + string param_name = "var_expectation_model_" + name + '_' + symbol_table.getName(variable) + '_' + to_string(lag); + int new_param_id = symbol_table.addSymbol(param_name, SymbolType::parameter); + aux_param_symb_ids[name].push_back(new_param_id); + + subst_expr = dynamic_model.AddPlus(subst_expr, + dynamic_model.AddTimes(dynamic_model.AddVariable(new_param_id), + dynamic_model.AddVariable(variable, -lag + time_shift[name]))); + } + + if (var_expectation_subst_table.find(name) != var_expectation_subst_table.end()) + { + cerr << "ERROR: model name '" << name << "' is used by several var_expectation_model statements" << endl; + exit(EXIT_FAILURE); + } + var_expectation_subst_table[name] = subst_expr; + } + + // Actually substitute var_expectation statements + dynamic_model.substituteVarExpectation(var_expectation_subst_table); + /* At this point, we know that all var_expectation operators have been + substituted, because of the error check performed in + VarExpectationNode::substituteVarExpectation(). */ +} + +void +VarExpectationModelTable::writeJsonOutput(ostream &output) const +{ + for (const auto &name : names) + { + output << R"({"statementName": "var_expectation_model",)" + << R"("model_name": ")" << name << R"(", )" + << R"("expression": ")"; + expression.at(name)->writeOutput(output); + output << R"(", )" + << R"("auxiliary_model_name": ")" << aux_model_name.at(name) << R"(", )" + << R"("horizon": ")" << horizon.at(name) << R"(", )" + << R"("discount": ")"; + discount.at(name)->writeOutput(output); + output << R"("})"; + } +} + + PacModelTable::PacModelTable(SymbolTable &symbol_table_arg) : symbol_table{symbol_table_arg} { diff --git a/src/SubModel.hh b/src/SubModel.hh index 548d323a..c035fbae 100644 --- a/src/SubModel.hh +++ b/src/SubModel.hh @@ -190,6 +190,36 @@ VarModelTable::empty() const return names.empty(); } +class VarExpectationModelTable +{ +private: + SymbolTable &symbol_table; + set names; + map expression; + map aux_model_name; + map horizon; + map discount; + map time_shift; + // For each model, list of generated auxiliary param ids, in variable-major order + map> aux_param_symb_ids; + // Decomposition of the expression + map>> vars_params_constants; +public: + explicit VarExpectationModelTable(SymbolTable &symbol_table_arg); + void addVarExpectationModel(string name_arg, expr_t expression_arg, string aux_model_name_arg, + string horizon_arg, expr_t discount_arg, int time_shift_arg); + bool isExistingVarExpectationModelName(const string &name_arg) const; + bool empty() const; + void substituteUnaryOpsInExpression(const lag_equivalence_table_t &nodes, ExprNode::subst_table_t &subst_table, vector &neweqs); + // Called by DynamicModel::substituteDiff() + void substituteDiffNodesInExpression(const lag_equivalence_table_t &diff_nodes, ExprNode::subst_table_t &diff_subst_table, vector &neweqs); + void transformPass(ExprNode::subst_table_t &diff_subst_table, + DynamicModel &dynamic_model, const VarModelTable &var_model_table, + const TrendComponentModelTable &trend_component_model_table); + void writeOutput(const string &basename, ostream &output) const; + void writeJsonOutput(ostream &output) const; +}; + class PacModelTable { private: