C++17 modernization: use std::optional for trend variables in TCM

fix-tolerance-parameters
Sébastien Villemot 2022-05-05 18:39:04 +02:00
parent fb3b1c301f
commit 0b51294994
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
5 changed files with 49 additions and 46 deletions

View File

@ -3119,12 +3119,13 @@ DynamicModel::updateVarAndTrendModel() const
{
for (bool var : { true, false })
{
map<string, vector<int>> trend_varr;
map<string, vector<optional<int>>> trend_varr;
map<string, vector<set<pair<int, int>>>> rhsr;
for (const auto &[model_name, eqns] : (var ? var_model_table.getEqNums()
: trend_component_model_table.getEqNums()))
{
vector<int> lhs, trend_var, trend_lhs;
vector<int> lhs, trend_lhs;
vector<optional<int>> trend_var;
vector<set<pair<int, int>>> rhs;
if (!var)
@ -3160,25 +3161,25 @@ DynamicModel::updateVarAndTrendModel() const
catch (...)
{
}
int trend_var_symb_id = equations[eqn]->arg2->findTargetVariable(lhs_symb_id);
if (trend_var_symb_id >= 0)
optional<int> trend_var_symb_id = equations[eqn]->arg2->findTargetVariable(lhs_symb_id);
if (trend_var_symb_id)
{
if (symbol_table.isDiffAuxiliaryVariable(trend_var_symb_id))
if (symbol_table.isDiffAuxiliaryVariable(*trend_var_symb_id))
try
{
trend_var_symb_id = symbol_table.getOrigSymbIdForAuxVar(trend_var_symb_id);
trend_var_symb_id = symbol_table.getOrigSymbIdForAuxVar(*trend_var_symb_id);
}
catch (...)
{
}
if (find(trend_lhs.begin(), trend_lhs.end(), trend_var_symb_id) == trend_lhs.end())
if (find(trend_lhs.begin(), trend_lhs.end(), *trend_var_symb_id) == trend_lhs.end())
{
cerr << "ERROR: trend found in trend_component equation #" << eqn << " ("
<< symbol_table.getName(trend_var_symb_id) << ") does not correspond to a trend equation" << endl;
<< symbol_table.getName(*trend_var_symb_id) << ") does not correspond to a trend equation" << endl;
exit(EXIT_FAILURE);
}
}
trend_var.push_back(trend_var_symb_id);
trend_var.push_back(move(trend_var_symb_id));
}
}

View File

@ -643,10 +643,10 @@ NumConstNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes)
{
}
int
optional<int>
NumConstNode::findTargetVariable(int lhs_symb_id) const
{
return -1;
return nullopt;
}
expr_t
@ -1593,13 +1593,13 @@ VariableNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes)
datatree.getLocalVariable(symb_id)->findUnaryOpNodesForAuxVarCreation(nodes);
}
int
optional<int>
VariableNode::findTargetVariable(int lhs_symb_id) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->findTargetVariable(lhs_symb_id);
return -1;
return nullopt;
}
expr_t
@ -3393,7 +3393,7 @@ UnaryOpNode::findDiffNodes(lag_equivalence_table_t &nodes) const
nodes[lag_equiv_repr][index] = const_cast<UnaryOpNode *>(this);
}
int
optional<int>
UnaryOpNode::findTargetVariable(int lhs_symb_id) const
{
return arg->findTargetVariable(lhs_symb_id);
@ -5310,14 +5310,14 @@ BinaryOpNode::findTargetVariableHelper1(int lhs_symb_id, int rhs_symb_id) const
return false;
}
int
optional<int>
BinaryOpNode::findTargetVariableHelper(const expr_t arg1, const expr_t arg2,
int lhs_symb_id) const
{
set<int> params;
arg1->collectVariables(SymbolType::parameter, params);
if (params.size() != 1)
return -1;
return nullopt;
set<pair<int, int>> endogs;
arg2->collectDynamicVariables(SymbolType::endogenous, endogs);
@ -5331,18 +5331,18 @@ BinaryOpNode::findTargetVariableHelper(const expr_t arg1, const expr_t arg2,
else if (findTargetVariableHelper1(lhs_symb_id, endogs.rbegin()->first))
return endogs.begin()->first;
}
return -1;
return nullopt;
}
int
optional<int>
BinaryOpNode::findTargetVariable(int lhs_symb_id) const
{
int retval = findTargetVariableHelper(arg1, arg2, lhs_symb_id);
if (retval < 0)
optional<int> retval = findTargetVariableHelper(arg1, arg2, lhs_symb_id);
if (!retval)
retval = findTargetVariableHelper(arg2, arg1, lhs_symb_id);
if (retval < 0)
if (!retval)
retval = arg1->findTargetVariable(lhs_symb_id);
if (retval < 0)
if (!retval)
retval = arg2->findTargetVariable(lhs_symb_id);
return retval;
}
@ -6447,13 +6447,13 @@ TrinaryOpNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes)
arg3->findUnaryOpNodesForAuxVarCreation(nodes);
}
int
optional<int>
TrinaryOpNode::findTargetVariable(int lhs_symb_id) const
{
int retval = arg1->findTargetVariable(lhs_symb_id);
if (retval < 0)
optional<int> retval = arg1->findTargetVariable(lhs_symb_id);
if (!retval)
retval = arg2->findTargetVariable(lhs_symb_id);
if (retval < 0)
if (!retval)
retval = arg3->findTargetVariable(lhs_symb_id);
return retval;
}
@ -6871,14 +6871,14 @@ AbstractExternalFunctionNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_
argument->findUnaryOpNodesForAuxVarCreation(nodes);
}
int
optional<int>
AbstractExternalFunctionNode::findTargetVariable(int lhs_symb_id) const
{
for (auto argument : arguments)
if (int retval = argument->findTargetVariable(lhs_symb_id);
retval >= 0)
if (optional<int> retval = argument->findTargetVariable(lhs_symb_id);
retval)
return retval;
return -1;
return nullopt;
}
expr_t
@ -8391,10 +8391,10 @@ SubModelNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes)
{
}
int
optional<int>
SubModelNode::findTargetVariable(int lhs_symb_id) const
{
return -1;
return nullopt;
}
expr_t

