diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 48654d7e..1c11d7ac 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -296,14 +296,15 @@ ExprNode::fillErrorCorrectionRow(int eqn, vector> terms; decomposeAdditiveTerms(terms, 1); - for (const auto &it : terms) + for (const auto &[term, sign] : terms) { - pair>> m; + int speed_of_adjustment_param; + vector> error_linear_combination; try { - m = it.first->matchParamTimesLinearCombinationOfVariables(); - for (auto &t : m.second) - get<3>(t) *= it.second; // Update sign of constants + tie(speed_of_adjustment_param, error_linear_combination) = term->matchParamTimesLinearCombinationOfVariables(); + for (auto &[var_id, lag, param_id, constant] : error_linear_combination) + constant *= sign; // Update sign of constants } catch (MatchFailureException &e) { @@ -315,17 +316,17 @@ ExprNode::fillErrorCorrectionRow(int eqn, /* Verify that all variables belong to the error-correction term. FIXME: same remark as above about skipping terms. */ bool not_ec = false; - for (const auto &t : m.second) + for (const auto &[var_id, lag, param_id, constant] : error_linear_combination) { - auto [vid, vlag] = datatree.symbol_table.unrollDiffLeadLagChain(get<0>(t), get<1>(t)); - not_ec = not_ec || (find(target_lhs.begin(), target_lhs.end(), vid) == target_lhs.end() - && find(nontarget_lhs.begin(), nontarget_lhs.end(), vid) == nontarget_lhs.end()); + auto [orig_var_id, orig_lag] = datatree.symbol_table.unrollDiffLeadLagChain(var_id, lag); + not_ec = not_ec || (find(target_lhs.begin(), target_lhs.end(), orig_var_id) == target_lhs.end() + && find(nontarget_lhs.begin(), nontarget_lhs.end(), orig_var_id) == nontarget_lhs.end()); } if (not_ec) continue; // Now fill the matrices - for (auto [var_id, lag, param_id, constant] : m.second) + for (auto [var_id, lag, param_id, constant] : error_linear_combination) { auto [orig_vid, orig_lag] = datatree.symbol_table.unrollDiffLeadLagChain(var_id, lag); if (find(target_lhs.begin(), target_lhs.end(), orig_vid) == target_lhs.end()) @@ -353,13 +354,14 @@ ExprNode::fillErrorCorrectionRow(int eqn, << "symb_id encountered more than once in equation" << endl; exit(EXIT_FAILURE); } - A0[{eqn, colidx}] = datatree.AddVariable(m.first); + A0[{eqn, colidx}] = datatree.AddVariable(speed_of_adjustment_param); } else { // This is a target, so fill A0star int colidx = static_cast(distance(target_lhs.begin(), find(target_lhs.begin(), target_lhs.end(), orig_vid))); - expr_t e = datatree.AddTimes(datatree.AddVariable(m.first), datatree.AddPossiblyNegativeConstant(-constant)); + expr_t e = datatree.AddTimes(datatree.AddVariable(speed_of_adjustment_param), + datatree.AddPossiblyNegativeConstant(-constant)); if (param_id != -1) e = datatree.AddTimes(e, datatree.AddVariable(param_id)); if (auto coor = make_pair(eqn, colidx); A0star.find(coor) == A0star.end()) @@ -5293,21 +5295,23 @@ BinaryOpNode::getPacAREC(int lhs_symb_id, int lhs_orig_symb_id, exit(EXIT_FAILURE); } - for (const auto &it : terms) + for (const auto &[term, sign] : terms) { - if (dynamic_cast(it.first)) + if (dynamic_cast(term)) continue; - pair>> m; + int pid; + vector> linear_combination; try { - m = {-1, {it.first->matchVariableTimesConstantTimesParam()}}; + pid = -1; + linear_combination = { term->matchVariableTimesConstantTimesParam() }; } catch (MatchFailureException &e) { try { - m = it.first->matchParamTimesLinearCombinationOfVariables(); + tie(pid, linear_combination) = term->matchParamTimesLinearCombinationOfVariables(); } catch (MatchFailureException &e) { @@ -5316,11 +5320,10 @@ BinaryOpNode::getPacAREC(int lhs_symb_id, int lhs_orig_symb_id, } } - for (auto &t : m.second) - get<3>(t) *= it.second; // Update sign of constants + for (auto &[vid, vlag, pidtmp, constant] : linear_combination) + constant *= sign; // Update sign of constants - int pid = get<0>(m); - for (auto [vid, vlag, pidtmp, constant] : m.second) + for (auto [vid, vlag, pidtmp, constant] : linear_combination) { if (pid == -1) pid = pidtmp; diff --git a/src/ExprNode.hh b/src/ExprNode.hh index 6b9c0051..b484318b 100644 --- a/src/ExprNode.hh +++ b/src/ExprNode.hh @@ -644,6 +644,13 @@ public: */ vector> matchLinearCombinationOfVariables(bool variable_obligatory_in_each_term = true) const; + /* Matches a parameter, times a linear combination of variables (endo or + exo), where scalars can be constant*parameters. + The first output argument is the symbol ID of the parameter. + The second output argument is the linear combination, in the same format + as the output of matchLinearCombinationOfVariables(). */ + pair>> matchParamTimesLinearCombinationOfVariables() const; + /* Matches a linear combination of endogenous, where scalars can be any constant expression (i.e. containing no endogenous, no exogenous and no exogenous deterministic). The linear combination can contain constant @@ -653,8 +660,6 @@ public: – the sum of all constant (intercept) terms */ pair>, expr_t> matchLinearCombinationOfEndogenousWithConstant() const; - pair>> matchParamTimesLinearCombinationOfVariables() const; - /* Matches an expression of the form parameter*(var1-endo2). endo2 must correspond to symb_id. var1 must be an endogenous or an exogenous; it must be of the form X(-1) or log(X(-1)) or log(X)(-1) (unary ops aux var),