diff --git a/src/ComputingTasks.cc b/src/ComputingTasks.cc index cb94d749..686e4194 100644 --- a/src/ComputingTasks.cc +++ b/src/ComputingTasks.cc @@ -4924,14 +4924,51 @@ VarExpectationModelStatement::VarExpectationModelStatement(string model_name_arg aux_model_name{move(aux_model_name_arg)}, horizon{move(horizon_arg)}, discount{discount_arg}, symbol_table{symbol_table_arg} { - auto vpc = expression->matchLinearCombinationOfVariables(); - for (const auto &it : vpc) +} + +void +VarExpectationModelStatement::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, ExprNode::subst_table_t &subst_table) +{ + vector neweqs; + expression = expression->substituteUnaryOpNodes(static_datatree, nodes, subst_table, neweqs); + if (neweqs.size() > 0) { - if (get<1>(it) != 0) - throw ExprNode::MatchFailureException{"lead/lags are not allowed"}; - if (symbol_table.getType(get<0>(it)) != SymbolType::endogenous) - throw ExprNode::MatchFailureException{"Variable is not an endogenous"}; - vars_params_constants.emplace_back(get<0>(it), get<2>(it), get<3>(it)); + cerr << "ERROR: the 'expression' option of var_expectation_model contains a variable with a unary operator that is not present in the VAR model" << endl; + exit(EXIT_FAILURE); + } +} + +void +VarExpectationModelStatement::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, ExprNode::subst_table_t &subst_table) +{ + vector neweqs; + expression = expression->substituteDiff(static_datatree, diff_table, subst_table, neweqs); + if (neweqs.size() > 0) + { + cerr << "ERROR: the 'expression' option of var_expectation_model contains a diff'd variable that is not present in the VAR model" << endl; + exit(EXIT_FAILURE); + } +} + +void +VarExpectationModelStatement::matchExpression() +{ + try + { + auto vpc = expression->matchLinearCombinationOfVariables(); + for (const auto &it : vpc) + { + if (get<1>(it) != 0) + throw ExprNode::MatchFailureException{"lead/lags are not allowed"}; + if (symbol_table.getType(get<0>(it)) != SymbolType::endogenous) + throw ExprNode::MatchFailureException{"Variable is not an endogenous"}; + vars_params_constants.emplace_back(get<0>(it), get<2>(it), get<3>(it)); + } + } + catch (ExprNode::MatchFailureException &e) + { + cerr << "ERROR: expression in var_expectation_model is not of the expected form: " << e.message << endl; + exit(EXIT_FAILURE); } } @@ -4942,6 +4979,12 @@ VarExpectationModelStatement::writeOutput(ostream &output, const string &basenam output << mstruct << ".auxiliary_model_name = '" << aux_model_name << "';" << endl << mstruct << ".horizon = " << horizon << ';' << endl; + if (!vars_params_constants.size()) + { + cerr << "ERROR: VarExpectationModelStatement::writeOutput: matchExpression() has not been called" << endl; + exit(EXIT_FAILURE); + } + ostringstream vars_list, params_list, constants_list; for (auto it = vars_params_constants.begin(); it != vars_params_constants.end(); ++it) { diff --git a/src/ComputingTasks.hh b/src/ComputingTasks.hh index 2afcbf76..8b882ab1 100644 --- a/src/ComputingTasks.hh +++ b/src/ComputingTasks.hh @@ -1185,7 +1185,9 @@ class VarExpectationModelStatement : public Statement { public: const string model_name; - const expr_t expression; +private: + expr_t expression; +public: const string aux_model_name, horizon; const expr_t discount; const SymbolTable &symbol_table; @@ -1196,6 +1198,12 @@ private: public: VarExpectationModelStatement(string model_name_arg, expr_t expression_arg, string aux_model_name_arg, string horizon_arg, expr_t discount_arg, const SymbolTable &symbol_table_arg); + void substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, ExprNode::subst_table_t &subst_table); + void substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, ExprNode::subst_table_t &subst_table); + // Analyzes the linear combination contained in the 'expression' option + /* Must be called after substituteUnaryOpNodes() and substituteDiff() (in + that order) */ + void matchExpression(); void writeOutput(ostream &output, const string &basename, bool minimal_workspace) const override; void writeJsonOutput(ostream &output) const override; }; diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc index ad9e6909..41398450 100644 --- a/src/DynamicModel.cc +++ b/src/DynamicModel.cc @@ -5802,27 +5802,25 @@ DynamicModel::findPacExpectationEquationNumbers(vector &eqnumbers) const } void -DynamicModel::substituteUnaryOps(StaticModel &static_model, bool nopreprocessoroutput) +DynamicModel::substituteUnaryOps(StaticModel &static_model, diff_table_t &nodes, ExprNode::subst_table_t &subst_table, bool nopreprocessoroutput) { vector eqnumbers(equations.size()); iota(eqnumbers.begin(), eqnumbers.end(), 0); - substituteUnaryOps(static_model, eqnumbers, nopreprocessoroutput); + substituteUnaryOps(static_model, nodes, subst_table, eqnumbers, nopreprocessoroutput); } void -DynamicModel::substituteUnaryOps(StaticModel &static_model, set &var_model_eqtags, bool nopreprocessoroutput) +DynamicModel::substituteUnaryOps(StaticModel &static_model, diff_table_t &nodes, ExprNode::subst_table_t &subst_table, set &var_model_eqtags, bool nopreprocessoroutput) { vector eqnumbers; getEquationNumbersFromTags(eqnumbers, var_model_eqtags); findPacExpectationEquationNumbers(eqnumbers); - substituteUnaryOps(static_model, eqnumbers, nopreprocessoroutput); + substituteUnaryOps(static_model, nodes, subst_table, eqnumbers, nopreprocessoroutput); } void -DynamicModel::substituteUnaryOps(StaticModel &static_model, vector &eqnumbers, bool nopreprocessoroutput) +DynamicModel::substituteUnaryOps(StaticModel &static_model, diff_table_t &nodes, ExprNode::subst_table_t &subst_table, vector &eqnumbers, bool nopreprocessoroutput) { - diff_table_t nodes; - // Find matching unary ops that may be outside of diffs (i.e., those with different lags) set used_local_vars; for (int eqnumber : eqnumbers) @@ -5837,7 +5835,6 @@ DynamicModel::substituteUnaryOps(StaticModel &static_model, vector &eqnumbe equations[eqnumber]->findUnaryOpNodesForAuxVarCreation(static_model, nodes); // Substitute in model local variables - ExprNode::subst_table_t subst_table; vector neweqs; for (auto & it : local_variables_table) it.second = it.second->substituteUnaryOpNodes(static_model, nodes, subst_table, neweqs); @@ -5862,14 +5859,13 @@ DynamicModel::substituteUnaryOps(StaticModel &static_model, vector &eqnumbe } void -DynamicModel::substituteDiff(StaticModel &static_model, ExprNode::subst_table_t &diff_subst_table, bool nopreprocessoroutput) +DynamicModel::substituteDiff(StaticModel &static_model, diff_table_t &diff_table, ExprNode::subst_table_t &diff_subst_table, bool nopreprocessoroutput) { set used_local_vars; for (const auto & equation : equations) equation->collectVariables(SymbolType::modelLocalVariable, used_local_vars); // Only substitute diffs in model local variables that appear in VAR equations - diff_table_t diff_table; for (auto & it : local_variables_table) if (used_local_vars.find(it.first) != used_local_vars.end()) it.second->findDiffNodes(static_model, diff_table); diff --git a/src/DynamicModel.hh b/src/DynamicModel.hh index cf161507..2fd42318 100644 --- a/src/DynamicModel.hh +++ b/src/DynamicModel.hh @@ -437,16 +437,16 @@ public: void substituteAdl(); //! Creates aux vars for all unary operators - void substituteUnaryOps(StaticModel &static_model, bool nopreprocessoroutput); + void substituteUnaryOps(StaticModel &static_model, diff_table_t &nodes, ExprNode::subst_table_t &subst_table, bool nopreprocessoroutput); //! Creates aux vars for certain unary operators: originally implemented for support of VARs - void substituteUnaryOps(StaticModel &static_model, set &eq_tags, bool nopreprocessoroutput); + void substituteUnaryOps(StaticModel &static_model, diff_table_t &nodes, ExprNode::subst_table_t &subst_table, set &eq_tags, bool nopreprocessoroutput); //! Creates aux vars for certain unary operators: originally implemented for support of VARs - void substituteUnaryOps(StaticModel &static_model, vector &eqnumbers, bool nopreprocessoroutput); + void substituteUnaryOps(StaticModel &static_model, diff_table_t &nodes, ExprNode::subst_table_t &subst_table, vector &eqnumbers, bool nopreprocessoroutput); //! Substitutes diff operator - void substituteDiff(StaticModel &static_model, ExprNode::subst_table_t &diff_subst_table, bool nopreprocessoroutput); + void substituteDiff(StaticModel &static_model, diff_table_t &diff_table, ExprNode::subst_table_t &diff_subst_table, bool nopreprocessoroutput); //! Substitute VarExpectation operators void substituteVarExpectation(const map &subst_table); diff --git a/src/DynareBison.yy b/src/DynareBison.yy index fac0a1b1..e76c4c8b 100644 --- a/src/DynareBison.yy +++ b/src/DynareBison.yy @@ -421,8 +421,11 @@ var_expectation_model_options_list : var_expectation_model_option var_expectation_model_option : VARIABLE EQUAL symbol { driver.option_str("variable", $3); } - | EXPRESSION EQUAL expression - { driver.var_expectation_model_expression = $3; } + | EXPRESSION EQUAL { driver.begin_model(); } hand_side + { + driver.var_expectation_model_expression = $4; + driver.reset_data_tree(); + } | AUXILIARY_MODEL_NAME EQUAL symbol { driver.option_str("auxiliary_model_name", $3); } | HORIZON EQUAL INT_NUMBER diff --git a/src/DynareFlex.ll b/src/DynareFlex.ll index 186889af..c703738c 100644 --- a/src/DynareFlex.ll +++ b/src/DynareFlex.ll @@ -388,7 +388,7 @@ DATE -?[0-9]+([YyAa]|[Mm]([1-9]|1[0-2])|[Qq][1-4]|[Ww]([1-9]{1}|[1-4][0-9]|5[0-2 crossequations {return token::CROSSEQUATIONS;} covariance {return token::COVARIANCE;} adl {return token::ADL;} -diff {return token::DIFF;} +diff {return token::DIFF;} cross_restrictions {return token::CROSS_RESTRICTIONS;} contemp_reduced_form {return token::CONTEMP_REDUCED_FORM;} real_pseudo_forecast {return token::REAL_PSEUDO_FORECAST;} diff --git a/src/ExprNode.hh b/src/ExprNode.hh index 3927c6d4..d3f281be 100644 --- a/src/ExprNode.hh +++ b/src/ExprNode.hh @@ -411,7 +411,7 @@ class ExprNode */ virtual expr_t decreaseLeadsLags(int n) const = 0; - //! Type for the substitution map used in the process of creating auxiliary vars for leads >= 2 + //! Type for the substitution map used in the process of creating auxiliary vars using subst_table_t = map; //! Type for the substitution map used in the process of substituting adl expressions diff --git a/src/ModFile.cc b/src/ModFile.cc index 63593f8e..43d89595 100644 --- a/src/ModFile.cc +++ b/src/ModFile.cc @@ -391,15 +391,18 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, const for (auto & it1 : it.second) eqtags.insert(it1); + diff_table_t unary_ops_nodes; + ExprNode::subst_table_t unary_ops_subst_table; if (transform_unary_ops) - dynamic_model.substituteUnaryOps(diff_static_model, nopreprocessoroutput); + dynamic_model.substituteUnaryOps(diff_static_model, unary_ops_nodes, unary_ops_subst_table, nopreprocessoroutput); else // substitute only those unary ops that appear in auxiliary model equations - dynamic_model.substituteUnaryOps(diff_static_model, eqtags, nopreprocessoroutput); + dynamic_model.substituteUnaryOps(diff_static_model, unary_ops_nodes, unary_ops_subst_table, eqtags, nopreprocessoroutput); // Create auxiliary variable and equations for Diff operators that appear in VAR equations + diff_table_t diff_table; ExprNode::subst_table_t diff_subst_table; - dynamic_model.substituteDiff(diff_static_model, diff_subst_table, nopreprocessoroutput); + dynamic_model.substituteDiff(diff_static_model, diff_table, diff_subst_table, nopreprocessoroutput); // Fill Trend Component Model Table dynamic_model.fillTrendComponentModelTable(); @@ -544,6 +547,12 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, const exit(EXIT_FAILURE); } + /* Substitute unary and diff operators in the 'expression' option, then + match the linear combination in the expression option */ + vems->substituteUnaryOpNodes(diff_static_model, unary_ops_nodes, unary_ops_subst_table); + vems->substituteDiff(diff_static_model, diff_table, diff_subst_table); + vems->matchExpression(); + /* Create auxiliary parameters and the expression to be substituted into the var_expectations statement */ auto subst_expr = dynamic_model.Zero; diff --git a/src/ParsingDriver.cc b/src/ParsingDriver.cc index 482e5f90..7a1122b0 100644 --- a/src/ParsingDriver.cc +++ b/src/ParsingDriver.cc @@ -3383,16 +3383,9 @@ ParsingDriver::var_expectation_model() else var_expectation_model_discount = data_tree->One; - try - { - mod_file->addStatement(make_unique(model_name, var_expectation_model_expression, - var_model_name, horizon, - var_expectation_model_discount, mod_file->symbol_table)); - } - catch (ExprNode::MatchFailureException &e) - { - error("expression in var_expectation_model is not of the expected form: " + e.message); - } + mod_file->addStatement(make_unique(model_name, var_expectation_model_expression, + var_model_name, horizon, + var_expectation_model_discount, mod_file->symbol_table)); options_list.clear(); var_expectation_model_discount = nullptr;