View File

@ -25,6 +25,7 @@
#include <vector>
#include <ostream>
#include <functional>
#include <optional>
using namespace std;
@ -620,7 +621,7 @@ public:
//! Substitute pac_target_nonstationary operator
virtual expr_t substitutePacTargetNonstationary(const string &name, expr_t subexpr) = 0;
virtual int findTargetVariable(int lhs_symb_id) const = 0;
virtual optional<int> findTargetVariable(int lhs_symb_id) const = 0;
//! Add ExprNodes to the provided datatree
virtual expr_t clone(DataTree &datatree) const = 0;
@ -811,7 +812,7 @@ public:
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(lag_equivalence_table_t &nodes) const override;
void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override;
int findTargetVariable(int lhs_symb_id) const override;
optional<int> findTargetVariable(int lhs_symb_id) const override;
expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(const string &name, expr_t subexpr) override;
@ -884,7 +885,7 @@ public:
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(lag_equivalence_table_t &nodes) const override;
void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override;
int findTargetVariable(int lhs_symb_id) const override;
optional<int> findTargetVariable(int lhs_symb_id) const override;
expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(const string &name, expr_t subexpr) override;
@ -989,7 +990,7 @@ public:
void findDiffNodes(lag_equivalence_table_t &nodes) const override;
bool createAuxVarForUnaryOpNode() const;
void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override;
int findTargetVariable(int lhs_symb_id) const override;
optional<int> findTargetVariable(int lhs_symb_id) const override;
expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(const string &name, expr_t subexpr) override;
@ -1096,8 +1097,8 @@ public:
void findDiffNodes(lag_equivalence_table_t &nodes) const override;
void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override;
bool findTargetVariableHelper1(int lhs_symb_id, int rhs_symb_id) const;
int findTargetVariableHelper(const expr_t arg1, const expr_t arg2, int lhs_symb_id) const;
int findTargetVariable(int lhs_symb_id) const override;
optional<int> findTargetVariableHelper(const expr_t arg1, const expr_t arg2, int lhs_symb_id) const;
optional<int> findTargetVariable(int lhs_symb_id) const override;
expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(const string &name, expr_t subexpr) override;
@ -1232,7 +1233,7 @@ public:
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(lag_equivalence_table_t &nodes) const override;
void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override;
int findTargetVariable(int lhs_symb_id) const override;
optional<int> findTargetVariable(int lhs_symb_id) const override;
expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(const string &name, expr_t subexpr) override;
@ -1342,7 +1343,7 @@ public:
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(lag_equivalence_table_t &nodes) const override;
void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override;
int findTargetVariable(int lhs_symb_id) const override;
optional<int> findTargetVariable(int lhs_symb_id) const override;
expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(const string &name, expr_t subexpr) override;
@ -1520,7 +1521,7 @@ public:
expr_t substituteModelLocalVariables() const override;
void findDiffNodes(lag_equivalence_table_t &nodes) const override;
void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override;
int findTargetVariable(int lhs_symb_id) const override;
optional<int> findTargetVariable(int lhs_symb_id) const override;
expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;

