Comment improvement + cosmetics

fix-tolerance-parameters
Sébastien Villemot 2022-01-28 17:24:48 +01:00
parent 01bea3f5e7
commit adab6c7f93
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
2 changed files with 31 additions and 23 deletions

View File

@ -296,14 +296,15 @@ ExprNode::fillErrorCorrectionRow(int eqn,
vector<pair<expr_t, int>> terms;
decomposeAdditiveTerms(terms, 1);
for (const auto &it : terms)
for (const auto &[term, sign] : terms)
{
pair<int, vector<tuple<int, int, int, double>>> m;
int speed_of_adjustment_param;
vector<tuple<int, int, int, double>> error_linear_combination;
try
{
m = it.first->matchParamTimesLinearCombinationOfVariables();
for (auto &t : m.second)
get<3>(t) *= it.second; // Update sign of constants
tie(speed_of_adjustment_param, error_linear_combination) = term->matchParamTimesLinearCombinationOfVariables();
for (auto &[var_id, lag, param_id, constant] : error_linear_combination)
constant *= sign; // Update sign of constants
}
catch (MatchFailureException &e)
{
@ -315,17 +316,17 @@ ExprNode::fillErrorCorrectionRow(int eqn,
/* Verify that all variables belong to the error-correction term.
FIXME: same remark as above about skipping terms. */
bool not_ec = false;
for (const auto &t : m.second)
for (const auto &[var_id, lag, param_id, constant] : error_linear_combination)
{
auto [vid, vlag] = datatree.symbol_table.unrollDiffLeadLagChain(get<0>(t), get<1>(t));
not_ec = not_ec || (find(target_lhs.begin(), target_lhs.end(), vid) == target_lhs.end()
&& find(nontarget_lhs.begin(), nontarget_lhs.end(), vid) == nontarget_lhs.end());
auto [orig_var_id, orig_lag] = datatree.symbol_table.unrollDiffLeadLagChain(var_id, lag);
not_ec = not_ec || (find(target_lhs.begin(), target_lhs.end(), orig_var_id) == target_lhs.end()
&& find(nontarget_lhs.begin(), nontarget_lhs.end(), orig_var_id) == nontarget_lhs.end());
}
if (not_ec)
continue;
// Now fill the matrices
for (auto [var_id, lag, param_id, constant] : m.second)
for (auto [var_id, lag, param_id, constant] : error_linear_combination)
{
auto [orig_vid, orig_lag] = datatree.symbol_table.unrollDiffLeadLagChain(var_id, lag);
if (find(target_lhs.begin(), target_lhs.end(), orig_vid) == target_lhs.end())
@ -353,13 +354,14 @@ ExprNode::fillErrorCorrectionRow(int eqn,
<< "symb_id encountered more than once in equation" << endl;
exit(EXIT_FAILURE);
}
A0[{eqn, colidx}] = datatree.AddVariable(m.first);
A0[{eqn, colidx}] = datatree.AddVariable(speed_of_adjustment_param);
}
else
{
// This is a target, so fill A0star
int colidx = static_cast<int>(distance(target_lhs.begin(), find(target_lhs.begin(), target_lhs.end(), orig_vid)));
expr_t e = datatree.AddTimes(datatree.AddVariable(m.first), datatree.AddPossiblyNegativeConstant(-constant));
expr_t e = datatree.AddTimes(datatree.AddVariable(speed_of_adjustment_param),
datatree.AddPossiblyNegativeConstant(-constant));
if (param_id != -1)
e = datatree.AddTimes(e, datatree.AddVariable(param_id));
if (auto coor = make_pair(eqn, colidx); A0star.find(coor) == A0star.end())
@ -5293,21 +5295,23 @@ BinaryOpNode::getPacAREC(int lhs_symb_id, int lhs_orig_symb_id,
exit(EXIT_FAILURE);
}
for (const auto &it : terms)
for (const auto &[term, sign] : terms)
{
if (dynamic_cast<PacExpectationNode *>(it.first))
if (dynamic_cast<PacExpectationNode *>(term))
continue;
pair<int, vector<tuple<int, int, int, double>>> m;
int pid;
vector<tuple<int, int, int, double>> linear_combination;
try
{
m = {-1, {it.first->matchVariableTimesConstantTimesParam()}};
pid = -1;
linear_combination = { term->matchVariableTimesConstantTimesParam() };
}
catch (MatchFailureException &e)
{
try
{
m = it.first->matchParamTimesLinearCombinationOfVariables();
tie(pid, linear_combination) = term->matchParamTimesLinearCombinationOfVariables();
}
catch (MatchFailureException &e)
{
@ -5316,11 +5320,10 @@ BinaryOpNode::getPacAREC(int lhs_symb_id, int lhs_orig_symb_id,
}
}
for (auto &t : m.second)
get<3>(t) *= it.second; // Update sign of constants
for (auto &[vid, vlag, pidtmp, constant] : linear_combination)
constant *= sign; // Update sign of constants
int pid = get<0>(m);
for (auto [vid, vlag, pidtmp, constant] : m.second)
for (auto [vid, vlag, pidtmp, constant] : linear_combination)
{
if (pid == -1)
pid = pidtmp;

View File

@ -644,6 +644,13 @@ public:
*/
vector<tuple<int, int, int, double>> matchLinearCombinationOfVariables(bool variable_obligatory_in_each_term = true) const;
/* Matches a parameter, times a linear combination of variables (endo or
exo), where scalars can be constant*parameters.
The first output argument is the symbol ID of the parameter.
The second output argument is the linear combination, in the same format
as the output of matchLinearCombinationOfVariables(). */
pair<int, vector<tuple<int, int, int, double>>> matchParamTimesLinearCombinationOfVariables() const;
/* Matches a linear combination of endogenous, where scalars can be any
constant expression (i.e. containing no endogenous, no exogenous and no
exogenous deterministic). The linear combination can contain constant
@ -653,8 +660,6 @@ public:
the sum of all constant (intercept) terms */
pair<vector<pair<int, expr_t>>, expr_t> matchLinearCombinationOfEndogenousWithConstant() const;
pair<int, vector<tuple<int, int, int, double>>> matchParamTimesLinearCombinationOfVariables() const;
/* Matches an expression of the form parameter*(var1-endo2).
endo2 must correspond to symb_id. var1 must be an endogenous or an
exogenous; it must be of the form X(-1) or log(X(-1)) or log(X)(-1) (unary ops aux var),