C++17 modernization: use std::optional in expression matching functions

fix-tolerance-parameters
Sébastien Villemot 2022-05-16 17:42:24 +02:00
parent 3496d77eb4
commit 0eb11d3323
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
5 changed files with 116 additions and 90 deletions

View File

@ -3814,8 +3814,8 @@ DynamicModel::analyzePacEquationStructure(const string &name, map<string, string
= arg2->getPacOptimizingShareAndExprNodes(lhs_symb_id, lhs_orig_symb_id); = arg2->getPacOptimizingShareAndExprNodes(lhs_symb_id, lhs_orig_symb_id);
pair<int, vector<tuple<int, bool, int>>> ec_params_and_vars; pair<int, vector<tuple<int, bool, int>>> ec_params_and_vars;
vector<tuple<int, int, int>> ar_params_and_vars; vector<tuple<optional<int>, optional<int>, int>> ar_params_and_vars;
vector<tuple<int, int, int, double>> non_optim_vars_params_and_constants, optim_additive_vars_params_and_constants, additive_vars_params_and_constants; vector<tuple<int, int, optional<int>, double>> non_optim_vars_params_and_constants, optim_additive_vars_params_and_constants, additive_vars_params_and_constants;
if (!optim_part) if (!optim_part)
{ {
auto bopn = dynamic_cast<BinaryOpNode *>(equation->arg2); auto bopn = dynamic_cast<BinaryOpNode *>(equation->arg2);

View File

@ -299,7 +299,7 @@ ExprNode::fillErrorCorrectionRow(int eqn,
for (const auto &[term, sign] : terms) for (const auto &[term, sign] : terms)
{ {
int speed_of_adjustment_param; int speed_of_adjustment_param;
vector<tuple<int, int, int, double>> error_linear_combination; vector<tuple<int, int, optional<int>, double>> error_linear_combination;
try try
{ {
tie(speed_of_adjustment_param, error_linear_combination) = term->matchParamTimesLinearCombinationOfVariables(); tie(speed_of_adjustment_param, error_linear_combination) = term->matchParamTimesLinearCombinationOfVariables();
@ -326,7 +326,7 @@ ExprNode::fillErrorCorrectionRow(int eqn,
continue; continue;
// Now fill the matrices // Now fill the matrices
for (auto [var_id, lag, param_id, constant] : error_linear_combination) for (const auto &[var_id, lag, param_id, constant] : error_linear_combination)
{ {
auto [orig_vid, orig_lag] = datatree.symbol_table.unrollDiffLeadLagChain(var_id, lag); auto [orig_vid, orig_lag] = datatree.symbol_table.unrollDiffLeadLagChain(var_id, lag);
if (find(target_lhs.begin(), target_lhs.end(), orig_vid) == target_lhs.end()) if (find(target_lhs.begin(), target_lhs.end(), orig_vid) == target_lhs.end())
@ -342,7 +342,7 @@ ExprNode::fillErrorCorrectionRow(int eqn,
cerr << "ERROR in trend component model: LHS variable should not appear with a multiplicative constant in error correction term" << endl; cerr << "ERROR in trend component model: LHS variable should not appear with a multiplicative constant in error correction term" << endl;
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
if (param_id != -1) if (*param_id)
{ {
cerr << "ERROR in trend component model: spurious parameter in error correction term" << endl; cerr << "ERROR in trend component model: spurious parameter in error correction term" << endl;
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
@ -362,8 +362,8 @@ ExprNode::fillErrorCorrectionRow(int eqn,
int colidx = static_cast<int>(distance(target_lhs.begin(), find(target_lhs.begin(), target_lhs.end(), orig_vid))); int colidx = static_cast<int>(distance(target_lhs.begin(), find(target_lhs.begin(), target_lhs.end(), orig_vid)));
expr_t e = datatree.AddTimes(datatree.AddVariable(speed_of_adjustment_param), expr_t e = datatree.AddTimes(datatree.AddVariable(speed_of_adjustment_param),
datatree.AddPossiblyNegativeConstant(-constant)); datatree.AddPossiblyNegativeConstant(-constant));
if (param_id != -1) if (param_id)
e = datatree.AddTimes(e, datatree.AddVariable(param_id)); e = datatree.AddTimes(e, datatree.AddVariable(*param_id));
if (auto coor = make_pair(eqn, colidx); A0star.contains(coor)) if (auto coor = make_pair(eqn, colidx); A0star.contains(coor))
A0star[coor] = datatree.AddPlus(e, A0star[coor]); A0star[coor] = datatree.AddPlus(e, A0star[coor]);
else else
@ -5350,8 +5350,8 @@ BinaryOpNode::findTargetVariable(int lhs_symb_id) const
void void
BinaryOpNode::getPacAREC(int lhs_symb_id, int lhs_orig_symb_id, BinaryOpNode::getPacAREC(int lhs_symb_id, int lhs_orig_symb_id,
pair<int, vector<tuple<int, bool, int>>> &ec_params_and_vars, pair<int, vector<tuple<int, bool, int>>> &ec_params_and_vars,
vector<tuple<int, int, int>> &ar_params_and_vars, vector<tuple<optional<int>, optional<int>, int>> &ar_params_and_vars,
vector<tuple<int, int, int, double>> &additive_vars_params_and_constants) const vector<tuple<int, int, optional<int>, double>> &additive_vars_params_and_constants) const
{ {
ec_params_and_vars.first = -1; ec_params_and_vars.first = -1;
@ -5383,12 +5383,12 @@ BinaryOpNode::getPacAREC(int lhs_symb_id, int lhs_orig_symb_id,
if (dynamic_cast<PacExpectationNode *>(term)) if (dynamic_cast<PacExpectationNode *>(term))
continue; continue;
int pid; optional<int> pid;
vector<tuple<int, int, int, double>> linear_combination; vector<tuple<int, int, optional<int>, double>> linear_combination;
try try
{ {
pid = -1; auto [vid, lag, pid, constant] = term->matchVariableTimesConstantTimesParam(true);
linear_combination = { term->matchVariableTimesConstantTimesParam() }; linear_combination.emplace_back(vid.value(), lag, move(pid), constant);
} }
catch (MatchFailureException &e) catch (MatchFailureException &e)
{ {
@ -5406,11 +5406,11 @@ BinaryOpNode::getPacAREC(int lhs_symb_id, int lhs_orig_symb_id,
for (auto &[vid, vlag, pidtmp, constant] : linear_combination) for (auto &[vid, vlag, pidtmp, constant] : linear_combination)
constant *= sign; // Update sign of constants constant *= sign; // Update sign of constants
for (auto [vid, vlag, pidtmp, constant] : linear_combination) for (const auto &[vid, vlag, pidtmp, constant] : linear_combination)
{ {
if (pid == -1) if (!pid)
pid = pidtmp; pid = pidtmp;
else if (pidtmp >= 0) else if (*pidtmp)
{ {
cerr << "unexpected parameter found in PAC equation" << endl; cerr << "unexpected parameter found in PAC equation" << endl;
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
@ -5420,13 +5420,13 @@ BinaryOpNode::getPacAREC(int lhs_symb_id, int lhs_orig_symb_id,
vidorig == lhs_symb_id) vidorig == lhs_symb_id)
{ {
// This is an autoregressive term // This is an autoregressive term
if (constant != 1 || pid == -1 || !datatree.symbol_table.isDiffAuxiliaryVariable(vid)) if (constant != 1 || !pid || !datatree.symbol_table.isDiffAuxiliaryVariable(vid))
{ {
cerr << "BinaryOpNode::getPacAREC: autoregressive terms must be of the form 'parameter*diff_lagged_variable" << endl; cerr << "BinaryOpNode::getPacAREC: autoregressive terms must be of the form 'parameter*diff_lagged_variable" << endl;
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
if (static_cast<int>(ar_params_and_vars.size()) < -vlagorig) if (static_cast<int>(ar_params_and_vars.size()) < -vlagorig)
ar_params_and_vars.resize(-vlagorig, { -1, -1, 0 }); ar_params_and_vars.resize(-vlagorig, { nullopt, nullopt, 0 });
ar_params_and_vars[-vlagorig-1] = { pid, vid, vlag }; ar_params_and_vars[-vlagorig-1] = { pid, vid, vlag };
} }
else else
@ -5604,11 +5604,12 @@ BinaryOpNode::fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<i
decomposeAdditiveTerms(terms, 1); decomposeAdditiveTerms(terms, 1);
for (const auto &it : terms) for (const auto &it : terms)
{ {
int vid, lag, param_id; optional<int> vid, param_id;
int lag;
double constant; double constant;
try try
{ {
tie(vid, lag, param_id, constant) = it.first->matchVariableTimesConstantTimesParam(); tie(vid, lag, param_id, constant) = it.first->matchVariableTimesConstantTimesParam(true);
constant *= it.second; constant *= it.second;
} }
catch (MatchFailureException &e) catch (MatchFailureException &e)
@ -5616,23 +5617,23 @@ BinaryOpNode::fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<i
continue; continue;
} }
tie(vid, lag) = datatree.symbol_table.unrollDiffLeadLagChain(vid, lag); tie(vid, lag) = datatree.symbol_table.unrollDiffLeadLagChain(*vid, lag);
if (find(lhs.begin(), lhs.end(), vid) == lhs.end()) if (find(lhs.begin(), lhs.end(), *vid) == lhs.end())
continue; continue;
if (AR.contains({eqn, -lag, vid})) if (AR.contains({eqn, -lag, *vid}))
{ {
cerr << "BinaryOpNode::fillAutoregressiveRow: Error filling AR matrix: " cerr << "BinaryOpNode::fillAutoregressiveRow: Error filling AR matrix: "
<< "lag/symb_id encountered more than once in equation" << endl; << "lag/symb_id encountered more than once in equation" << endl;
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
if (constant != 1 || param_id == -1) if (constant != 1 || !param_id)
{ {
cerr << "BinaryOpNode::fillAutoregressiveRow: autoregressive terms must be of the form 'parameter*lagged_variable" << endl; cerr << "BinaryOpNode::fillAutoregressiveRow: autoregressive terms must be of the form 'parameter*lagged_variable" << endl;
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
AR[{eqn, -lag, vid}] = datatree.AddVariable(param_id); AR[{eqn, -lag, *vid}] = datatree.AddVariable(*param_id);
} }
} }
@ -8811,25 +8812,26 @@ BinaryOpNode::decomposeMultiplicativeFactors(vector<pair<expr_t, int>> &factors,
ExprNode::decomposeMultiplicativeFactors(factors, current_exponent); ExprNode::decomposeMultiplicativeFactors(factors, current_exponent);
} }
tuple<int, int, int, double> tuple<optional<int>, int, optional<int>, double>
ExprNode::matchVariableTimesConstantTimesParam(bool variable_obligatory) const ExprNode::matchVariableTimesConstantTimesParam(bool variable_obligatory) const
{ {
int variable_id = -1, lag = 0, param_id = -1; optional<int> variable_id, param_id;
int lag = 0;
double constant = 1.0; double constant = 1.0;
matchVTCTPHelper(variable_id, lag, param_id, constant, false); matchVTCTPHelper(variable_id, lag, param_id, constant, false);
if (variable_obligatory && variable_id == -1) if (variable_obligatory && !variable_id)
throw MatchFailureException{"No variable in this expression"}; throw MatchFailureException{"No variable in this expression"};
return {variable_id, lag, param_id, constant}; return { move(variable_id), lag, move(param_id), constant};
} }
void void
ExprNode::matchVTCTPHelper(int &var_id, int &lag, int &param_id, double &constant, bool at_denominator) const ExprNode::matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const
{ {
throw MatchFailureException{"Expression not allowed in linear combination of variables"}; throw MatchFailureException{"Expression not allowed in linear combination of variables"};
} }
void void
NumConstNode::matchVTCTPHelper(int &var_id, int &lag, int &param_id, double &constant, bool at_denominator) const NumConstNode::matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const
{ {
double myvalue = eval({}); double myvalue = eval({});
if (at_denominator) if (at_denominator)
@ -8839,7 +8841,7 @@ NumConstNode::matchVTCTPHelper(int &var_id, int &lag, int &param_id, double &con
} }
void void
VariableNode::matchVTCTPHelper(int &var_id, int &lag, int &param_id, double &constant, bool at_denominator) const VariableNode::matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const
{ {
if (at_denominator) if (at_denominator)
throw MatchFailureException{"A variable or parameter cannot appear at denominator"}; throw MatchFailureException{"A variable or parameter cannot appear at denominator"};
@ -8847,14 +8849,14 @@ VariableNode::matchVTCTPHelper(int &var_id, int &lag, int &param_id, double &con
SymbolType type = get_type(); SymbolType type = get_type();
if (type == SymbolType::endogenous || type == SymbolType::exogenous) if (type == SymbolType::endogenous || type == SymbolType::exogenous)
{ {
if (var_id != -1) if (var_id)
throw MatchFailureException{"More than one variable in this expression"}; throw MatchFailureException{"More than one variable in this expression"};
var_id = symb_id; var_id = symb_id;
lag = this->lag; lag = this->lag;
} }
else if (type == SymbolType::parameter) else if (type == SymbolType::parameter)
{ {
if (param_id != -1) if (param_id)
throw MatchFailureException{"More than one parameter in this expression"}; throw MatchFailureException{"More than one parameter in this expression"};
param_id = symb_id; param_id = symb_id;
} }
@ -8863,7 +8865,7 @@ VariableNode::matchVTCTPHelper(int &var_id, int &lag, int &param_id, double &con
} }
void void
UnaryOpNode::matchVTCTPHelper(int &var_id, int &lag, int &param_id, double &constant, bool at_denominator) const UnaryOpNode::matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const
{ {
if (op_code == UnaryOpcode::uminus) if (op_code == UnaryOpcode::uminus)
{ {
@ -8875,7 +8877,7 @@ UnaryOpNode::matchVTCTPHelper(int &var_id, int &lag, int &param_id, double &cons
} }
void void
BinaryOpNode::matchVTCTPHelper(int &var_id, int &lag, int &param_id, double &constant, bool at_denominator) const BinaryOpNode::matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const
{ {
if (op_code == BinaryOpcode::times || op_code == BinaryOpcode::divide) if (op_code == BinaryOpcode::times || op_code == BinaryOpcode::divide)
{ {
@ -8889,24 +8891,41 @@ BinaryOpNode::matchVTCTPHelper(int &var_id, int &lag, int &param_id, double &con
throw MatchFailureException{"Operator not allowed in this expression"}; throw MatchFailureException{"Operator not allowed in this expression"};
} }
vector<tuple<int, int, int, double>> vector<tuple<int, int, optional<int>, double>>
ExprNode::matchLinearCombinationOfVariables(bool variable_obligatory_in_each_term) const ExprNode::matchLinearCombinationOfVariables() const
{ {
vector<pair<expr_t, int>> terms; vector<pair<expr_t, int>> terms;
decomposeAdditiveTerms(terms); decomposeAdditiveTerms(terms);
vector<tuple<int, int, int, double>> result; vector<tuple<int, int, optional<int>, double>> result;
for (auto [term, sign] : terms) for (auto [term, sign] : terms)
{ {
auto m = term->matchVariableTimesConstantTimesParam(variable_obligatory_in_each_term); auto [variable_id, lag, param_id, constant] = term->matchVariableTimesConstantTimesParam(true);
get<3>(m) *= sign; constant *= sign;
result.push_back(m); result.emplace_back(variable_id.value(), lag, move(param_id), constant);
} }
return result; return result;
} }
pair<int, vector<tuple<int, int, int, double>>> vector<tuple<optional<int>, int, optional<int>, double>>
ExprNode::matchLinearCombinationOfVariablesPlusConstant() const
{
vector<pair<expr_t, int>> terms;
decomposeAdditiveTerms(terms);
vector<tuple<optional<int>, int, optional<int>, double>> result;
for (auto [term, sign] : terms)
{
auto m = term->matchVariableTimesConstantTimesParam(false);
get<3>(m) *= sign;
result.push_back(move(m));
}
return result;
}
pair<int, vector<tuple<int, int, optional<int>, double>>>
ExprNode::matchParamTimesLinearCombinationOfVariables() const ExprNode::matchParamTimesLinearCombinationOfVariables() const
{ {
auto bopn = dynamic_cast<const BinaryOpNode *>(this); auto bopn = dynamic_cast<const BinaryOpNode *>(this);

View File

@ -244,7 +244,7 @@ protected:
const temporary_terms_idxs_t &temporary_terms_idxs) const; const temporary_terms_idxs_t &temporary_terms_idxs) const;
// Internal helper for matchVariableTimesConstantTimesParam() // Internal helper for matchVariableTimesConstantTimesParam()
virtual void matchVTCTPHelper(int &var_id, int &lag, int &param_id, double &constant, bool at_denominator) const; virtual void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const;
/* Computes the representative element and the index under the /* Computes the representative element and the index under the
lag-equivalence relationship. See the comment above lag-equivalence relationship. See the comment above
@ -638,21 +638,28 @@ public:
//! Matches a linear combination of variables (endo or exo), where scalars can be constant*parameter //! Matches a linear combination of variables (endo or exo), where scalars can be constant*parameter
/*! Returns a list of (variable_id, lag, param_id, constant) /*! Returns a list of (variable_id, lag, param_id, constant)
corresponding to the terms in the expression. When there is no corresponding to the terms in the expression. When there is no
parameter in a term, param_id == -1. parameter in a term, param_id is nullopt.
Can throw a MatchFailureException. Can throw a MatchFailureException.
if `variable_obligatory_in_each_term` is true, then every part of the linear combination must contain a variable;
otherwise, if `variable_obligatory_in_each_term` is false, then any linear
combination of constant/variable/param is matched (and variable_id == -1
for terms without a variable).
*/ */
vector<tuple<int, int, int, double>> matchLinearCombinationOfVariables(bool variable_obligatory_in_each_term = true) const; vector<tuple<int, int, optional<int>, double>> matchLinearCombinationOfVariables() const;
/* Matches a linear combination of variables (endo or exo), where scalars can
be constant*parameter. In addition, there may be one or more scalar terms
(i.e. without a variable).
Returns a list of (variable_id, lag, param_id, constant)
corresponding to the terms in the expression. When there is no
parameter in a term, param_id is nullopt. When the term is scalar (i.e.
no variable), then variable_id is nullopt.
Can throw a MatchFailureException.
*/
vector<tuple<optional<int>, int, optional<int>, double>> matchLinearCombinationOfVariablesPlusConstant() const;
/* Matches a parameter, times a linear combination of variables (endo or /* Matches a parameter, times a linear combination of variables (endo or
exo), where scalars can be constant*parameters. exo), where scalars can be constant*parameters.
The first output argument is the symbol ID of the parameter. The first output argument is the symbol ID of the parameter.
The second output argument is the linear combination, in the same format The second output argument is the linear combination, in the same format
as the output of matchLinearCombinationOfVariables(). */ as the output of matchLinearCombinationOfVariables(). */
pair<int, vector<tuple<int, int, int, double>>> matchParamTimesLinearCombinationOfVariables() const; pair<int, vector<tuple<int, int, optional<int>, double>>> matchParamTimesLinearCombinationOfVariables() const;
/* Matches a linear combination of endogenous, where scalars can be any /* Matches a linear combination of endogenous, where scalars can be any
constant expression (i.e. containing no endogenous, no exogenous and no constant expression (i.e. containing no endogenous, no exogenous and no
@ -705,17 +712,17 @@ public:
// Matches an expression of the form variable*constant*parameter // Matches an expression of the form variable*constant*parameter
/* Returns a tuple (variable_id, lag, param_id, constant). /* Returns a tuple (variable_id, lag, param_id, constant).
The variable must be an exogenous or an endogenous. If `variable_obligatory` is true, then the expression must contain a variable.
If present, the variable must be an exogenous or an endogenous. If absent,
and `variable_obligatory` is false, then variable_id is nullopt.
The constant is optional (in which case 1 is returned); there can be The constant is optional (in which case 1 is returned); there can be
several multiplicative constants; constants can also appear at the several multiplicative constants; constants can also appear at the
denominator (i.e. after a divide sign). denominator (i.e. after a divide sign).
The parameter is optional (in which case param_id == -1). The parameter is optional (in which case param_id is nullopt).
If the expression is not of the expected form, throws a If the expression is not of the expected form, throws a
MatchFailureException MatchFailureException
if `variable_obligatory` is true, then the linear combination must contain a variable;
otherwise, if `variable_obligatory`, then an expression is matched that has any mix of constant/variable/param
*/ */
tuple<int, int, int, double> matchVariableTimesConstantTimesParam(bool variable_obligatory = true) const; tuple<optional<int>, int, optional<int>, double> matchVariableTimesConstantTimesParam(bool variable_obligatory) const;
/* Matches an expression of the form endogenous*constant where constant is an /* Matches an expression of the form endogenous*constant where constant is an
expression containing no endogenous, no exogenous and no exogenous deterministic. expression containing no endogenous, no exogenous and no exogenous deterministic.
@ -775,7 +782,7 @@ public:
private: private:
expr_t computeDerivative(int deriv_id) override; expr_t computeDerivative(int deriv_id) override;
protected: protected:
void matchVTCTPHelper(int &var_id, int &lag, int &param_id, double &constant, bool at_denominator) const override; void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
public: public:
NumConstNode(DataTree &datatree_arg, int idx_arg, int id_arg); NumConstNode(DataTree &datatree_arg, int idx_arg, int id_arg);
void prepareForDerivation() override; void prepareForDerivation() override;
@ -847,7 +854,7 @@ public:
private: private:
expr_t computeDerivative(int deriv_id) override; expr_t computeDerivative(int deriv_id) override;
protected: protected:
void matchVTCTPHelper(int &var_id, int &lag, int &param_id, double &constant, bool at_denominator) const override; void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
public: public:
VariableNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg, int lag_arg); VariableNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg, int lag_arg);
void prepareForDerivation() override; void prepareForDerivation() override;
@ -915,7 +922,7 @@ public:
class UnaryOpNode : public ExprNode class UnaryOpNode : public ExprNode
{ {
protected: protected:
void matchVTCTPHelper(int &var_id, int &lag, int &param_id, double &constant, bool at_denominator) const override; void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
public: public:
const expr_t arg; const expr_t arg;
//! Stores the information set. Only used for expectation operator //! Stores the information set. Only used for expectation operator
@ -1019,7 +1026,7 @@ public:
class BinaryOpNode : public ExprNode class BinaryOpNode : public ExprNode
{ {
protected: protected:
void matchVTCTPHelper(int &var_id, int &lag, int &param_id, double &constant, bool at_denominator) const override; void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
public: public:
const expr_t arg1, arg2; const expr_t arg1, arg2;
const BinaryOpcode op_code; const BinaryOpcode op_code;
@ -1139,8 +1146,8 @@ public:
*/ */
void getPacAREC(int lhs_symb_id, int lhs_orig_symb_id, void getPacAREC(int lhs_symb_id, int lhs_orig_symb_id,
pair<int, vector<tuple<int, bool, int>>> &ec_params_and_vars, pair<int, vector<tuple<int, bool, int>>> &ec_params_and_vars,
vector<tuple<int, int, int>> &ar_params_and_vars, vector<tuple<optional<int>, optional<int>, int>> &ar_params_and_vars,
vector<tuple<int, int, int, double>> &additive_vars_params_and_constants) const; vector<tuple<int, int, optional<int>, double>> &additive_vars_params_and_constants) const;
//! Finds the share of optimizing agents in the PAC equation, //! Finds the share of optimizing agents in the PAC equation,
//! the expr node associated with it, //! the expr node associated with it,

View File

@ -803,10 +803,10 @@ VarExpectationModelTable::writeOutput(const string &basename, ostream &output) c
constants_list << ", "; constants_list << ", ";
} }
vars_list << symbol_table.getTypeSpecificID(get<0>(*it))+1; vars_list << symbol_table.getTypeSpecificID(get<0>(*it))+1;
if (get<1>(*it) == -1) if (get<1>(*it))
params_list << "NaN"; params_list << symbol_table.getTypeSpecificID(*get<1>(*it))+1;
else else
params_list << symbol_table.getTypeSpecificID(get<1>(*it))+1; params_list << "NaN";
constants_list << get<2>(*it); constants_list << get<2>(*it);
} }
output << mstruct << ".expr.vars = [ " << vars_list.str() << " ];" << endl output << mstruct << ".expr.vars = [ " << vars_list.str() << " ];" << endl
@ -1130,7 +1130,7 @@ PacModelTable::transformPass(const lag_equivalence_table_t &unary_ops_nodes,
if (growth[name]) if (growth[name])
try try
{ {
growth_info[name] = growth[name]->matchLinearCombinationOfVariables(false); growth_info[name] = growth[name]->matchLinearCombinationOfVariablesPlusConstant();
} }
catch (ExprNode::MatchFailureException &e) catch (ExprNode::MatchFailureException &e)
{ {
@ -1183,7 +1183,7 @@ PacModelTable::transformPass(const lag_equivalence_table_t &unary_ops_nodes,
if (growth_component) if (growth_component)
try try
{ {
growth_component_info = growth_component->matchLinearCombinationOfVariables(false); growth_component_info = growth_component->matchLinearCombinationOfVariablesPlusConstant();
} }
catch (ExprNode::MatchFailureException &e) catch (ExprNode::MatchFailureException &e)
{ {
@ -1377,10 +1377,10 @@ PacModelTable::writeOutput(const string &basename, ostream &output) const
for (auto [growth_symb_id, growth_lag, param_id, constant] : gi) for (auto [growth_symb_id, growth_lag, param_id, constant] : gi)
{ {
string structname = fieldname + "(" + to_string(i++) + ")."; string structname = fieldname + "(" + to_string(i++) + ").";
if (growth_symb_id >= 0) if (growth_symb_id)
{ {
string var_field = "endo_id"; string var_field = "endo_id";
if (symbol_table.getType(growth_symb_id) == SymbolType::exogenous) if (symbol_table.getType(*growth_symb_id) == SymbolType::exogenous)
{ {
var_field = "exo_id"; var_field = "exo_id";
output << structname << "endo_id = 0;" << endl; output << structname << "endo_id = 0;" << endl;
@ -1390,7 +1390,7 @@ PacModelTable::writeOutput(const string &basename, ostream &output) const
try try
{ {
// case when this is not the highest lag of the growth variable // case when this is not the highest lag of the growth variable
int aux_symb_id = symbol_table.searchAuxiliaryVars(growth_symb_id, growth_lag); int aux_symb_id = symbol_table.searchAuxiliaryVars(*growth_symb_id, growth_lag);
output << structname << var_field << " = " << symbol_table.getTypeSpecificID(aux_symb_id) + 1 << ";" << endl output << structname << var_field << " = " << symbol_table.getTypeSpecificID(aux_symb_id) + 1 << ";" << endl
<< structname << "lag = 0;" << endl; << structname << "lag = 0;" << endl;
} }
@ -1400,14 +1400,14 @@ PacModelTable::writeOutput(const string &basename, ostream &output) const
{ {
// case when this is the highest lag of the growth variable // case when this is the highest lag of the growth variable
int tmp_growth_lag = growth_lag + 1; int tmp_growth_lag = growth_lag + 1;
int aux_symb_id = symbol_table.searchAuxiliaryVars(growth_symb_id, tmp_growth_lag); int aux_symb_id = symbol_table.searchAuxiliaryVars(*growth_symb_id, tmp_growth_lag);
output << structname << var_field << " = " << symbol_table.getTypeSpecificID(aux_symb_id) + 1 << ";" << endl output << structname << var_field << " = " << symbol_table.getTypeSpecificID(aux_symb_id) + 1 << ";" << endl
<< structname << "lag = -1;" << endl; << structname << "lag = -1;" << endl;
} }
catch (...) catch (...)
{ {
// case when there is no aux var for the variable // case when there is no aux var for the variable
output << structname << var_field << " = "<< symbol_table.getTypeSpecificID(growth_symb_id) + 1 << ";" << endl output << structname << var_field << " = "<< symbol_table.getTypeSpecificID(*growth_symb_id) + 1 << ";" << endl
<< structname << "lag = " << growth_lag << ";" << endl; << structname << "lag = " << growth_lag << ";" << endl;
} }
} }
@ -1417,7 +1417,7 @@ PacModelTable::writeOutput(const string &basename, ostream &output) const
<< structname << "exo_id = 0;" << endl << structname << "exo_id = 0;" << endl
<< structname << "lag = 0;" << endl; << structname << "lag = 0;" << endl;
output << structname << "param_id = " output << structname << "param_id = "
<< (param_id == -1 ? 0 : symbol_table.getTypeSpecificID(param_id) + 1) << ";" << endl << (param_id ? symbol_table.getTypeSpecificID(*param_id) + 1 : 0) << ";" << endl
<< structname << "constant = " << constant << ";" << endl; << structname << "constant = " << constant << ";" << endl;
} }
}; };
@ -1468,7 +1468,7 @@ PacModelTable::writeOutput(const string &basename, ostream &output) const
for (auto &[name, val] : equation_info) for (auto &[name, val] : equation_info)
{ {
auto [lhs_pac_var, optim_share_index, ar_params_and_vars, ec_params_and_vars, non_optim_vars_params_and_constants, additive_vars_params_and_constants, optim_additive_vars_params_and_constants] = val; auto &[lhs_pac_var, optim_share_index, ar_params_and_vars, ec_params_and_vars, non_optim_vars_params_and_constants, additive_vars_params_and_constants, optim_additive_vars_params_and_constants] = val;
output << "M_.pac." << name << ".lhs_var = " output << "M_.pac." << name << ".lhs_var = "
<< symbol_table.getTypeSpecificID(lhs_pac_var.first) + 1 << ";" << endl; << symbol_table.getTypeSpecificID(lhs_pac_var.first) + 1 << ";" << endl;
@ -1479,19 +1479,19 @@ PacModelTable::writeOutput(const string &basename, ostream &output) const
output << "M_.pac." << name << ".ec.params = " output << "M_.pac." << name << ".ec.params = "
<< symbol_table.getTypeSpecificID(ec_params_and_vars.first) + 1 << ";" << endl << symbol_table.getTypeSpecificID(ec_params_and_vars.first) + 1 << ";" << endl
<< "M_.pac." << name << ".ec.vars = ["; << "M_.pac." << name << ".ec.vars = [";
for (auto it : ec_params_and_vars.second) for (auto &it : ec_params_and_vars.second)
output << symbol_table.getTypeSpecificID(get<0>(it)) + 1 << " "; output << symbol_table.getTypeSpecificID(get<0>(it)) + 1 << " ";
output << "];" << endl output << "];" << endl
<< "M_.pac." << name << ".ec.istarget = ["; << "M_.pac." << name << ".ec.istarget = [";
for (auto it : ec_params_and_vars.second) for (auto &it : ec_params_and_vars.second)
output << boolalpha << get<1>(it) << " "; output << boolalpha << get<1>(it) << " ";
output << "];" << endl output << "];" << endl
<< "M_.pac." << name << ".ec.scale = ["; << "M_.pac." << name << ".ec.scale = [";
for (auto it : ec_params_and_vars.second) for (auto &it : ec_params_and_vars.second)
output << get<2>(it) << " "; output << get<2>(it) << " ";
output << "];" << endl output << "];" << endl
<< "M_.pac." << name << ".ec.isendo = ["; << "M_.pac." << name << ".ec.isendo = [";
for (auto it : ec_params_and_vars.second) for (auto &it : ec_params_and_vars.second)
switch (symbol_table.getType(get<0>(it))) switch (symbol_table.getType(get<0>(it)))
{ {
case SymbolType::endogenous: case SymbolType::endogenous:
@ -1507,11 +1507,11 @@ PacModelTable::writeOutput(const string &basename, ostream &output) const
output << "];" << endl output << "];" << endl
<< "M_.pac." << name << ".ar.params = ["; << "M_.pac." << name << ".ar.params = [";
for (auto &[pid, vid, vlag] : ar_params_and_vars) for (auto &[pid, vid, vlag] : ar_params_and_vars)
output << (pid != -1 ? symbol_table.getTypeSpecificID(pid) + 1 : -1) << " "; output << (pid ? symbol_table.getTypeSpecificID(*pid) + 1 : -1) << " ";
output << "];" << endl output << "];" << endl
<< "M_.pac." << name << ".ar.vars = ["; << "M_.pac." << name << ".ar.vars = [";
for (auto &[pid, vid, vlag] : ar_params_and_vars) for (auto &[pid, vid, vlag] : ar_params_and_vars)
output << (vid != -1 ? symbol_table.getTypeSpecificID(vid) + 1 : -1) << " "; output << (vid ? symbol_table.getTypeSpecificID(*vid) + 1 : -1) << " ";
output << "];" << endl output << "];" << endl
<< "M_.pac." << name << ".ar.lags = ["; << "M_.pac." << name << ".ar.lags = [";
for (auto &[pid, vid, vlag] : ar_params_and_vars) for (auto &[pid, vid, vlag] : ar_params_and_vars)
@ -1522,8 +1522,8 @@ PacModelTable::writeOutput(const string &basename, ostream &output) const
{ {
output << "M_.pac." << name << ".non_optimizing_behaviour.params = ["; output << "M_.pac." << name << ".non_optimizing_behaviour.params = [";
for (auto &it : non_optim_vars_params_and_constants) for (auto &it : non_optim_vars_params_and_constants)
if (get<2>(it) >= 0) if (get<2>(it))
output << symbol_table.getTypeSpecificID(get<2>(it)) + 1 << " "; output << symbol_table.getTypeSpecificID(*get<2>(it)) + 1 << " ";
else else
output << "NaN "; output << "NaN ";
output << "];" << endl output << "];" << endl
@ -1559,8 +1559,8 @@ PacModelTable::writeOutput(const string &basename, ostream &output) const
{ {
output << "M_.pac." << name << ".additive.params = ["; output << "M_.pac." << name << ".additive.params = [";
for (auto &it : additive_vars_params_and_constants) for (auto &it : additive_vars_params_and_constants)
if (get<2>(it) >= 0) if (get<2>(it))
output << symbol_table.getTypeSpecificID(get<2>(it)) + 1 << " "; output << symbol_table.getTypeSpecificID(*get<2>(it)) + 1 << " ";
else else
output << "NaN "; output << "NaN ";
output << "];" << endl output << "];" << endl
@ -1596,8 +1596,8 @@ PacModelTable::writeOutput(const string &basename, ostream &output) const
{ {
output << "M_.pac." << name << ".optim_additive.params = ["; output << "M_.pac." << name << ".optim_additive.params = [";
for (auto &it : optim_additive_vars_params_and_constants) for (auto &it : optim_additive_vars_params_and_constants)
if (get<2>(it) >= 0) if (get<2>(it))
output << symbol_table.getTypeSpecificID(get<2>(it)) + 1 << " "; output << symbol_table.getTypeSpecificID(*get<2>(it)) + 1 << " ";
else else
output << "NaN "; output << "NaN ";
output << "];" << endl output << "];" << endl

View File

@ -205,7 +205,7 @@ private:
// For each model, list of generated auxiliary param ids, in variable-major order // For each model, list of generated auxiliary param ids, in variable-major order
map<string, vector<int>> aux_param_symb_ids; map<string, vector<int>> aux_param_symb_ids;
// Decomposition of the expression // Decomposition of the expression
map<string, vector<tuple<int, int, double>>> vars_params_constants; map<string, vector<tuple<int, optional<int>, double>>> vars_params_constants;
public: public:
explicit VarExpectationModelTable(SymbolTable &symbol_table_arg); explicit VarExpectationModelTable(SymbolTable &symbol_table_arg);
void addVarExpectationModel(string name_arg, expr_t expression_arg, string aux_model_name_arg, void addVarExpectationModel(string name_arg, expr_t expression_arg, string aux_model_name_arg,
@ -234,9 +234,9 @@ private:
pac_target_info block. */ pac_target_info block. */
map<string, expr_t> growth, original_growth; map<string, expr_t> growth, original_growth;
/* Information about the structure of growth expressions (which must be a /* Information about the structure of growth expressions (which must be a
linear combination of variables). linear combination of variables, possibly with additional constants).
Each tuple represents a term: (endo_id, lag, param_id, constant) */ Each tuple represents a term: (endo_id, lag, param_id, constant) */
using growth_info_t = vector<tuple<int, int, int, double>>; using growth_info_t = vector<tuple<optional<int>, int, optional<int>, double>>;
map<string, growth_info_t> growth_info; map<string, growth_info_t> growth_info;
// The “auxname” option of pac_model (empty if not passed) // The “auxname” option of pac_model (empty if not passed)
map<string, string> auxname; map<string, string> auxname;
@ -283,7 +283,7 @@ public:
(lhs, optim_share_index, ar_params_and_vars, ec_params_and_vars, non_optim_vars_params_and_constants, additive_vars_params_and_constants, optim_additive_vars_params_and_constants) (lhs, optim_share_index, ar_params_and_vars, ec_params_and_vars, non_optim_vars_params_and_constants, additive_vars_params_and_constants, optim_additive_vars_params_and_constants)
*/ */
using equation_info_t = map<string, using equation_info_t = map<string,
tuple<pair<int, int>, int, vector<tuple<int, int, int>>, pair<int, vector<tuple<int, bool, int>>>, vector<tuple<int, int, int, double>>, vector<tuple<int, int, int, double>>, vector<tuple<int, int, int, double>>>>; tuple<pair<int, int>, int, vector<tuple<optional<int>, optional<int>, int>>, pair<int, vector<tuple<int, bool, int>>>, vector<tuple<int, int, optional<int>, double>>, vector<tuple<int, int, optional<int>, double>>, vector<tuple<int, int, optional<int>, double>>>>;
private: private:
equation_info_t equation_info; equation_info_t equation_info;