Misc simplifications using STL algorithms

master
Sébastien Villemot 2023-02-28 15:33:24 +01:00
parent 008a80910e
commit 62c455ff56
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
7 changed files with 61 additions and 109 deletions

View File

@ -20,6 +20,7 @@
#include <iostream>
#include <ios>
#include <cstdlib>
#include <algorithm>
#include "Bytecode.hh"
@ -96,10 +97,8 @@ operator<<(BytecodeWriter &code_file, const FBEGINBLOCK_ &instr)
write_member(instr.det_exo_size);
write_member(instr.exo_size);
for (int i{0}; i < instr.det_exo_size; i++)
write_member(instr.det_exogenous[i]);
for (int i{0}; i < instr.exo_size; i++)
write_member(instr.exogenous[i]);
for_each_n(instr.det_exogenous.begin(), instr.det_exo_size, write_member);
for_each_n(instr.exogenous.begin(), instr.exo_size, write_member);
return code_file;
}

View File

@ -840,11 +840,8 @@ DataTree::addAllParamDerivId([[maybe_unused]] set<int> &deriv_id_set)
bool
DataTree::isUnaryOpUsed(UnaryOpcode opcode) const
{
for (const auto &it : unary_op_node_map)
if (get<1>(it.first) == opcode)
return true;
return false;
return any_of(unary_op_node_map.begin(), unary_op_node_map.end(),
[=](const auto &it) { return get<1>(it.first) == opcode; });
}
bool
@ -864,11 +861,8 @@ DataTree::isUnaryOpUsedOnType(SymbolType type, UnaryOpcode opcode) const
bool
DataTree::isBinaryOpUsed(BinaryOpcode opcode) const
{
for (const auto &it : binary_op_node_map)
if (get<2>(it.first) == opcode)
return true;
return false;
return any_of(binary_op_node_map.begin(), binary_op_node_map.end(),
[=](const auto &it) { return get<2>(it.first) == opcode; });
}
bool

View File

