From 8dad93964da8ed27d165073a39361901ed670e1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= Date: Thu, 21 Mar 2019 18:13:34 +0100 Subject: [PATCH] Allow linear combination of targets in error correction term of trend component models --- src/ExprNode.cc | 325 ++++++++++++++++++------------------------------ src/ExprNode.hh | 20 +-- 2 files changed, 128 insertions(+), 217 deletions(-) diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 7bbcc901..3fdf8326 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -314,6 +314,104 @@ ExprNode::getEndosAndMaxLags(map &model_endos_and_lags) const { } +void +ExprNode::fillErrorCorrectionRow(int eqn, + const vector &nontarget_lhs, + const vector &target_lhs, + map, expr_t> &A0, + map, expr_t> &A0star) const +{ + vector> terms; + decomposeAdditiveTerms(terms, 1); + + for (const auto &it : terms) + { + pair>> m; + try + { + m = it.first->matchParamTimesLinearCombinationOfVariables(); + for (auto &t : m.second) + get<3>(t) *= it.second; // Update sign of constants + } + catch (MatchFailureException &e) + { + /* FIXME: we should not just skip them, but rather verify that they are + autoregressive terms or residuals (probably by merging the two "fill" procedures) */ + continue; + } + + // Helper function + auto one_step_orig = [this](int symb_id) { + return datatree.symbol_table.isAuxiliaryVariable(symb_id) ? + datatree.symbol_table.getOrigSymbIdForDiffAuxVar(symb_id) : symb_id; + }; + + /* Verify that all variables belong to the error-correction term. + FIXME: same remark as above about skipping terms. */ + bool not_ec = false; + for (const auto &t : m.second) + { + int vid = one_step_orig(get<0>(t)); + not_ec = not_ec || (find(target_lhs.begin(), target_lhs.end(), vid) == target_lhs.end() + && find(nontarget_lhs.begin(), nontarget_lhs.end(), vid) == nontarget_lhs.end()); + } + if (not_ec) + continue; + + // Now fill the matrices + for (const auto &t : m.second) + { + int var_id, lag, param_id; + double constant; + tie(var_id, lag, param_id, constant) = t; + /* + if (lag != -1) + { + cerr << "ERROR in trend component model: variables should appear with a lag of 1 in error correction term" << endl; + exit(EXIT_FAILURE); + } + */ + int orig_vid = one_step_orig(var_id); + int orig_lag = datatree.symbol_table.isAuxiliaryVariable(var_id) ? -datatree.symbol_table.getOrigLeadLagForDiffAuxVar(var_id) : lag; + if (find(target_lhs.begin(), target_lhs.end(), orig_vid) == target_lhs.end()) + { + // This an LHS variable, so fill A0 + if (constant != 1) + { + cerr << "ERROR in trend component model: LHS variable should not appear with a multiplicative constant in error correction term" << endl; + exit(EXIT_FAILURE); + } + if (param_id != -1) + { + cerr << "ERROR in trend component model: spurious parameter in error correction term" << endl; + exit(EXIT_FAILURE); + } + int colidx = static_cast(distance(nontarget_lhs.begin(), find(nontarget_lhs.begin(), nontarget_lhs.end(), orig_vid))); + if (A0.find({eqn, -orig_lag, colidx}) != A0.end()) + { + cerr << "ExprNode::fillErrorCorrection: Error filling A0 matrix: " + << "lag/symb_id encountered more than once in equation" << endl; + exit(EXIT_FAILURE); + } + A0[{eqn, -orig_lag, colidx}] = datatree.AddVariable(m.first); + } + else + { + // This is a target, so fill A0star + int colidx = static_cast(distance(target_lhs.begin(), find(target_lhs.begin(), target_lhs.end(), orig_vid))); + expr_t e = datatree.AddTimes(datatree.AddVariable(m.first), datatree.AddPossiblyNegativeConstant(-constant)); + if (param_id != -1) + e = datatree.AddTimes(e, datatree.AddVariable(param_id)); + auto coor = make_tuple(eqn, -orig_lag, colidx); + if (A0star.find(coor) == A0star.end()) + A0star[coor] = e; + else + A0star[coor] = datatree.AddPlus(e, A0star[coor]); + } + } + } +} + NumConstNode::NumConstNode(DataTree &datatree_arg, int idx_arg, int id_arg) : ExprNode{datatree_arg, idx_arg}, id{id_arg} @@ -686,11 +784,6 @@ NumConstNode::substituteStaticAuxiliaryVariable() const return const_cast(this); } -void -NumConstNode::fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &A0, map, expr_t> &A0star) const -{ -} - void NumConstNode::findConstantEquations(map &table) const { @@ -1974,11 +2067,6 @@ VariableNode::getEndosAndMaxLags(map &model_endos_and_lags) const model_endos_and_lags[varname] = lag; } -void -VariableNode::fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &A0, map, expr_t> &A0star) const -{ -} - void VariableNode::findConstantEquations(map &table) const { @@ -3784,12 +3872,6 @@ UnaryOpNode::substituteStaticAuxiliaryVariable() const return buildSimilarUnaryOpNode(argsubst, datatree); } -void -UnaryOpNode::fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &A0, map, expr_t> &A0star) const -{ - arg->fillErrorCorrectionRow(eqn, nontrend_lhs, trend_lhs, A0, A0star); -} - void UnaryOpNode::findConstantEquations(map &table) const { @@ -5835,153 +5917,6 @@ BinaryOpNode::fillAutoregressiveRow(int eqn, const vector &lhs, map &nontarget_lhs, - const vector &target_lhs, - map, expr_t> &A0, - map, expr_t> &A0star) const -{ - if (op_code != BinaryOpcode::times) - return; - - set> endogs, tmp; - arg1->collectDynamicVariables(SymbolType::endogenous, tmp); - arg1->collectDynamicVariables(SymbolType::exogenous, tmp); - if (tmp.size() != 0) - return; - - arg1->collectDynamicVariables(SymbolType::parameter, tmp); - if (tmp.size() != 1) - return; - - auto *vn = dynamic_cast(arg2); - auto *bopn = dynamic_cast(arg2); - if ((bopn == nullptr || bopn->op_code != BinaryOpcode::minus) - && vn == nullptr) - return; - - if (bopn != nullptr) - { - arg2->collectDynamicVariables(SymbolType::endogenous, endogs); - if (endogs.size() != 2) - return; - - arg2->collectDynamicVariables(SymbolType::exogenous, endogs); - arg2->collectDynamicVariables(SymbolType::parameter, endogs); - if (endogs.size() != 2) - { - cerr << "ERROR in model; expecting param*endog or param*(endog-endog)" << endl; - exit(EXIT_FAILURE); - } - - auto *vn1 = dynamic_cast(bopn->arg1); - auto *vn2 = dynamic_cast(bopn->arg2); - if (vn1 == nullptr || vn2 == nullptr) - { - cerr << "ERROR in model; expecting param*endog or param*(endog-endog)" << endl; - exit(EXIT_FAILURE); - } - - int endog1 = vn1->symb_id; - int endog2 = vn2->symb_id; - int orig_endog1 = endog1; - int orig_endog2 = endog2; - - bool isauxvar1 = datatree.symbol_table.isAuxiliaryVariable(endog1); - endog1 = isauxvar1 ? - datatree.symbol_table.getOrigSymbIdForDiffAuxVar(endog1) : endog1; - - bool isauxvar2 = datatree.symbol_table.isAuxiliaryVariable(endog2); - endog2 = isauxvar2 ? - datatree.symbol_table.getOrigSymbIdForDiffAuxVar(endog2) : endog2; - - int A0_max_lag = vn1->lag; - int A0star_max_lag = vn2->lag; - int A0_colidx = -1; - int A0star_colidx = -1; - if (find(nontarget_lhs.begin(), nontarget_lhs.end(), endog1) != nontarget_lhs.end() - && find(target_lhs.begin(), target_lhs.end(), endog2) != target_lhs.end()) - { - A0_colidx = (int) distance(nontarget_lhs.begin(), find(nontarget_lhs.begin(), nontarget_lhs.end(), endog1)); - int tmp_lag = vn1->lag; - if (isauxvar1) - tmp_lag = -1 * datatree.symbol_table.getOrigLeadLagForDiffAuxVar(orig_endog1); - if (tmp_lag < A0_max_lag) - A0_max_lag = tmp_lag; - - A0star_colidx = (int) distance(target_lhs.begin(), find(target_lhs.begin(), target_lhs.end(), endog2)); - tmp_lag = vn2->lag; - if (isauxvar2) - tmp_lag = -1 * datatree.symbol_table.getOrigLeadLagForDiffAuxVar(orig_endog2); - if (tmp_lag < A0star_max_lag) - A0star_max_lag = tmp_lag; - } - else - return; - - if (A0.find({eqn, -A0_max_lag, A0_colidx}) != A0.end()) - { - cerr << "BinaryOpNode::fillErrorCorrectionRowHelper: Error filling A0 matrix: " - << "lag/symb_id encountered more than once in equtaion" << endl; - exit(EXIT_FAILURE); - } - - if (A0star.find({eqn, -A0star_max_lag, A0star_colidx}) != A0star.end()) - { - cerr << "BinaryOpNode::fillErrorCorrectionRowHelper: Error filling A0star matrix: " - << "lag/symb_id encountered more than once in equtaion" << endl; - exit(EXIT_FAILURE); - } - A0[{eqn, -A0_max_lag, A0_colidx}] = arg1; - A0star[{eqn, -A0star_max_lag, A0star_colidx}] = arg1; - } - else - { - arg2->collectDynamicVariables(SymbolType::endogenous, endogs); - if (endogs.size() != 1) - return; - - arg2->collectDynamicVariables(SymbolType::exogenous, endogs); - arg2->collectDynamicVariables(SymbolType::parameter, endogs); - if (endogs.size() != 1) - { - cerr << "ERROR in model; expecting param*endog or param*(endog-endog)" << endl; - exit(EXIT_FAILURE); - } - - int endog1 = vn->symb_id; - int orig_endog1 = endog1; - - bool isauxvar1 = datatree.symbol_table.isAuxiliaryVariable(endog1); - endog1 = isauxvar1 ? - datatree.symbol_table.getOrigSymbIdForDiffAuxVar(endog1) : endog1; - - int max_lag = vn->lag; - int colidx = -1; - if (find(nontarget_lhs.begin(), nontarget_lhs.end(), endog1) != nontarget_lhs.end()) - { - colidx = (int) distance(nontarget_lhs.begin(), find(nontarget_lhs.begin(), nontarget_lhs.end(), endog1)); - int tmp_lag = vn->lag; - if (isauxvar1) - tmp_lag = -1 * datatree.symbol_table.getOrigLeadLagForDiffAuxVar(orig_endog1); - if (tmp_lag < max_lag) - max_lag = tmp_lag; - } - else - return; - - if (A0.find({eqn, -max_lag, colidx}) != A0.end()) - { - cerr << "BinaryOpNode::fillErrorCorrectionRowHelper: Error filling A0 matrix: " - << "lag/symb_id encountered more than once in equtaion" << endl; - exit(EXIT_FAILURE); - } - A0[{eqn, -max_lag, colidx}] = arg1; - } -} - void BinaryOpNode::findConstantEquations(map &table) const { @@ -6009,17 +5944,6 @@ BinaryOpNode::replaceVarsInEquation(map &table) return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree); } -void -BinaryOpNode::fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &A0, map, expr_t> &A0star) const -{ - size_t A0size = A0.size(); - fillErrorCorrectionRowHelper(arg1, arg2, eqn, nontrend_lhs, trend_lhs, A0, A0star); - if (A0size == A0.size()) - fillErrorCorrectionRowHelper(arg2, arg1, eqn, nontrend_lhs, trend_lhs, A0, A0star); - arg1->fillErrorCorrectionRow(eqn, nontrend_lhs, trend_lhs, A0, A0star); - arg2->fillErrorCorrectionRow(eqn, nontrend_lhs, trend_lhs, A0, A0star); -} - bool BinaryOpNode::isVarModelReferenced(const string &model_info_name) const { @@ -6949,14 +6873,6 @@ TrinaryOpNode::substituteStaticAuxiliaryVariable() const return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree); } -void -TrinaryOpNode::fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &A0, map, expr_t> &A0star) const -{ - arg1->fillErrorCorrectionRow(eqn, nontrend_lhs, trend_lhs, A0, A0star); - arg2->fillErrorCorrectionRow(eqn, nontrend_lhs, trend_lhs, A0, A0star); - arg3->fillErrorCorrectionRow(eqn, nontrend_lhs, trend_lhs, A0, A0star); -} - void TrinaryOpNode::findConstantEquations(map &table) const { @@ -7577,13 +7493,6 @@ AbstractExternalFunctionNode::substituteStaticAuxiliaryVariable() const return buildSimilarExternalFunctionNode(arguments_subst, datatree); } -void -AbstractExternalFunctionNode::fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &A0, map, expr_t> &A0star) const -{ - cerr << "External functions not supported in Trend Component Models" << endl; - exit(EXIT_FAILURE); -} - void AbstractExternalFunctionNode::findConstantEquations(map &table) const { @@ -9088,13 +8997,6 @@ VarExpectationNode::substituteStaticAuxiliaryVariable() const return const_cast(this); } -void -VarExpectationNode::fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &A0, map, expr_t> &A0star) const -{ - cerr << "Var Expectation not supported in Trend Component Models" << endl; - exit(EXIT_FAILURE); -} - void VarExpectationNode::findConstantEquations(map &table) const { @@ -9499,13 +9401,6 @@ PacExpectationNode::substituteStaticAuxiliaryVariable() const return const_cast(this); } -void -PacExpectationNode::fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &A0, map, expr_t> &A0star) const -{ - cerr << "Pac Expectation not supported in Trend Component Models" << endl; - exit(EXIT_FAILURE); -} - void PacExpectationNode::findConstantEquations(map &table) const { @@ -9676,3 +9571,27 @@ ExprNode::matchLinearCombinationOfVariables() const } return result; } + +pair>> +ExprNode::matchParamTimesLinearCombinationOfVariables() const +{ + auto bopn = dynamic_cast(this); + if (!bopn || bopn->op_code != BinaryOpcode::times) + throw MatchFailureException{"Not a multiplicative expression"}; + + expr_t param = bopn->arg1, lincomb = bopn->arg2; + + auto is_param = [](expr_t e) { + auto vn = dynamic_cast(e); + return vn && vn->get_type() == SymbolType::parameter; + }; + + if (!is_param(param)) + { + swap(param, lincomb); + if (!is_param(param)) + throw MatchFailureException{"No parameter on either side of the multiplication"}; + } + + return { dynamic_cast(param)->symb_id, lincomb->matchLinearCombinationOfVariables() }; +} diff --git a/src/ExprNode.hh b/src/ExprNode.hh index af4ed5e6..67c92a37 100644 --- a/src/ExprNode.hh +++ b/src/ExprNode.hh @@ -589,14 +589,17 @@ class ExprNode parameter in a term, param_id == -1. Can throw a MatchFailureException. */ vector> matchLinearCombinationOfVariables() const; + + pair>> matchParamTimesLinearCombinationOfVariables() const; + //! Returns true if expression is of the form: //! param * (endog op endog op ...) + param * (endog op endog op ...) + ... virtual bool isParamTimesEndogExpr() const = 0; //! Fills the EC matrix structure - virtual void fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, - map, expr_t> &A0, - map, expr_t> &A0star) const = 0; + void fillErrorCorrectionRow(int eqn, const vector &nontarget_lhs, const vector &target_lhs, + map, expr_t> &A0, + map, expr_t> &A0star) const; //! Finds equations where a variable is equal to a constant virtual void findConstantEquations(map &table) const = 0; @@ -715,7 +718,6 @@ public: expr_t clone(DataTree &datatree) const override; expr_t removeTrendLeadLag(map trend_symbols_map) const override; bool isInStaticForm() const override; - void fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &A0, map, expr_t> &A0star) const override; void findConstantEquations(map &table) const override; expr_t replaceVarsInEquation(map &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override; @@ -799,7 +801,6 @@ public: expr_t clone(DataTree &datatree) const override; expr_t removeTrendLeadLag(map trend_symbols_map) const override; bool isInStaticForm() const override; - void fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &A0, map, expr_t> &A0star) const override; void findConstantEquations(map &table) const override; expr_t replaceVarsInEquation(map &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override; @@ -911,7 +912,6 @@ public: expr_t clone(DataTree &datatree) const override; expr_t removeTrendLeadLag(map trend_symbols_map) const override; bool isInStaticForm() const override; - void fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &A0, map, expr_t> &A0star) const override; void findConstantEquations(map &table) const override; expr_t replaceVarsInEquation(map &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override; @@ -1033,10 +1033,6 @@ public: expr_t getNonZeroPartofEquation() const; bool isInStaticForm() const override; void fillAutoregressiveRow(int eqn, const vector &lhs, map, expr_t> &AR) const; - void fillErrorCorrectionRowHelper(expr_t arg1, expr_t arg2, - int eqn, const vector &nontrend_lhs, const vector &trend_lhs, - map, expr_t> &A0, map, expr_t> &A0star) const; - void fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &A0, map, expr_t> &A0star) const override; void findConstantEquations(map &table) const override; expr_t replaceVarsInEquation(map &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override; @@ -1158,7 +1154,6 @@ public: expr_t clone(DataTree &datatree) const override; expr_t removeTrendLeadLag(map trend_symbols_map) const override; bool isInStaticForm() const override; - void fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &A0, map, expr_t> &A0star) const override; void findConstantEquations(map &table) const override; expr_t replaceVarsInEquation(map &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override; @@ -1279,7 +1274,6 @@ public: expr_t clone(DataTree &datatree) const override = 0; expr_t removeTrendLeadLag(map trend_symbols_map) const override; bool isInStaticForm() const override; - void fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &A0, map, expr_t> &A0star) const override; void findConstantEquations(map &table) const override; expr_t replaceVarsInEquation(map &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override; @@ -1488,7 +1482,6 @@ public: expr_t detrend(int symb_id, bool log_trend, expr_t trend) const override; expr_t removeTrendLeadLag(map trend_symbols_map) const override; bool isInStaticForm() const override; - void fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &A0, map, expr_t> &A0star) const override; void findConstantEquations(map &table) const override; expr_t replaceVarsInEquation(map &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override; @@ -1570,7 +1563,6 @@ public: expr_t detrend(int symb_id, bool log_trend, expr_t trend) const override; expr_t removeTrendLeadLag(map trend_symbols_map) const override; bool isInStaticForm() const override; - void fillErrorCorrectionRow(int eqn, const vector &nontrend_lhs, const vector &trend_lhs, map, expr_t> &A0, map, expr_t> &A0star) const override; void findConstantEquations(map &table) const override; expr_t replaceVarsInEquation(map &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override;