View File

@ -83,7 +83,7 @@ TrendComponentModelTable::setRhs(map<string, vector<set<pair<int, int>>>> rhs_ar
}
void
TrendComponentModelTable::setTargetVar(map<string, vector<int>> target_vars_arg)
TrendComponentModelTable::setTargetVar(map<string, vector<optional<int>>> target_vars_arg)
{
target_vars = move(target_vars_arg);
}
@ -319,8 +319,8 @@ TrendComponentModelTable::writeOutput(const string &basename, ostream &output) c
i++;
}
output << "M_.trend_component." << name << ".target_vars = [";
for (auto it : target_vars.at(name))
output << (it >= 0 ? symbol_table.getTypeSpecificID(it) + 1 : -1) << " ";
for (const optional<int> &it : target_vars.at(name))
output << (it ? symbol_table.getTypeSpecificID(*it) + 1 : -1) << " ";
output << "];" << endl;
vector<string> target_eqtags_vec = target_eqtags.at(name);

View File

@ -24,6 +24,7 @@
#include <map>
#include <vector>
#include <iostream>
#include <optional>
#include "ExprNode.hh"
#include "SymbolTable.hh"
@ -49,7 +50,7 @@ private:
map<string, vector<set<pair<int, int>>>> rhs;
map<string, vector<bool>> diff;
map<string, vector<expr_t>> lhs_expr_t;
map<string, vector<int>> target_vars;
map<string, vector<optional<int>>> target_vars;
map<string, map<tuple<int, int, int>, expr_t>> AR; // name -> (eqn, lag, lhs_symb_id) -> expr_t
/* Note that A0 in the trend-component model context is not the same thing as
in the structural VAR context. */
@ -89,7 +90,7 @@ public:
void setMaxLags(map<string, vector<int>> max_lags_arg);
void setDiff(map<string, vector<bool>> diff_arg);
void setOrigDiffVar(map<string, vector<int>> orig_diff_var_arg);
void setTargetVar(map<string, vector<int>> target_vars_arg);
void setTargetVar(map<string, vector<optional<int>>> target_vars_arg);
void setAR(map<string, map<tuple<int, int, int>, expr_t>> AR_arg);
void setA0(map<string, map<tuple<int, int>, expr_t>> A0_arg,
map<string, map<tuple<int, int>, expr_t>> A0star_arg);