@ -1,5 +1,5 @@
/*
* Copyright © 2007-2022 Dynare Team
* Copyright © 2007-2023 Dynare Team
*
* This file is part of Dynare.
*
@ -23,6 +23,7 @@
#include <cmath>
#include <utility>
#include <limits>
#include <numeric>
#include "ExprNode.hh"
#include "DataTree.hh"
@ -3430,19 +3431,13 @@ UnaryOpNode::substituteAdl() const
}
expr_t arg1subst = arg->substituteAdl();
expr_t retval = nullptr;
for (bool first_term{true};
int lag : adl_lags)
{
expr_t e = datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.getID(adl_param_name + "_lag_" + to_string(lag)), 0),
arg1subst->decreaseLeadsLags(lag));
if (exchange(first_term, false))
retval = e;
else
retval = datatree.AddPlus(retval, e);
}
return retval;
return transform_reduce(adl_lags.begin(), adl_lags.end(), static_cast<expr_t>(datatree.Zero),
[&](expr_t e1, expr_t e2) { return datatree.AddPlus(e1, e2); },
[&](int lag) {
return datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.getID(adl_param_name + "_lag_" + to_string(lag)), 0),
arg1subst->decreaseLeadsLags(lag));
});
}
expr_t
@ -6821,67 +6816,53 @@ AbstractExternalFunctionNode::eval([[maybe_unused]] const eval_context_t &eval_c
throw EvalExternalFunctionException();
}
int
AbstractExternalFunctionNode::maxHelper(const function<int (expr_t)> &f) const
{
return transform_reduce(arguments.begin(), arguments.end(), 0,
[](int a, int b) { return max(a, b); }, f);
}
int
AbstractExternalFunctionNode::maxEndoLead() const
{
int val = 0;
for (auto argument : arguments)
val = max(val, argument->maxEndoLead());
return val;
return maxHelper([](expr_t e) { return e->maxEndoLead(); });
}
int
AbstractExternalFunctionNode::maxExoLead() const
{
int val = 0;
for (auto argument : arguments)
val = max(val, argument->maxExoLead());
return val;
return maxHelper([](expr_t e) { return e->maxExoLead(); });
}
int
AbstractExternalFunctionNode::maxEndoLag() const
{
int val = 0;
for (auto argument : arguments)
val = max(val, argument->maxEndoLag());
return val;
return maxHelper([](expr_t e) { return e->maxEndoLag(); });
}
int
AbstractExternalFunctionNode::maxExoLag() const
{
int val = 0;
for (auto argument : arguments)
val = max(val, argument->maxExoLag());
return val;
return maxHelper([](expr_t e) { return e->maxExoLag(); });
}
int
AbstractExternalFunctionNode::maxLead() const
{
int val = 0;
for (auto argument : arguments)
val = max(val, argument->maxLead());
return val;
return maxHelper([](expr_t e) { return e->maxLead(); });
}
int
AbstractExternalFunctionNode::maxLag() const
{
int val = 0;
for (auto argument : arguments)
val = max(val, argument->maxLag());
return val;
return maxHelper([](expr_t e) { return e->maxLag(); });
}
int
AbstractExternalFunctionNode::maxLagWithDiffsExpanded() const
{
int val = 0;
for (auto argument : arguments)
val = max(val, argument->maxLagWithDiffsExpanded());
return val;
return maxHelper([](expr_t e) { return e->maxLagWithDiffsExpanded(); });
}
expr_t
@ -6896,10 +6877,7 @@ AbstractExternalFunctionNode::undiff() const
int
AbstractExternalFunctionNode::VarMaxLag(const set<expr_t> &lhs_lag_equiv) const
{
int max_lag = 0;
for (auto argument : arguments)
max_lag = max(max_lag, argument->VarMaxLag(lhs_lag_equiv));
return max_lag;
return maxHelper([&](expr_t e) { return e->VarMaxLag(lhs_lag_equiv); });
}
expr_t
@ -7038,10 +7016,7 @@ AbstractExternalFunctionNode::substituteUnaryOpNodes(const lag_equivalence_table
int
AbstractExternalFunctionNode::countDiffs() const
{
int ndiffs = 0;
for (auto argument : arguments)
ndiffs = max(ndiffs, argument->countDiffs());
return ndiffs;
return maxHelper([](expr_t e) { return e->countDiffs(); });
}
expr_t
@ -7148,19 +7123,15 @@ AbstractExternalFunctionNode::isVariableNodeEqualTo([[maybe_unused]] SymbolType
bool
AbstractExternalFunctionNode::containsPacExpectation(const string &pac_model_name) const
{
for (auto argument : arguments)
if (argument->containsPacExpectation(pac_model_name))
return true;
return false;
return any_of(arguments.begin(), arguments.end(),
[&](expr_t e) { return e->containsPacExpectation(pac_model_name); });
}
bool
AbstractExternalFunctionNode::containsPacTargetNonstationary(const string &pac_model_name) const
{
for (auto argument : arguments)
if (argument->containsPacTargetNonstationary(pac_model_name))
return true;
return false;
return any_of(arguments.begin(), arguments.end(),
[&](expr_t e) { return e->containsPacTargetNonstationary(pac_model_name); });
}
expr_t
@ -7193,10 +7164,8 @@ AbstractExternalFunctionNode::removeTrendLeadLag(const map<int, expr_t> &trend_s
bool
AbstractExternalFunctionNode::isInStaticForm() const
{
for (auto argument : arguments)
if (!argument->isInStaticForm())
return false;
return true;
return all_of(arguments.begin(), arguments.end(),
[](expr_t e) { return e->isInStaticForm(); });
}
bool
@ -7337,10 +7306,8 @@ ExternalFunctionNode::composeDerivatives(const vector<expr_t> &dargs)
dNodes.push_back(datatree.AddTimes(dargs.at(i),
datatree.AddFirstDerivExternalFunction(symb_id, arguments, i+1)));
expr_t theDeriv = datatree.Zero;
for (auto &dNode : dNodes)
theDeriv = datatree.AddPlus(theDeriv, dNode);
return theDeriv;
return accumulate(dNodes.begin(), dNodes.end(), static_cast<expr_t>(datatree.Zero),
[&](expr_t e1, expr_t e2) { return datatree.AddPlus(e1, e2); });
}
void
@ -7644,10 +7611,8 @@ FirstDerivExternalFunctionNode::composeDerivatives(const vector<expr_t> &dargs)
for (int i = 0; i < static_cast<int>(dargs.size()); i++)
dNodes.push_back(datatree.AddTimes(dargs.at(i),
datatree.AddSecondDerivExternalFunction(symb_id, arguments, inputIndex, i+1)));
expr_t theDeriv = datatree.Zero;
for (auto &dNode : dNodes)
theDeriv = datatree.AddPlus(theDeriv, dNode);
return theDeriv;
return accumulate(dNodes.begin(), dNodes.end(), static_cast<expr_t>(datatree.Zero),
[&](expr_t e1, expr_t e2) { return datatree.AddPlus(e1, e2); });
}
void

View File

@ -1,5 +1,5 @@
/*
* Copyright © 2007-2022 Dynare Team
* Copyright © 2007-2023 Dynare Team
*
* This file is part of Dynare.
*
@ -1340,6 +1340,8 @@ private:
expr_t computeDerivative(int deriv_id) override;
expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
virtual expr_t composeDerivatives(const vector<expr_t> &dargs) = 0;
// Computes the maximum of f applied to all arguments (result will always be non-negative)
int maxHelper(const function<int (expr_t)> &f) const;
protected:
//! Thrown when trying to access an unknown entry in external_function_node_map
class UnknownFunctionNameAndArgs

View File

@ -22,6 +22,7 @@
#include <cassert>
#include <sstream>
#include <cmath>
#include <numeric>
#include "ParsingDriver.hh"
#include "Statement.hh"
@ -2819,10 +2820,8 @@ ParsingDriver::add_diff(expr_t arg1)
expr_t
ParsingDriver::add_adl(expr_t arg1, const string &name, const string &lag)
{
vector<int> lags;
for (int i = 1; i <= stoi(lag); i++)
lags.push_back(i);
vector<int> lags(stoi(lag));
iota(lags.begin(), lags.end(), 1);
return add_adl(arg1, name, lags);
}

View File

@ -1,5 +1,5 @@
/*
* Copyright © 2018-2022 Dynare Team
* Copyright © 2018-2023 Dynare Team
*
* This file is part of Dynare.
*
@ -19,6 +19,7 @@
#include <algorithm>
#include <cassert>
#include <numeric>
#include "SubModel.hh"
#include "DynamicModel.hh"
@ -699,10 +700,8 @@ VarModelTable::getMaxLags(const string &name_arg) const
int
VarModelTable::getMaxLag(const string &name_arg) const
{
int max_lag_int = 0;
for (auto it : getMaxLags(name_arg))
max_lag_int = max(max_lag_int, it);
return max_lag_int;
vector<int> maxlags { getMaxLags(name_arg) };
return reduce(maxlags.begin(), maxlags.end(), 0, [](int a, int b) { return max(a, b); });
}
const vector<int> &

View File

@ -1,5 +1,5 @@
/*
* Copyright © 2003-2022 Dynare Team
* Copyright © 2003-2023 Dynare Team
*
* This file is part of Dynare.
*
@ -859,31 +859,25 @@ SymbolTable::getEndogenous() const
bool
SymbolTable::isAuxiliaryVariable(int symb_id) const
{
for (const auto &aux_var : aux_vars)
if (aux_var.symb_id == symb_id)
return true;
return false;
return any_of(aux_vars.begin(), aux_vars.end(), [=](const auto &av) { return av.symb_id == symb_id; });
}
bool
SymbolTable::isAuxiliaryVariableButNotMultiplier(int symb_id) const
{
for (const auto &aux_var : aux_vars)
if (aux_var.symb_id == symb_id && aux_var.type != AuxVarType::multiplier)
return true;
return false;
return any_of(aux_vars.begin(), aux_vars.end(),
[=](const auto &av)
{ return av.symb_id == symb_id && av.type != AuxVarType::multiplier; });
}
bool
SymbolTable::isDiffAuxiliaryVariable(int symb_id) const
{
for (const auto &aux_var : aux_vars)
if (aux_var.symb_id == symb_id
&& (aux_var.type == AuxVarType::diff
|| aux_var.type == AuxVarType::diffLag
|| aux_var.type == AuxVarType::diffLead))
return true;
return false;
return any_of(aux_vars.begin(), aux_vars.end(),
[=](const auto &av) { return av.symb_id == symb_id
&& (av.type == AuxVarType::diff
|| av.type == AuxVarType::diffLag
|| av.type == AuxVarType::diffLead); });
}
set<int>