/*
* Copyright © 2007-2022 Dynare Team
*
* This file is part of Dynare.
*
* Dynare is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Dynare is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Dynare. If not, see .
*/
#include
#include
#include
#include
#include
#include
#include "ExprNode.hh"
#include "DataTree.hh"
#include "ModFile.hh"
ExprNode::ExprNode(DataTree &datatree_arg, int idx_arg) : datatree{datatree_arg}, idx{idx_arg}
{
}
expr_t
ExprNode::getDerivative(int deriv_id)
{
if (!preparedForDerivation)
prepareForDerivation();
// Return zero if derivative is necessarily null (using symbolic a priori)
if (!non_null_derivatives.contains(deriv_id))
return datatree.Zero;
// If derivative is stored in cache, use the cached value, otherwise compute it (and cache it)
if (auto it2 = derivatives.find(deriv_id); it2 != derivatives.end())
return it2->second;
else
{
expr_t d = computeDerivative(deriv_id);
derivatives[deriv_id] = d;
return d;
}
}
int
ExprNode::precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const
{
// For a constant, a variable, or a unary op, the precedence is maximal
return 100;
}
int
ExprNode::precedenceJson(const temporary_terms_t &temporary_terms) const
{
// For a constant, a variable, or a unary op, the precedence is maximal
return 100;
}
int
ExprNode::cost(int cost, bool is_matlab) const
{
// For a terminal node, the cost is null
return 0;
}
int
ExprNode::cost(const vector> &blocks_temporary_terms, bool is_matlab) const
{
// For a terminal node, the cost is null
return 0;
}
int
ExprNode::cost(const map, temporary_terms_t> &temp_terms_map, bool is_matlab) const
{
// For a terminal node, the cost is null
return 0;
}
bool
ExprNode::checkIfTemporaryTermThenWrite(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_t &temporary_terms,
const temporary_terms_idxs_t &temporary_terms_idxs) const
{
if (!temporary_terms.contains(const_cast(this)))
return false;
auto it2 = temporary_terms_idxs.find(const_cast(this));
// It is the responsibility of the caller to ensure that all temporary terms have their index
assert(it2 != temporary_terms_idxs.end());
output << "T" << LEFT_ARRAY_SUBSCRIPT(output_type)
<< it2->second + ARRAY_SUBSCRIPT_OFFSET(output_type)
<< RIGHT_ARRAY_SUBSCRIPT(output_type);
return true;
}
pair
ExprNode::getLagEquivalenceClass() const
{
int index = maxLead();
if (index == numeric_limits::min())
index = 0; // If no variable in the expression, the equivalence class has size 1
return { decreaseLeadsLags(index), index };
}
void
ExprNode::collectVariables(SymbolType type, set &result) const
{
set> symbs_lags;
collectDynamicVariables(type, symbs_lags);
transform(symbs_lags.begin(), symbs_lags.end(), inserter(result, result.begin()),
[](auto x) { return x.first; });
}
void
ExprNode::collectEndogenous(set> &result) const
{
set> symb_ids_and_lags;
collectDynamicVariables(SymbolType::endogenous, symb_ids_and_lags);
for (const auto &[symb_id, lag] : symb_ids_and_lags)
result.emplace(datatree.symbol_table.getTypeSpecificID(symb_id), lag);
}
void
ExprNode::computeTemporaryTerms(const pair &derivOrder,
map, temporary_terms_t> &temp_terms_map,
map>> &reference_count,
bool is_matlab) const
{
// Nothing to do for a terminal node
}
void
ExprNode::computeBlockTemporaryTerms(int blk, int eq, vector> &blocks_temporary_terms,
map> &reference_count) const
{
// Nothing to do for a terminal node
}
void
ExprNode::writeOutput(ostream &output) const
{
writeOutput(output, ExprNodeOutputType::matlabOutsideModel, {}, {});
}
void
ExprNode::writeOutput(ostream &output, ExprNodeOutputType output_type) const
{
writeOutput(output, output_type, {}, {});
}
void
ExprNode::writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const
{
writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, {});
}
void
ExprNode::compile(ostream &CompileCode, unsigned int &instruction_number,
bool lhs_rhs, const temporary_terms_t &temporary_terms,
const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic) const
{
compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, temporary_terms_idxs, dynamic, steady_dynamic, {});
}
void
ExprNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_t &temporary_terms,
const temporary_terms_idxs_t &temporary_terms_idxs,
deriv_node_temp_terms_t &tef_terms) const
{
// Nothing to do
}
void
ExprNode::writeJsonExternalFunctionOutput(vector &efout,
const temporary_terms_t &temporary_terms,
deriv_node_temp_terms_t &tef_terms,
bool isdynamic) const
{
// Nothing to do
}
void
ExprNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number,
bool lhs_rhs, const temporary_terms_t &temporary_terms,
const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic,
deriv_node_temp_terms_t &tef_terms) const
{
// Nothing to do
}
VariableNode *
ExprNode::createEndoLeadAuxiliaryVarForMyself(subst_table_t &subst_table, vector &neweqs) const
{
int n = maxEndoLead();
assert(n >= 2);
if (auto it = subst_table.find(this);
it != subst_table.end())
return const_cast(it->second);
expr_t substexpr = decreaseLeadsLags(n-1);
int lag = n-2;
// Each iteration tries to create an auxvar such that auxvar(+1)=expr(-lag)
// At the beginning (resp. end) of each iteration, substexpr is an expression (possibly an auxvar) equivalent to expr(-lag-1) (resp. expr(-lag))
while (lag >= 0)
{
expr_t orig_expr = decreaseLeadsLags(lag);
if (auto it = subst_table.find(orig_expr); it == subst_table.end())
{
int symb_id = datatree.symbol_table.addEndoLeadAuxiliaryVar(orig_expr->idx, substexpr);
neweqs.push_back(datatree.AddEqual(datatree.AddVariable(symb_id, 0), substexpr));
substexpr = datatree.AddVariable(symb_id, +1);
assert(dynamic_cast(substexpr));
subst_table[orig_expr] = dynamic_cast(substexpr);
}
else
substexpr = const_cast(it->second);
lag--;
}
return dynamic_cast(substexpr);
}
VariableNode *
ExprNode::createExoLeadAuxiliaryVarForMyself(subst_table_t &subst_table, vector &neweqs) const
{
int n = maxExoLead();
assert(n >= 1);
if (auto it = subst_table.find(this);
it != subst_table.end())
return const_cast(it->second);
expr_t substexpr = decreaseLeadsLags(n);
int lag = n-1;
// Each iteration tries to create an auxvar such that auxvar(+1)=expr(-lag)
// At the beginning (resp. end) of each iteration, substexpr is an expression (possibly an auxvar) equivalent to expr(-lag-1) (resp. expr(-lag))
while (lag >= 0)
{
expr_t orig_expr = decreaseLeadsLags(lag);
if (auto it = subst_table.find(orig_expr); it == subst_table.end())
{
int symb_id = datatree.symbol_table.addExoLeadAuxiliaryVar(orig_expr->idx, substexpr);
neweqs.push_back(datatree.AddEqual(datatree.AddVariable(symb_id, 0), substexpr));
substexpr = datatree.AddVariable(symb_id, +1);
assert(dynamic_cast(substexpr));
subst_table[orig_expr] = dynamic_cast(substexpr);
}
else
substexpr = const_cast(it->second);
lag--;
}
return dynamic_cast(substexpr);
}
bool
ExprNode::isNumConstNodeEqualTo(double value) const
{
return false;
}
bool
ExprNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
{
return false;
}
void
ExprNode::fillErrorCorrectionRow(int eqn,
const vector &nontarget_lhs,
const vector &target_lhs,
map, expr_t> &A0,
map, expr_t> &A0star) const
{
vector> terms;
decomposeAdditiveTerms(terms, 1);
for (const auto &[term, sign] : terms)
{
int speed_of_adjustment_param;
vector> error_linear_combination;
try
{
tie(speed_of_adjustment_param, error_linear_combination) = term->matchParamTimesLinearCombinationOfVariables();
for (auto &[var_id, lag, param_id, constant] : error_linear_combination)
constant *= sign; // Update sign of constants
}
catch (MatchFailureException &e)
{
/* FIXME: we should not just skip them, but rather verify that they are
autoregressive terms or residuals (probably by merging the two "fill" procedures) */
continue;
}
/* Verify that all variables belong to the error-correction term.
FIXME: same remark as above about skipping terms. */
bool not_ec = false;
for (const auto &[var_id, lag, param_id, constant] : error_linear_combination)
{
auto [orig_var_id, orig_lag] = datatree.symbol_table.unrollDiffLeadLagChain(var_id, lag);
not_ec = not_ec || (find(target_lhs.begin(), target_lhs.end(), orig_var_id) == target_lhs.end()
&& find(nontarget_lhs.begin(), nontarget_lhs.end(), orig_var_id) == nontarget_lhs.end());
}
if (not_ec)
continue;
// Now fill the matrices
for (auto [var_id, lag, param_id, constant] : error_linear_combination)
{
auto [orig_vid, orig_lag] = datatree.symbol_table.unrollDiffLeadLagChain(var_id, lag);
if (find(target_lhs.begin(), target_lhs.end(), orig_vid) == target_lhs.end())
{
if (orig_lag != -1)
{
cerr << "ERROR in trend component model: variables in the error correction term should appear with a lag of -1" << endl;
exit(EXIT_FAILURE);
}
// This an LHS variable, so fill A0
if (constant != 1)
{
cerr << "ERROR in trend component model: LHS variable should not appear with a multiplicative constant in error correction term" << endl;
exit(EXIT_FAILURE);
}
if (param_id != -1)
{
cerr << "ERROR in trend component model: spurious parameter in error correction term" << endl;
exit(EXIT_FAILURE);
}
int colidx = static_cast(distance(nontarget_lhs.begin(), find(nontarget_lhs.begin(), nontarget_lhs.end(), orig_vid)));
if (A0.contains({eqn, colidx}))
{
cerr << "ExprNode::fillErrorCorrection: Error filling A0 matrix: "
<< "symb_id encountered more than once in equation" << endl;
exit(EXIT_FAILURE);
}
A0[{eqn, colidx}] = datatree.AddVariable(speed_of_adjustment_param);
}
else
{
// This is a target, so fill A0star
int colidx = static_cast(distance(target_lhs.begin(), find(target_lhs.begin(), target_lhs.end(), orig_vid)));
expr_t e = datatree.AddTimes(datatree.AddVariable(speed_of_adjustment_param),
datatree.AddPossiblyNegativeConstant(-constant));
if (param_id != -1)
e = datatree.AddTimes(e, datatree.AddVariable(param_id));
if (auto coor = make_pair(eqn, colidx); A0star.contains(coor))
A0star[coor] = datatree.AddPlus(e, A0star[coor]);
else
A0star[coor] = e;
}
}
}
}
void
ExprNode::matchMatchedMoment(vector &symb_ids, vector &lags, vector &powers) const
{
throw MatchFailureException{"Unsupported expression"};
}
bool
ExprNode::isConstant() const
{
set> symbs_lags;
collectDynamicVariables(SymbolType::endogenous, symbs_lags);
collectDynamicVariables(SymbolType::exogenous, symbs_lags);
collectDynamicVariables(SymbolType::exogenousDet, symbs_lags);
return symbs_lags.empty();
}
bool
ExprNode::hasExogenous() const
{
set> symbs_lags;
collectDynamicVariables(SymbolType::exogenous, symbs_lags);
collectDynamicVariables(SymbolType::exogenousDet, symbs_lags);
return !symbs_lags.empty();
}
NumConstNode::NumConstNode(DataTree &datatree_arg, int idx_arg, int id_arg) :
ExprNode{datatree_arg, idx_arg},
id{id_arg}
{
}
int
NumConstNode::countDiffs() const
{
return 0;
}
void
NumConstNode::prepareForDerivation()
{
preparedForDerivation = true;
// All derivatives are null, so non_null_derivatives is left empty
}
expr_t
NumConstNode::computeDerivative(int deriv_id)
{
return datatree.Zero;
}
void
NumConstNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_t &temporary_terms,
const temporary_terms_idxs_t &temporary_terms_idxs,
const deriv_node_temp_terms_t &tef_terms) const
{
if (!checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
output << datatree.num_constants.get(id);
}
void
NumConstNode::writeJsonAST(ostream &output) const
{
output << R"({"node_type" : "NumConstNode", "value" : )";
output << std::stof(datatree.num_constants.get(id)) << "}";
}
void
NumConstNode::writeJsonOutput(ostream &output,
const temporary_terms_t &temporary_terms,
const deriv_node_temp_terms_t &tef_terms,
bool isdynamic) const
{
output << datatree.num_constants.get(id);
}
bool
NumConstNode::containsExternalFunction() const
{
return false;
}
double
NumConstNode::eval(const eval_context_t &eval_context) const noexcept(false)
{
return datatree.num_constants.getDouble(id);
}
void
NumConstNode::compile(ostream &CompileCode, unsigned int &instruction_number,
bool lhs_rhs, const temporary_terms_t &temporary_terms,
const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic,
const deriv_node_temp_terms_t &tef_terms) const
{
FLDC_ fldc(datatree.num_constants.getDouble(id));
fldc.write(CompileCode, instruction_number);
}
void
NumConstNode::collectVARLHSVariable(set &result) const
{
cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
exit(EXIT_FAILURE);
}
void
NumConstNode::collectDynamicVariables(SymbolType type_arg, set> &result) const
{
}
void
NumConstNode::computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const
{
}
BinaryOpNode *
NumConstNode::normalizeEquationHelper(const set &contain_var, expr_t rhs) const
{
cerr << "NumConstNode::normalizeEquation: this should not happen" << endl;
exit(EXIT_FAILURE);
}
expr_t
NumConstNode::getChainRuleDerivative(int deriv_id, const map &recursive_variables)
{
return datatree.Zero;
}
expr_t
NumConstNode::toStatic(DataTree &static_datatree) const
{
return static_datatree.AddNonNegativeConstant(datatree.num_constants.get(id));
}
void
NumConstNode::computeXrefs(EquationInfo &ei) const
{
}
expr_t
NumConstNode::clone(DataTree &datatree) const
{
return datatree.AddNonNegativeConstant(datatree.num_constants.get(id));
}
int
NumConstNode::maxEndoLead() const
{
return 0;
}
int
NumConstNode::maxExoLead() const
{
return 0;
}
int
NumConstNode::maxEndoLag() const
{
return 0;
}
int
NumConstNode::maxExoLag() const
{
return 0;
}
int
NumConstNode::maxLead() const
{
return numeric_limits::min();
}
int
NumConstNode::maxLag() const
{
return numeric_limits::min();
}
int
NumConstNode::maxLagWithDiffsExpanded() const
{
return numeric_limits::min();
}
expr_t
NumConstNode::undiff() const
{
return const_cast(this);
}
int
NumConstNode::VarMaxLag(const set &lhs_lag_equiv) const
{
return 0;
}
expr_t
NumConstNode::decreaseLeadsLags(int n) const
{
return const_cast(this);
}
expr_t
NumConstNode::decreaseLeadsLagsPredeterminedVariables() const
{
return const_cast(this);
}
expr_t
NumConstNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector &neweqs, bool deterministic_model) const
{
return const_cast(this);
}
expr_t
NumConstNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const
{
return const_cast(this);
}
expr_t
NumConstNode::substituteExoLead(subst_table_t &subst_table, vector &neweqs, bool deterministic_model) const
{
return const_cast(this);
}
expr_t
NumConstNode::substituteExoLag(subst_table_t &subst_table, vector &neweqs) const
{
return const_cast(this);
}
expr_t
NumConstNode::substituteExpectation(subst_table_t &subst_table, vector &neweqs, bool partial_information_model) const
{
return const_cast(this);
}
expr_t
NumConstNode::substituteAdl() const
{
return const_cast(this);
}
expr_t
NumConstNode::substituteModelLocalVariables() const
{
return const_cast(this);
}
expr_t
NumConstNode::substituteVarExpectation(const map &subst_table) const
{
return const_cast(this);
}
void
NumConstNode::findDiffNodes(lag_equivalence_table_t &nodes) const
{
}
void
NumConstNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
{
}
optional
NumConstNode::findTargetVariable(int lhs_symb_id) const
{
return nullopt;
}
expr_t
NumConstNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector &neweqs) const
{
return const_cast(this);
}
expr_t
NumConstNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector &neweqs) const
{
return const_cast(this);
}
expr_t
NumConstNode::substitutePacExpectation(const string &name, expr_t subexpr)
{
return const_cast(this);
}
expr_t
NumConstNode::substitutePacTargetNonstationary(const string &name, expr_t subexpr)
{
return const_cast(this);
}
expr_t
NumConstNode::differentiateForwardVars(const vector &subset, subst_table_t &subst_table, vector &neweqs) const
{
return const_cast(this);
}
bool
NumConstNode::isNumConstNodeEqualTo(double value) const
{
if (datatree.num_constants.getDouble(id) == value)
return true;
else
return false;
}
bool
NumConstNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
{
return false;
}
bool
NumConstNode::containsPacExpectation(const string &pac_model_name) const
{
return false;
}
bool
NumConstNode::containsPacTargetNonstationary(const string &pac_model_name) const
{
return false;
}
expr_t
NumConstNode::replaceTrendVar() const
{
return const_cast(this);
}
expr_t
NumConstNode::detrend(int symb_id, bool log_trend, expr_t trend) const
{
return const_cast(this);
}
expr_t
NumConstNode::removeTrendLeadLag(const map &trend_symbols_map) const
{
return const_cast(this);
}
bool
NumConstNode::isInStaticForm() const
{
return true;
}
bool
NumConstNode::isParamTimesEndogExpr() const
{
return false;
}
expr_t
NumConstNode::substituteStaticAuxiliaryVariable() const
{
return const_cast(this);
}
expr_t
NumConstNode::replaceVarsInEquation(map &table) const
{
return const_cast(this);
}
expr_t
NumConstNode::substituteLogTransform(int orig_symb_id, int aux_symb_id) const
{
return const_cast(this);
}
VariableNode::VariableNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg, int lag_arg) :
ExprNode{datatree_arg, idx_arg},
symb_id{symb_id_arg},
lag{lag_arg}
{
// It makes sense to allow a lead/lag on parameters: during steady state calibration, endogenous and parameters can be swapped
assert(get_type() != SymbolType::externalFunction
&& (lag == 0 || (get_type() != SymbolType::modelLocalVariable && get_type() != SymbolType::modFileLocalVariable)));
}
void
VariableNode::prepareForDerivation()
{
if (preparedForDerivation)
return;
preparedForDerivation = true;
// Fill in non_null_derivatives
switch (get_type())
{
case SymbolType::endogenous:
case SymbolType::exogenous:
case SymbolType::exogenousDet:
case SymbolType::parameter:
case SymbolType::trend:
case SymbolType::logTrend:
// For a variable or a parameter, the only non-null derivative is with respect to itself
non_null_derivatives.insert(datatree.getDerivID(symb_id, lag));
break;
case SymbolType::modelLocalVariable:
datatree.getLocalVariable(symb_id)->prepareForDerivation();
// Non null derivatives are those of the value of the local parameter
non_null_derivatives = datatree.getLocalVariable(symb_id)->non_null_derivatives;
break;
case SymbolType::modFileLocalVariable:
case SymbolType::statementDeclaredVariable:
case SymbolType::unusedEndogenous:
// Such a variable is never derived
break;
case SymbolType::externalFunction:
case SymbolType::epilogue:
cerr << "VariableNode::prepareForDerivation: impossible case" << endl;
exit(EXIT_FAILURE);
case SymbolType::excludedVariable:
cerr << "VariableNode::prepareForDerivation: impossible case: "
<< "You are trying to derive a variable that has been excluded via model_remove/var_remove/include_eqs/exclude_eqs: "
<< datatree.symbol_table.getName(symb_id) << endl;
exit(EXIT_FAILURE);
}
}
expr_t
VariableNode::computeDerivative(int deriv_id)
{
switch (get_type())
{
case SymbolType::endogenous:
case SymbolType::exogenous:
case SymbolType::exogenousDet:
case SymbolType::parameter:
case SymbolType::trend:
case SymbolType::logTrend:
if (deriv_id == datatree.getDerivID(symb_id, lag))
return datatree.One;
else
return datatree.Zero;
case SymbolType::modelLocalVariable:
return datatree.getLocalVariable(symb_id)->getDerivative(deriv_id);
case SymbolType::modFileLocalVariable:
cerr << "modFileLocalVariable is not derivable" << endl;
exit(EXIT_FAILURE);
case SymbolType::statementDeclaredVariable:
cerr << "statementDeclaredVariable is not derivable" << endl;
exit(EXIT_FAILURE);
case SymbolType::unusedEndogenous:
cerr << "unusedEndogenous is not derivable" << endl;
exit(EXIT_FAILURE);
case SymbolType::externalFunction:
case SymbolType::epilogue:
case SymbolType::excludedVariable:
cerr << "VariableNode::computeDerivative: Impossible case!" << endl;
exit(EXIT_FAILURE);
}
// Suppress GCC warning
exit(EXIT_FAILURE);
}
bool
VariableNode::containsExternalFunction() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->containsExternalFunction();
return false;
}
void
VariableNode::writeJsonAST(ostream &output) const
{
output << R"({"node_type" : "VariableNode", )"
<< R"("name" : ")" << datatree.symbol_table.getName(symb_id) << R"(", "type" : ")";
switch (get_type())
{
case SymbolType::endogenous:
output << "endogenous";
break;
case SymbolType::exogenous:
output << "exogenous";
break;
case SymbolType::exogenousDet:
output << "exogenousDet";
break;
case SymbolType::parameter:
output << "parameter";
break;
case SymbolType::modelLocalVariable:
output << "modelLocalVariable";
break;
case SymbolType::modFileLocalVariable:
output << "modFileLocalVariable";
break;
case SymbolType::externalFunction:
output << "externalFunction";
break;
case SymbolType::trend:
output << "trend";
break;
case SymbolType::statementDeclaredVariable:
output << "statementDeclaredVariable";
break;
case SymbolType::logTrend:
output << "logTrend:";
break;
case SymbolType::unusedEndogenous:
output << "unusedEndogenous";
break;
case SymbolType::epilogue:
output << "epilogue";
break;
case SymbolType::excludedVariable:
cerr << "VariableNode::computeDerivative: Impossible case!" << endl;
exit(EXIT_FAILURE);
}
output << R"(", "lag" : )" << lag << "}";
}
void
VariableNode::writeJsonOutput(ostream &output,
const temporary_terms_t &temporary_terms,
const deriv_node_temp_terms_t &tef_terms,
bool isdynamic) const
{
if (temporary_terms.contains(const_cast(this)))
{
output << "T" << idx;
return;
}
output << datatree.symbol_table.getName(symb_id);
if (isdynamic && lag != 0)
output << "(" << lag << ")";
}
void
VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_t &temporary_terms,
const temporary_terms_idxs_t &temporary_terms_idxs,
const deriv_node_temp_terms_t &tef_terms) const
{
auto type = get_type();
if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
return;
if (isLatexOutput(output_type))
{
if (output_type == ExprNodeOutputType::latexDynamicSteadyStateOperator)
output << R"(\bar)";
output << "{" << datatree.symbol_table.getTeXName(symb_id) << "}";
if (output_type == ExprNodeOutputType::latexDynamicModel
&& (type == SymbolType::endogenous || type == SymbolType::exogenous || type == SymbolType::exogenousDet || type == SymbolType::trend || type == SymbolType::logTrend))
{
output << "_{t";
if (lag != 0)
{
if (lag > 0)
output << "+";
output << lag;
}
output << "}";
}
return;
}
auto juliaTimeDataFrameHelper = [&]()
{
if (lag != 0)
output << "lag(";
output << "ds." << datatree.symbol_table.getName(symb_id);
if (lag != 0)
{
if (lag != -1)
output << "," << -lag;
output << ")";
}
};
int i;
switch (type)
{
case SymbolType::parameter:
if (int tsid = datatree.symbol_table.getTypeSpecificID(symb_id);
output_type == ExprNodeOutputType::matlabOutsideModel)
output << "M_.params" << "(" << tsid + 1 << ")";
else
output << "params" << LEFT_ARRAY_SUBSCRIPT(output_type) << tsid + ARRAY_SUBSCRIPT_OFFSET(output_type) << RIGHT_ARRAY_SUBSCRIPT(output_type);
break;
case SymbolType::modelLocalVariable:
if (output_type == ExprNodeOutputType::matlabDynamicSteadyStateOperator
|| output_type == ExprNodeOutputType::CDynamicSteadyStateOperator)
{
output << "(";
datatree.getLocalVariable(symb_id)->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
output << ")";
}
else
/* We append underscores to avoid name clashes with "g1" or "oo_".
But we probably never arrive here because MLV are temporary terms… */
output << datatree.symbol_table.getName(symb_id) << "__";
break;
case SymbolType::modFileLocalVariable:
output << datatree.symbol_table.getName(symb_id);
break;
case SymbolType::endogenous:
switch (int tsid = datatree.symbol_table.getTypeSpecificID(symb_id);
output_type)
{
case ExprNodeOutputType::juliaDynamicModel:
case ExprNodeOutputType::matlabDynamicModel:
case ExprNodeOutputType::CDynamicModel:
i = datatree.getDynJacobianCol(datatree.getDerivID(symb_id, lag)) + ARRAY_SUBSCRIPT_OFFSET(output_type);
output << "y" << LEFT_ARRAY_SUBSCRIPT(output_type) << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
break;
case ExprNodeOutputType::CStaticModel:
case ExprNodeOutputType::juliaStaticModel:
case ExprNodeOutputType::matlabStaticModel:
i = tsid + ARRAY_SUBSCRIPT_OFFSET(output_type);
output << "y" << LEFT_ARRAY_SUBSCRIPT(output_type) << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
break;
case ExprNodeOutputType::matlabOutsideModel:
output << "oo_.steady_state(" << tsid + 1 << ")";
break;
case ExprNodeOutputType::juliaDynamicSteadyStateOperator:
case ExprNodeOutputType::matlabDynamicSteadyStateOperator:
output << "steady_state" << LEFT_ARRAY_SUBSCRIPT(output_type) << tsid + 1 << RIGHT_ARRAY_SUBSCRIPT(output_type);
break;
case ExprNodeOutputType::CDynamicSteadyStateOperator:
output << "steady_state[" << tsid << "]";
break;
case ExprNodeOutputType::juliaSteadyStateFile:
case ExprNodeOutputType::steadyStateFile:
output << "ys_" << LEFT_ARRAY_SUBSCRIPT(output_type) << tsid + 1 << RIGHT_ARRAY_SUBSCRIPT(output_type);
break;
case ExprNodeOutputType::matlabDseries:
output << "ds." << datatree.symbol_table.getName(symb_id);
if (lag != 0)
output << LEFT_ARRAY_SUBSCRIPT(output_type) << lag << RIGHT_ARRAY_SUBSCRIPT(output_type);
break;
case ExprNodeOutputType::juliaTimeDataFrame:
juliaTimeDataFrameHelper();
break;
case ExprNodeOutputType::epilogueFile:
output << "ds." << datatree.symbol_table.getName(symb_id);
output << LEFT_ARRAY_SUBSCRIPT(output_type) << "t";
if (lag != 0)
output << lag;
output << RIGHT_ARRAY_SUBSCRIPT(output_type);
break;
case ExprNodeOutputType::occbinDifferenceFile:
output << "zdatalinear(:," << tsid + 1 << ")";
break;
default:
cerr << "VariableNode::writeOutput: should not reach this point" << endl;
exit(EXIT_FAILURE);
}
break;
case SymbolType::exogenous:
i = datatree.symbol_table.getTypeSpecificID(symb_id) + ARRAY_SUBSCRIPT_OFFSET(output_type);
switch (output_type)
{
case ExprNodeOutputType::juliaDynamicModel:
case ExprNodeOutputType::matlabDynamicModel:
if (lag > 0)
output << "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_+" << lag << ", " << i
<< RIGHT_ARRAY_SUBSCRIPT(output_type);
else if (lag < 0)
output << "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_" << lag << ", " << i
<< RIGHT_ARRAY_SUBSCRIPT(output_type);
else
output << "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_, " << i
<< RIGHT_ARRAY_SUBSCRIPT(output_type);
break;
case ExprNodeOutputType::CDynamicModel:
if (lag == 0)
output << "x[it_+" << i << "*nb_row_x]";
else if (lag > 0)
output << "x[it_+" << lag << "+" << i << "*nb_row_x]";
else
output << "x[it_" << lag << "+" << i << "*nb_row_x]";
break;
case ExprNodeOutputType::CStaticModel:
case ExprNodeOutputType::juliaStaticModel:
case ExprNodeOutputType::matlabStaticModel:
output << "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
break;
case ExprNodeOutputType::matlabOutsideModel:
assert(lag == 0);
output << "oo_.exo_steady_state(" << i << ")";
break;
case ExprNodeOutputType::matlabDynamicSteadyStateOperator:
output << "oo_.exo_steady_state(" << i << ")";
break;
case ExprNodeOutputType::juliaSteadyStateFile:
case ExprNodeOutputType::steadyStateFile:
output << "exo_" << LEFT_ARRAY_SUBSCRIPT(output_type) << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
break;
case ExprNodeOutputType::matlabDseries:
output << "ds." << datatree.symbol_table.getName(symb_id);
if (lag != 0)
output << LEFT_ARRAY_SUBSCRIPT(output_type) << lag << RIGHT_ARRAY_SUBSCRIPT(output_type);
break;
case ExprNodeOutputType::juliaTimeDataFrame:
juliaTimeDataFrameHelper();
break;
case ExprNodeOutputType::epilogueFile:
output << "ds." << datatree.symbol_table.getName(symb_id);
output << LEFT_ARRAY_SUBSCRIPT(output_type) << "t";
if (lag != 0)
output << lag;
output << RIGHT_ARRAY_SUBSCRIPT(output_type);
break;
default:
cerr << "VariableNode::writeOutput: should not reach this point" << endl;
exit(EXIT_FAILURE);
}
break;
case SymbolType::exogenousDet:
i = datatree.symbol_table.getTypeSpecificID(symb_id) + datatree.symbol_table.exo_nbr() + ARRAY_SUBSCRIPT_OFFSET(output_type);
switch (output_type)
{
case ExprNodeOutputType::juliaDynamicModel:
case ExprNodeOutputType::matlabDynamicModel:
if (lag > 0)
output << "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_+" << lag << ", " << i
<< RIGHT_ARRAY_SUBSCRIPT(output_type);
else if (lag < 0)
output << "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_" << lag << ", " << i
<< RIGHT_ARRAY_SUBSCRIPT(output_type);
else
output << "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_, " << i
<< RIGHT_ARRAY_SUBSCRIPT(output_type);
break;
case ExprNodeOutputType::CDynamicModel:
if (lag == 0)
output << "x[it_+" << i << "*nb_row_x]";
else if (lag > 0)
output << "x[it_+" << lag << "+" << i << "*nb_row_x]";
else
output << "x[it_" << lag << "+" << i << "*nb_row_x]";
break;
case ExprNodeOutputType::CStaticModel:
case ExprNodeOutputType::juliaStaticModel:
case ExprNodeOutputType::matlabStaticModel:
output << "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
break;
case ExprNodeOutputType::matlabOutsideModel:
assert(lag == 0);
output << "oo_.exo_det_steady_state(" << datatree.symbol_table.getTypeSpecificID(symb_id) + 1 << ")";
break;
case ExprNodeOutputType::matlabDynamicSteadyStateOperator:
output << "oo_.exo_det_steady_state(" << datatree.symbol_table.getTypeSpecificID(symb_id) + 1 << ")";
break;
case ExprNodeOutputType::juliaSteadyStateFile:
case ExprNodeOutputType::steadyStateFile:
output << "exo_" << LEFT_ARRAY_SUBSCRIPT(output_type) << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
break;
case ExprNodeOutputType::matlabDseries:
output << "ds." << datatree.symbol_table.getName(symb_id);
if (lag != 0)
output << LEFT_ARRAY_SUBSCRIPT(output_type) << lag << RIGHT_ARRAY_SUBSCRIPT(output_type);
break;
case ExprNodeOutputType::juliaTimeDataFrame:
juliaTimeDataFrameHelper();
break;
case ExprNodeOutputType::epilogueFile:
output << "ds." << datatree.symbol_table.getName(symb_id);
output << LEFT_ARRAY_SUBSCRIPT(output_type) << "t";
if (lag != 0)
output << lag;
output << RIGHT_ARRAY_SUBSCRIPT(output_type);
break;
default:
cerr << "VariableNode::writeOutput: should not reach this point" << endl;
exit(EXIT_FAILURE);
}
break;
case SymbolType::epilogue:
if (output_type == ExprNodeOutputType::epilogueFile)
{
output << "ds." << datatree.symbol_table.getName(symb_id);
output << LEFT_ARRAY_SUBSCRIPT(output_type) << "t";
if (lag != 0)
output << lag;
output << RIGHT_ARRAY_SUBSCRIPT(output_type);
}
else if (output_type == ExprNodeOutputType::matlabDseries
|| output_type == ExprNodeOutputType::juliaTimeDataFrame)
// Only writing dseries for epilogue_static, hence no need to check lag
output << "ds." << datatree.symbol_table.getName(symb_id);
else
{
cerr << "VariableNode::writeOutput: Impossible case" << endl;
exit(EXIT_FAILURE);
}
break;
case SymbolType::unusedEndogenous:
cerr << "ERROR: You cannot use an endogenous variable in an expression if that variable has not been used in the model block." << endl;
exit(EXIT_FAILURE);
case SymbolType::externalFunction:
case SymbolType::trend:
case SymbolType::logTrend:
case SymbolType::statementDeclaredVariable:
case SymbolType::excludedVariable:
cerr << "VariableNode::writeOutput: Impossible case" << endl;
exit(EXIT_FAILURE);
}
}
expr_t
VariableNode::substituteStaticAuxiliaryVariable() const
{
if (get_type() == SymbolType::endogenous)
try
{
return datatree.symbol_table.getAuxiliaryVarsExprNode(symb_id)->substituteStaticAuxiliaryVariable();
}
catch (SymbolTable::SearchFailedException &e)
{
}
return const_cast(this);
}
double
VariableNode::eval(const eval_context_t &eval_context) const noexcept(false)
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->eval(eval_context);
auto it = eval_context.find(symb_id);
if (it == eval_context.end())
throw EvalException();
return it->second;
}
void
VariableNode::compile(ostream &CompileCode, unsigned int &instruction_number,
bool lhs_rhs, const temporary_terms_t &temporary_terms,
const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic,
const deriv_node_temp_terms_t &tef_terms) const
{
auto type = get_type();
if (type == SymbolType::modelLocalVariable || type == SymbolType::modFileLocalVariable)
datatree.getLocalVariable(symb_id)->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms);
else
{
int tsid = datatree.symbol_table.getTypeSpecificID(symb_id);
if (type == SymbolType::exogenousDet)
tsid += datatree.symbol_table.exo_nbr();
if (!lhs_rhs)
{
if (dynamic)
{
if (steady_dynamic) // steady state values in a dynamic model
{
FLDVS_ fldvs{static_cast(type), static_cast(tsid)};
fldvs.write(CompileCode, instruction_number);
}
else
{
if (type == SymbolType::parameter)
{
FLDV_ fldv{static_cast(type), static_cast(tsid)};
fldv.write(CompileCode, instruction_number);
}
else
{
FLDV_ fldv{static_cast(type), static_cast(tsid), lag};
fldv.write(CompileCode, instruction_number);
}
}
}
else
{
FLDSV_ fldsv{static_cast(type), static_cast(tsid)};
fldsv.write(CompileCode, instruction_number);
}
}
else
{
if (dynamic)
{
if (steady_dynamic) // steady state values in a dynamic model
{
cerr << "Impossible case: steady_state in rhs of equation" << endl;
exit(EXIT_FAILURE);
}
else
{
if (type == SymbolType::parameter)
{
FSTPV_ fstpv{static_cast(type), static_cast(tsid)};
fstpv.write(CompileCode, instruction_number);
}
else
{
FSTPV_ fstpv{static_cast(type), static_cast(tsid), lag};
fstpv.write(CompileCode, instruction_number);
}
}
}
else
{
FSTPSV_ fstpsv{static_cast(type), static_cast(tsid)};
fstpsv.write(CompileCode, instruction_number);
}
}
}
}
void
VariableNode::collectVARLHSVariable(set &result) const
{
if (get_type() == SymbolType::endogenous && lag == 0)
result.insert(const_cast(this));
else
{
cerr << "ERROR: you can only have endogenous variables or unary ops on LHS of VAR" << endl;
exit(EXIT_FAILURE);
}
}
void
VariableNode::collectDynamicVariables(SymbolType type_arg, set> &result) const
{
if (get_type() == type_arg)
result.emplace(symb_id, lag);
if (get_type() == SymbolType::modelLocalVariable)
datatree.getLocalVariable(symb_id)->collectDynamicVariables(type_arg, result);
}
void
VariableNode::computeSubExprContainingVariable(int symb_id_arg, int lag_arg, set &contain_var) const
{
if (symb_id == symb_id_arg && lag == lag_arg)
contain_var.insert(const_cast(this));
if (get_type() == SymbolType::modelLocalVariable)
datatree.getLocalVariable(symb_id)->computeSubExprContainingVariable(symb_id_arg, lag_arg, contain_var);
}
BinaryOpNode *
VariableNode::normalizeEquationHelper(const set &contain_var, expr_t rhs) const
{
assert(contain_var.contains(const_cast(this)));
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->normalizeEquationHelper(contain_var, rhs);
// This the LHS variable: we have finished the normalization
return datatree.AddEqual(const_cast(this), rhs);
}
expr_t
VariableNode::getChainRuleDerivative(int deriv_id, const map &recursive_variables)
{
switch (get_type())
{
case SymbolType::endogenous:
case SymbolType::exogenous:
case SymbolType::exogenousDet:
case SymbolType::parameter:
case SymbolType::trend:
case SymbolType::logTrend:
if (deriv_id == datatree.getDerivID(symb_id, lag))
return datatree.One;
// If there is in the equation a recursive variable we could use a chaine rule derivation
else if (auto it = recursive_variables.find(datatree.getDerivID(symb_id, lag));
it != recursive_variables.end())
return it->second->arg2->getChainRuleDerivative(deriv_id, recursive_variables);
else
return datatree.Zero;
case SymbolType::modelLocalVariable:
return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables);
case SymbolType::modFileLocalVariable:
cerr << "modFileLocalVariable is not derivable" << endl;
exit(EXIT_FAILURE);
case SymbolType::statementDeclaredVariable:
cerr << "statementDeclaredVariable is not derivable" << endl;
exit(EXIT_FAILURE);
case SymbolType::unusedEndogenous:
cerr << "unusedEndogenous is not derivable" << endl;
exit(EXIT_FAILURE);
case SymbolType::externalFunction:
case SymbolType::epilogue:
case SymbolType::excludedVariable:
cerr << "VariableNode::getChainRuleDerivative: Impossible case" << endl;
exit(EXIT_FAILURE);
}
// Suppress GCC warning
exit(EXIT_FAILURE);
}
expr_t
VariableNode::toStatic(DataTree &static_datatree) const
{
return static_datatree.AddVariable(symb_id);
}
void
VariableNode::computeXrefs(EquationInfo &ei) const
{
switch (get_type())
{
case SymbolType::endogenous:
ei.endo.emplace(symb_id, lag);
break;
case SymbolType::exogenous:
ei.exo.emplace(symb_id, lag);
break;
case SymbolType::exogenousDet:
ei.exo_det.emplace(symb_id, lag);
break;
case SymbolType::parameter:
ei.param.emplace(symb_id, 0);
break;
case SymbolType::modFileLocalVariable:
datatree.getLocalVariable(symb_id)->computeXrefs(ei);
break;
case SymbolType::trend:
case SymbolType::logTrend:
case SymbolType::modelLocalVariable:
case SymbolType::statementDeclaredVariable:
case SymbolType::unusedEndogenous:
case SymbolType::externalFunction:
case SymbolType::epilogue:
case SymbolType::excludedVariable:
break;
}
}
SymbolType
VariableNode::get_type() const
{
return datatree.symbol_table.getType(symb_id);
}
expr_t
VariableNode::clone(DataTree &datatree) const
{
return datatree.AddVariable(symb_id, lag);
}
int
VariableNode::maxEndoLead() const
{
switch (get_type())
{
case SymbolType::endogenous:
return max(lag, 0);
case SymbolType::modelLocalVariable:
return datatree.getLocalVariable(symb_id)->maxEndoLead();
default:
return 0;
}
}
int
VariableNode::maxExoLead() const
{
switch (get_type())
{
case SymbolType::exogenous:
return max(lag, 0);
case SymbolType::modelLocalVariable:
return datatree.getLocalVariable(symb_id)->maxExoLead();
default:
return 0;
}
}
int
VariableNode::maxEndoLag() const
{
switch (get_type())
{
case SymbolType::endogenous:
return max(-lag, 0);
case SymbolType::modelLocalVariable:
return datatree.getLocalVariable(symb_id)->maxEndoLag();
default:
return 0;
}
}
int
VariableNode::maxExoLag() const
{
switch (get_type())
{
case SymbolType::exogenous:
return max(-lag, 0);
case SymbolType::modelLocalVariable:
return datatree.getLocalVariable(symb_id)->maxExoLag();
default:
return 0;
}
}
int
VariableNode::maxLead() const
{
switch (get_type())
{
case SymbolType::endogenous:
case SymbolType::exogenous:
case SymbolType::exogenousDet:
return lag;
case SymbolType::modelLocalVariable:
return datatree.getLocalVariable(symb_id)->maxLead();
default:
return 0;
}
}
int
VariableNode::maxLag() const
{
switch (get_type())
{
case SymbolType::endogenous:
case SymbolType::exogenous:
case SymbolType::exogenousDet:
return -lag;
case SymbolType::modelLocalVariable:
return datatree.getLocalVariable(symb_id)->maxLag();
default:
return 0;
}
}
int
VariableNode::maxLagWithDiffsExpanded() const
{
switch (get_type())
{
case SymbolType::endogenous:
case SymbolType::exogenous:
case SymbolType::exogenousDet:
case SymbolType::epilogue:
return -lag;
case SymbolType::modelLocalVariable:
return datatree.getLocalVariable(symb_id)->maxLagWithDiffsExpanded();
default:
return 0;
}
}
expr_t
VariableNode::undiff() const
{
return const_cast(this);
}
int
VariableNode::VarMaxLag(const set &lhs_lag_equiv) const
{
auto [lag_equiv_repr, index] = getLagEquivalenceClass();
if (lhs_lag_equiv.contains(lag_equiv_repr))
return maxLag();
else
return 0;
}
expr_t
VariableNode::substituteAdl() const
{
/* Do not recurse into model-local variables definition, rather do it at the
DynamicModel method level (see the comment there) */
return const_cast(this);
}
expr_t
VariableNode::substituteModelLocalVariables() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id);
return const_cast(this);
}
expr_t
VariableNode::substituteVarExpectation(const map &subst_table) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substituteVarExpectation(subst_table);
return const_cast(this);
}
void
VariableNode::findDiffNodes(lag_equivalence_table_t &nodes) const
{
if (get_type() == SymbolType::modelLocalVariable)
datatree.getLocalVariable(symb_id)->findDiffNodes(nodes);
}
void
VariableNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
{
if (get_type() == SymbolType::modelLocalVariable)
datatree.getLocalVariable(symb_id)->findUnaryOpNodesForAuxVarCreation(nodes);
}
optional
VariableNode::findTargetVariable(int lhs_symb_id) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->findTargetVariable(lhs_symb_id);
return nullopt;
}
expr_t
VariableNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
vector &neweqs) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substituteDiff(nodes, subst_table, neweqs);
return const_cast(this);
}
expr_t
VariableNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector &neweqs) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substituteUnaryOpNodes(nodes, subst_table, neweqs);
return const_cast(this);
}
expr_t
VariableNode::substitutePacExpectation(const string &name, expr_t subexpr)
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substitutePacExpectation(name, subexpr);
return const_cast(this);
}
expr_t
VariableNode::substitutePacTargetNonstationary(const string &name, expr_t subexpr)
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substitutePacTargetNonstationary(name, subexpr);
return const_cast(this);
}
expr_t
VariableNode::decreaseLeadsLags(int n) const
{
switch (get_type())
{
case SymbolType::endogenous:
case SymbolType::exogenous:
case SymbolType::exogenousDet:
case SymbolType::trend:
case SymbolType::logTrend:
return datatree.AddVariable(symb_id, lag-n);
case SymbolType::modelLocalVariable:
return datatree.getLocalVariable(symb_id)->decreaseLeadsLags(n);
default:
return const_cast(this);
}
}
expr_t
VariableNode::decreaseLeadsLagsPredeterminedVariables() const
{
/* Do not recurse into model-local variables definitions, since MLVs are
already handled by DynamicModel::transformPredeterminedVariables().
This is also necessary because of #65. */
if (datatree.symbol_table.isPredetermined(symb_id))
return decreaseLeadsLags(1);
else
return const_cast(this);
}
expr_t
VariableNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector &neweqs, bool deterministic_model) const
{
switch (get_type())
{
case SymbolType::endogenous:
if (lag <= 1)
return const_cast(this);
else
return createEndoLeadAuxiliaryVarForMyself(subst_table, neweqs);
case SymbolType::modelLocalVariable:
if (expr_t value = datatree.getLocalVariable(symb_id); value->maxEndoLead() <= 1)
return const_cast(this);
else
return value->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model);
default:
return const_cast(this);
}
}
expr_t
VariableNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const
{
VariableNode *substexpr;
int cur_lag;
switch (get_type())
{
case SymbolType::endogenous:
if (lag >= -1)
return const_cast(this);
if (auto it = subst_table.find(this); it != subst_table.end())
return const_cast(it->second);
substexpr = datatree.AddVariable(symb_id, -1);
cur_lag = -2;
// Each iteration tries to create an auxvar such that auxvar(-1)=curvar(cur_lag)
// At the beginning (resp. end) of each iteration, substexpr is an expression (possibly an auxvar) equivalent to curvar(cur_lag+1) (resp. curvar(cur_lag))
while (cur_lag >= lag)
{
VariableNode *orig_expr = datatree.AddVariable(symb_id, cur_lag);
if (auto it = subst_table.find(orig_expr); it == subst_table.end())
{
int aux_symb_id = datatree.symbol_table.addEndoLagAuxiliaryVar(symb_id, cur_lag+1, substexpr);
neweqs.push_back(datatree.AddEqual(datatree.AddVariable(aux_symb_id, 0), substexpr));
substexpr = datatree.AddVariable(aux_symb_id, -1);
subst_table[orig_expr] = substexpr;
}
else
substexpr = const_cast(it->second);
cur_lag--;
}
return substexpr;
case SymbolType::modelLocalVariable:
if (expr_t value = datatree.getLocalVariable(symb_id); value->maxEndoLag() <= 1)
return const_cast(this);
else
return value->substituteEndoLagGreaterThanTwo(subst_table, neweqs);
default:
return const_cast(this);
}
}
expr_t
VariableNode::substituteExoLead(subst_table_t &subst_table, vector &neweqs, bool deterministic_model) const
{
switch (get_type())
{
case SymbolType::exogenous:
if (lag <= 0)
return const_cast(this);
else
return createExoLeadAuxiliaryVarForMyself(subst_table, neweqs);
case SymbolType::modelLocalVariable:
if (expr_t value = datatree.getLocalVariable(symb_id); value->maxExoLead() == 0)
return const_cast(this);
else
return value->substituteExoLead(subst_table, neweqs, deterministic_model);
default:
return const_cast(this);
}
}
expr_t
VariableNode::substituteExoLag(subst_table_t &subst_table, vector &neweqs) const
{
VariableNode *substexpr;
int cur_lag;
switch (get_type())
{
case SymbolType::exogenous:
if (lag >= 0)
return const_cast(this);
if (auto it = subst_table.find(this); it != subst_table.end())
return const_cast(it->second);
substexpr = datatree.AddVariable(symb_id, 0);
cur_lag = -1;
// Each iteration tries to create an auxvar such that auxvar(-1)=curvar(cur_lag)
// At the beginning (resp. end) of each iteration, substexpr is an expression (possibly an auxvar) equivalent to curvar(cur_lag+1) (resp. curvar(cur_lag))
while (cur_lag >= lag)
{
VariableNode *orig_expr = datatree.AddVariable(symb_id, cur_lag);
if (auto it = subst_table.find(orig_expr); it == subst_table.end())
{
int aux_symb_id = datatree.symbol_table.addExoLagAuxiliaryVar(symb_id, cur_lag+1, substexpr);
neweqs.push_back(datatree.AddEqual(datatree.AddVariable(aux_symb_id, 0), substexpr));
substexpr = datatree.AddVariable(aux_symb_id, -1);
subst_table[orig_expr] = substexpr;
}
else
substexpr = const_cast(it->second);
cur_lag--;
}
return substexpr;
case SymbolType::modelLocalVariable:
if (expr_t value = datatree.getLocalVariable(symb_id); value->maxExoLag() == 0)
return const_cast(this);
else
return value->substituteExoLag(subst_table, neweqs);
default:
return const_cast(this);
}
}
expr_t
VariableNode::substituteExpectation(subst_table_t &subst_table, vector &neweqs, bool partial_information_model) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substituteExpectation(subst_table, neweqs, partial_information_model);
return const_cast(this);
}
expr_t
VariableNode::differentiateForwardVars(const vector &subset, subst_table_t &subst_table, vector &neweqs) const
{
switch (get_type())
{
case SymbolType::endogenous:
assert(lag <= 1);
if (lag <= 0
|| (subset.size() > 0
&& find(subset.begin(), subset.end(), datatree.symbol_table.getName(symb_id)) == subset.end()))
return const_cast(this);
else
{
VariableNode *diffvar;
if (auto it = subst_table.find(this); it != subst_table.end())
diffvar = const_cast(it->second);
else
{
expr_t substexpr = datatree.AddMinus(datatree.AddVariable(symb_id, 0),
datatree.AddVariable(symb_id, -1));
int aux_symb_id = datatree.symbol_table.addDiffForwardAuxiliaryVar(symb_id, 0, substexpr);
neweqs.push_back(datatree.AddEqual(datatree.AddVariable(aux_symb_id, 0), substexpr));
diffvar = datatree.AddVariable(aux_symb_id, 1);
subst_table[this] = diffvar;
}
return datatree.AddPlus(datatree.AddVariable(symb_id, 0), diffvar);
}
case SymbolType::modelLocalVariable:
if (expr_t value = datatree.getLocalVariable(symb_id); value->maxEndoLead() <= 0)
return const_cast(this);
else
return value->differentiateForwardVars(subset, subst_table, neweqs);
default:
return const_cast(this);
}
}
bool
VariableNode::isNumConstNodeEqualTo(double value) const
{
return false;
}
bool
VariableNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
{
if (get_type() == type_arg && datatree.symbol_table.getTypeSpecificID(symb_id) == variable_id && lag == lag_arg)
return true;
else
return false;
}
bool
VariableNode::containsPacExpectation(const string &pac_model_name) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->containsPacExpectation(pac_model_name);
return false;
}
bool
VariableNode::containsPacTargetNonstationary(const string &pac_model_name) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->containsPacTargetNonstationary(pac_model_name);
return false;
}
expr_t
VariableNode::replaceTrendVar() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->replaceTrendVar();
if (get_type() == SymbolType::trend)
return datatree.One;
else if (get_type() == SymbolType::logTrend)
return datatree.Zero;
else
return const_cast(this);
}
expr_t
VariableNode::detrend(int symb_id, bool log_trend, expr_t trend) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->detrend(symb_id, log_trend, trend);
if (this->symb_id != symb_id)
return const_cast(this);
if (log_trend)
{
if (lag == 0)
return datatree.AddPlus(const_cast(this), trend);
else
return datatree.AddPlus(const_cast(this), trend->decreaseLeadsLags(-lag));
}
else
{
if (lag == 0)
return datatree.AddTimes(const_cast(this), trend);
else
return datatree.AddTimes(const_cast(this), trend->decreaseLeadsLags(-lag));
}
}
int
VariableNode::countDiffs() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->countDiffs();
return 0;
}
expr_t
VariableNode::removeTrendLeadLag(const map &trend_symbols_map) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->removeTrendLeadLag(trend_symbols_map);
if ((get_type() != SymbolType::trend && get_type() != SymbolType::logTrend) || lag == 0)
return const_cast(this);
auto it = trend_symbols_map.find(symb_id);
expr_t noTrendLeadLagNode = datatree.AddVariable(it->first);
bool log_trend = get_type() == SymbolType::logTrend;
expr_t trend = it->second;
if (lag > 0)
{
expr_t growthFactorSequence = trend->decreaseLeadsLags(-1);
if (log_trend)
{
for (int i = 1; i < lag; i++)
growthFactorSequence = datatree.AddPlus(growthFactorSequence, trend->decreaseLeadsLags(-1*(i+1)));
return datatree.AddPlus(noTrendLeadLagNode, growthFactorSequence);
}
else
{
for (int i = 1; i < lag; i++)
growthFactorSequence = datatree.AddTimes(growthFactorSequence, trend->decreaseLeadsLags(-1*(i+1)));
return datatree.AddTimes(noTrendLeadLagNode, growthFactorSequence);
}
}
else //get_lag < 0
{
expr_t growthFactorSequence = trend;
if (log_trend)
{
for (int i = 1; i < abs(lag); i++)
growthFactorSequence = datatree.AddPlus(growthFactorSequence, trend->decreaseLeadsLags(i));
return datatree.AddMinus(noTrendLeadLagNode, growthFactorSequence);
}
else
{
for (int i = 1; i < abs(lag); i++)
growthFactorSequence = datatree.AddTimes(growthFactorSequence, trend->decreaseLeadsLags(i));
return datatree.AddDivide(noTrendLeadLagNode, growthFactorSequence);
}
}
}
bool
VariableNode::isInStaticForm() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->isInStaticForm();
return lag == 0;
}
bool
VariableNode::isParamTimesEndogExpr() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->isParamTimesEndogExpr();
return false;
}
expr_t
VariableNode::replaceVarsInEquation(map &table) const
{
/* Do not recurse into model-local variables definitions, since MLVs are
already handled by DynamicModel::simplifyEquations().
This is also necessary because of #65. */
for (auto &it : table)
if (it.first->symb_id == symb_id)
return it.second;
return const_cast(this);
}
void
VariableNode::matchMatchedMoment(vector &symb_ids, vector &lags, vector &powers) const
{
/* Used for simple expression outside model block, so no need to special-case
model local variables */
if (get_type() != SymbolType::endogenous)
throw MatchFailureException{"Variable " + datatree.symbol_table.getName(symb_id) + " is not an endogenous"};
symb_ids.push_back(symb_id);
lags.push_back(lag);
powers.push_back(1);
}
expr_t
VariableNode::substituteLogTransform(int orig_symb_id, int aux_symb_id) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substituteLogTransform(orig_symb_id, aux_symb_id);
if (symb_id == orig_symb_id)
return datatree.AddExp(datatree.AddVariable(aux_symb_id, lag));
else
return const_cast(this);
}
UnaryOpNode::UnaryOpNode(DataTree &datatree_arg, int idx_arg, UnaryOpcode op_code_arg, const expr_t arg_arg, int expectation_information_set_arg, int param1_symb_id_arg, int param2_symb_id_arg, string adl_param_name_arg, vector adl_lags_arg) :
ExprNode{datatree_arg, idx_arg},
arg{arg_arg},
expectation_information_set{expectation_information_set_arg},
param1_symb_id{param1_symb_id_arg},
param2_symb_id{param2_symb_id_arg},
op_code{op_code_arg},
adl_param_name{move(adl_param_name_arg)},
adl_lags{move(adl_lags_arg)}
{
}
void
UnaryOpNode::prepareForDerivation()
{
if (preparedForDerivation)
return;
preparedForDerivation = true;
arg->prepareForDerivation();
// Non-null derivatives are those of the argument (except for STEADY_STATE)
non_null_derivatives = arg->non_null_derivatives;
if (op_code == UnaryOpcode::steadyState || op_code == UnaryOpcode::steadyStateParamDeriv
|| op_code == UnaryOpcode::steadyStateParam2ndDeriv)
datatree.addAllParamDerivId(non_null_derivatives);
}
expr_t
UnaryOpNode::composeDerivatives(expr_t darg, int deriv_id)
{
expr_t t11, t12, t13, t14, t15;
switch (op_code)
{
case UnaryOpcode::uminus:
return datatree.AddUMinus(darg);
case UnaryOpcode::exp:
return datatree.AddTimes(darg, this);
case UnaryOpcode::log:
return datatree.AddDivide(darg, arg);
case UnaryOpcode::log10:
t11 = datatree.AddExp(datatree.One);
t12 = datatree.AddLog10(t11);
t13 = datatree.AddDivide(darg, arg);
return datatree.AddTimes(t12, t13);
case UnaryOpcode::cos:
t11 = datatree.AddSin(arg);
t12 = datatree.AddUMinus(t11);
return datatree.AddTimes(darg, t12);
case UnaryOpcode::sin:
t11 = datatree.AddCos(arg);
return datatree.AddTimes(darg, t11);
case UnaryOpcode::tan:
t11 = datatree.AddTimes(this, this);
t12 = datatree.AddPlus(t11, datatree.One);
return datatree.AddTimes(darg, t12);
case UnaryOpcode::acos:
t11 = datatree.AddSin(this);
t12 = datatree.AddDivide(darg, t11);
return datatree.AddUMinus(t12);
case UnaryOpcode::asin:
t11 = datatree.AddCos(this);
return datatree.AddDivide(darg, t11);
case UnaryOpcode::atan:
t11 = datatree.AddTimes(arg, arg);
t12 = datatree.AddPlus(datatree.One, t11);
return datatree.AddDivide(darg, t12);
case UnaryOpcode::cosh:
t11 = datatree.AddSinh(arg);
return datatree.AddTimes(darg, t11);
case UnaryOpcode::sinh:
t11 = datatree.AddCosh(arg);
return datatree.AddTimes(darg, t11);
case UnaryOpcode::tanh:
t11 = datatree.AddTimes(this, this);
t12 = datatree.AddMinus(datatree.One, t11);
return datatree.AddTimes(darg, t12);
case UnaryOpcode::acosh:
t11 = datatree.AddSinh(this);
return datatree.AddDivide(darg, t11);
case UnaryOpcode::asinh:
t11 = datatree.AddCosh(this);
return datatree.AddDivide(darg, t11);
case UnaryOpcode::atanh:
t11 = datatree.AddTimes(arg, arg);
t12 = datatree.AddMinus(datatree.One, t11);
return datatree.AddDivide(darg, t12);
case UnaryOpcode::sqrt:
t11 = datatree.AddPlus(this, this);
return datatree.AddDivide(darg, t11);
case UnaryOpcode::cbrt:
t11 = datatree.AddPower(arg, datatree.AddDivide(datatree.Two, datatree.Three));
t12 = datatree.AddTimes(datatree.Three, t11);
return datatree.AddDivide(darg, t12);
case UnaryOpcode::abs:
t11 = datatree.AddSign(arg);
return datatree.AddTimes(t11, darg);
case UnaryOpcode::sign:
return datatree.Zero;
case UnaryOpcode::steadyState:
if (datatree.isDynamic())
{
if (datatree.getTypeByDerivID(deriv_id) == SymbolType::parameter)
{
auto varg = dynamic_cast(arg);
if (!varg)
{
cerr << "UnaryOpNode::composeDerivatives: STEADY_STATE() should only be used on "
<< "standalone variables (like STEADY_STATE(y)) to be derivable w.r.t. parameters" << endl;
exit(EXIT_FAILURE);
}
if (datatree.symbol_table.getType(varg->symb_id) == SymbolType::endogenous)
return datatree.AddSteadyStateParamDeriv(arg, datatree.getSymbIDByDerivID(deriv_id));
else
return datatree.Zero;
}
else
return datatree.Zero;
}
else
return darg;
case UnaryOpcode::steadyStateParamDeriv:
assert(datatree.isDynamic());
if (datatree.getTypeByDerivID(deriv_id) == SymbolType::parameter)
{
auto varg = dynamic_cast(arg);
assert(varg);
assert(datatree.symbol_table.getType(varg->symb_id) == SymbolType::endogenous);
return datatree.AddSteadyStateParam2ndDeriv(arg, param1_symb_id, datatree.getSymbIDByDerivID(deriv_id));
}
else
return datatree.Zero;
case UnaryOpcode::steadyStateParam2ndDeriv:
assert(datatree.isDynamic());
if (datatree.getTypeByDerivID(deriv_id) == SymbolType::parameter)
{
cerr << "3rd derivative of STEADY_STATE node w.r.t. three parameters not implemented" << endl;
exit(EXIT_FAILURE);
}
else
return datatree.Zero;
case UnaryOpcode::expectation:
cerr << "UnaryOpNode::composeDerivatives: not implemented on UnaryOpcode::expectation" << endl;
exit(EXIT_FAILURE);
case UnaryOpcode::erf:
case UnaryOpcode::erfc:
// x^2
t11 = datatree.AddPower(arg, datatree.Two);
// exp(x^2)
t12 = datatree.AddExp(t11);
// sqrt(pi)
t11 = datatree.AddSqrt(datatree.Pi);
// sqrt(pi)*exp(x^2)
t13 = datatree.AddTimes(t11, t12);
// 2/(sqrt(pi)*exp(x^2));
t14 = datatree.AddDivide(datatree.Two, t13);
// (2/(sqrt(pi)*exp(x^2)))*dx;
t15 = datatree.AddTimes(t14, darg);
if (op_code == UnaryOpcode::erf)
return t15;
else // erfc
return datatree.AddUMinus(t15);
case UnaryOpcode::diff:
cerr << "UnaryOpNode::composeDerivatives: not implemented on UnaryOpcode::diff" << endl;
exit(EXIT_FAILURE);
case UnaryOpcode::adl:
cerr << "UnaryOpNode::composeDerivatives: not implemented on UnaryOpcode::adl" << endl;
exit(EXIT_FAILURE);
}
// Suppress GCC warning
exit(EXIT_FAILURE);
}
expr_t
UnaryOpNode::computeDerivative(int deriv_id)
{
expr_t darg = arg->getDerivative(deriv_id);
return composeDerivatives(darg, deriv_id);
}
int
UnaryOpNode::cost(const map, temporary_terms_t> &temp_terms_map, bool is_matlab) const
{
// For a temporary term, the cost is null
for (const auto &it : temp_terms_map)
if (it.second.contains(const_cast(this)))
return 0;
return cost(arg->cost(temp_terms_map, is_matlab), is_matlab);
}
int
UnaryOpNode::cost(const vector> &blocks_temporary_terms, bool is_matlab) const
{
// For a temporary term, the cost is null
for (const auto &blk_tt : blocks_temporary_terms)
for (const auto &eq_tt : blk_tt)
if (eq_tt.contains(const_cast(this)))
return 0;
return cost(arg->cost(blocks_temporary_terms, is_matlab), is_matlab);
}
int
UnaryOpNode::cost(int cost, bool is_matlab) const
{
if (op_code == UnaryOpcode::uminus && dynamic_cast(arg))
return 0; // Cost is zero for a negative constant, as for a positive one
if (is_matlab)
// Cost for Matlab files
switch (op_code)
{
case UnaryOpcode::uminus:
case UnaryOpcode::sign:
return cost + 70;
case UnaryOpcode::exp:
return cost + 160;
case UnaryOpcode::log:
return cost + 300;
case UnaryOpcode::log10:
case UnaryOpcode::erf:
case UnaryOpcode::erfc:
return cost + 16000;
case UnaryOpcode::cos:
case UnaryOpcode::sin:
case UnaryOpcode::cosh:
return cost + 210;
case UnaryOpcode::tan:
return cost + 230;
case UnaryOpcode::acos:
return cost + 300;
case UnaryOpcode::asin:
return cost + 310;
case UnaryOpcode::atan:
return cost + 140;
case UnaryOpcode::sinh:
return cost + 240;
case UnaryOpcode::tanh:
return cost + 190;
case UnaryOpcode::acosh:
return cost + 770;
case UnaryOpcode::asinh:
return cost + 460;
case UnaryOpcode::atanh:
return cost + 350;
case UnaryOpcode::sqrt:
case UnaryOpcode::cbrt:
case UnaryOpcode::abs:
return cost + 570;
case UnaryOpcode::steadyState:
case UnaryOpcode::steadyStateParamDeriv:
case UnaryOpcode::steadyStateParam2ndDeriv:
case UnaryOpcode::expectation:
return cost;
case UnaryOpcode::diff:
cerr << "UnaryOpNode::cost: not implemented on UnaryOpcode::diff" << endl;
exit(EXIT_FAILURE);
case UnaryOpcode::adl:
cerr << "UnaryOpNode::cost: not implemented on UnaryOpcode::adl" << endl;
exit(EXIT_FAILURE);
}
else
// Cost for C files
switch (op_code)
{
case UnaryOpcode::uminus:
case UnaryOpcode::sign:
return cost + 3;
case UnaryOpcode::exp:
case UnaryOpcode::acosh:
return cost + 210;
case UnaryOpcode::log:
return cost + 137;
case UnaryOpcode::log10:
return cost + 139;
case UnaryOpcode::cos:
case UnaryOpcode::sin:
return cost + 160;
case UnaryOpcode::tan:
return cost + 170;
case UnaryOpcode::acos:
case UnaryOpcode::atan:
return cost + 190;
case UnaryOpcode::asin:
return cost + 180;
case UnaryOpcode::cosh:
case UnaryOpcode::sinh:
case UnaryOpcode::tanh:
case UnaryOpcode::erf:
case UnaryOpcode::erfc:
return cost + 240;
case UnaryOpcode::asinh:
return cost + 220;
case UnaryOpcode::atanh:
return cost + 150;
case UnaryOpcode::sqrt:
case UnaryOpcode::cbrt:
case UnaryOpcode::abs:
return cost + 90;
case UnaryOpcode::steadyState:
case UnaryOpcode::steadyStateParamDeriv:
case UnaryOpcode::steadyStateParam2ndDeriv:
case UnaryOpcode::expectation:
return cost;
case UnaryOpcode::diff:
cerr << "UnaryOpNode::cost: not implemented on UnaryOpcode::diff" << endl;
exit(EXIT_FAILURE);
case UnaryOpcode::adl:
cerr << "UnaryOpNode::cost: not implemented on UnaryOpcode::adl" << endl;
exit(EXIT_FAILURE);
}
exit(EXIT_FAILURE);
}
void
UnaryOpNode::computeTemporaryTerms(const pair &derivOrder,
map, temporary_terms_t> &temp_terms_map,
map>> &reference_count,
bool is_matlab) const
{
expr_t this2 = const_cast(this);
if (auto it = reference_count.find(this2);
it == reference_count.end())
{
reference_count[this2] = { 1, derivOrder };
arg->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
}
else
{
auto &[nref, min_order] = it->second;
nref++;
if (nref * cost(temp_terms_map, is_matlab) > min_cost(is_matlab))
temp_terms_map[min_order].insert(this2);
}
}
void
UnaryOpNode::computeBlockTemporaryTerms(int blk, int eq, vector> &blocks_temporary_terms,
map> &reference_count) const
{
expr_t this2 = const_cast(this);
if (auto it = reference_count.find(this2);
it == reference_count.end())
{
reference_count[this2] = { 1, blk, eq };
arg->computeBlockTemporaryTerms(blk, eq, blocks_temporary_terms, reference_count);
}
else
{
auto &[nref, first_blk, first_eq] = it->second;
nref++;
if (nref * cost(blocks_temporary_terms, false) > min_cost_c)
blocks_temporary_terms[first_blk][first_eq].insert(this2);
}
}
bool
UnaryOpNode::containsExternalFunction() const
{
return arg->containsExternalFunction();
}
void
UnaryOpNode::writeJsonAST(ostream &output) const
{
output << R"({"node_type" : "UnaryOpNode", "op" : ")";
switch (op_code)
{
case UnaryOpcode::uminus:
output << "uminus";
break;
case UnaryOpcode::exp:
output << "exp";
break;
case UnaryOpcode::log:
output << "log";
break;
case UnaryOpcode::log10:
output << "log10";
break;
case UnaryOpcode::cos:
output << "cos";
break;
case UnaryOpcode::sin:
output << "sin";
break;
case UnaryOpcode::tan:
output << "tan";
break;
case UnaryOpcode::acos:
output << "acos";
break;
case UnaryOpcode::asin:
output << "asin";
break;
case UnaryOpcode::atan:
output << "atan";
break;
case UnaryOpcode::cosh:
output << "cosh";
break;
case UnaryOpcode::sinh:
output << "sinh";
break;
case UnaryOpcode::tanh:
output << "tanh";
break;
case UnaryOpcode::acosh:
output << "acosh";
break;
case UnaryOpcode::asinh:
output << "asinh";
break;
case UnaryOpcode::atanh:
output << "atanh";
break;
case UnaryOpcode::sqrt:
output << "sqrt";
break;
case UnaryOpcode::cbrt:
output << "cbrt";
break;
case UnaryOpcode::abs:
output << "abs";
break;
case UnaryOpcode::sign:
output << "sign";
break;
case UnaryOpcode::diff:
output << "diff";
break;
case UnaryOpcode::adl:
output << "adl";
break;
case UnaryOpcode::steadyState:
output << "steady_state";
case UnaryOpcode::steadyStateParamDeriv:
output << "steady_state_param_deriv";
break;
case UnaryOpcode::steadyStateParam2ndDeriv:
output << "steady_state_param_second_deriv";
break;
case UnaryOpcode::expectation:
output << "expectation";
break;
case UnaryOpcode::erf:
output << "erf";
break;
case UnaryOpcode::erfc:
output << "erfc";
break;
}
output << R"(", "arg" : )";
arg->writeJsonAST(output);
switch (op_code)
{
case UnaryOpcode::adl:
output << R"(, "adl_param_name" : ")" << adl_param_name << R"(")"
<< R"(, "lags" : [)";
for (auto it = adl_lags.begin(); it != adl_lags.end(); ++it)
{
if (it != adl_lags.begin())
output << ", ";
output << *it;
}
output << "]";
break;
default:
break;
}
output << "}";
}
void
UnaryOpNode::writeJsonOutput(ostream &output,
const temporary_terms_t &temporary_terms,
const deriv_node_temp_terms_t &tef_terms,
bool isdynamic) const
{
if (temporary_terms.contains(const_cast(this)))
{
output << "T" << idx;
return;
}
// Always put parenthesis around uminus nodes
if (op_code == UnaryOpcode::uminus)
output << "(";
switch (op_code)
{
case UnaryOpcode::uminus:
output << "-";
break;
case UnaryOpcode::exp:
output << "exp";
break;
case UnaryOpcode::log:
output << "log";
break;
case UnaryOpcode::log10:
output << "log10";
break;
case UnaryOpcode::cos:
output << "cos";
break;
case UnaryOpcode::sin:
output << "sin";
break;
case UnaryOpcode::tan:
output << "tan";
break;
case UnaryOpcode::acos:
output << "acos";
break;
case UnaryOpcode::asin:
output << "asin";
break;
case UnaryOpcode::atan:
output << "atan";
break;
case UnaryOpcode::cosh:
output << "cosh";
break;
case UnaryOpcode::sinh:
output << "sinh";
break;
case UnaryOpcode::tanh:
output << "tanh";
break;
case UnaryOpcode::acosh:
output << "acosh";
break;
case UnaryOpcode::asinh:
output << "asinh";
break;
case UnaryOpcode::atanh:
output << "atanh";
break;
case UnaryOpcode::sqrt:
output << "sqrt";
break;
case UnaryOpcode::cbrt:
output << "cbrt";
break;
case UnaryOpcode::abs:
output << "abs";
break;
case UnaryOpcode::sign:
output << "sign";
break;
case UnaryOpcode::diff:
output << "diff";
break;
case UnaryOpcode::adl:
output << "adl(";
arg->writeJsonOutput(output, temporary_terms, tef_terms);
output << ", '" << adl_param_name << "', [";
for (auto it = adl_lags.begin(); it != adl_lags.end(); ++it)
{
if (it != adl_lags.begin())
output << ", ";
output << *it;
}
output << "])";
return;
case UnaryOpcode::steadyState:
output << "(";
arg->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
output << ")";
return;
case UnaryOpcode::steadyStateParamDeriv:
{
auto varg = dynamic_cast(arg);
assert(varg);
assert(datatree.symbol_table.getType(varg->symb_id) == SymbolType::endogenous);
assert(datatree.symbol_table.getType(param1_symb_id) == SymbolType::parameter);
int tsid_endo = datatree.symbol_table.getTypeSpecificID(varg->symb_id);
int tsid_param = datatree.symbol_table.getTypeSpecificID(param1_symb_id);
output << "ss_param_deriv(" << tsid_endo+1 << "," << tsid_param+1 << ")";
}
return;
case UnaryOpcode::steadyStateParam2ndDeriv:
{
auto varg = dynamic_cast(arg);
assert(varg);
assert(datatree.symbol_table.getType(varg->symb_id) == SymbolType::endogenous);
assert(datatree.symbol_table.getType(param1_symb_id) == SymbolType::parameter);
assert(datatree.symbol_table.getType(param2_symb_id) == SymbolType::parameter);
int tsid_endo = datatree.symbol_table.getTypeSpecificID(varg->symb_id);
int tsid_param1 = datatree.symbol_table.getTypeSpecificID(param1_symb_id);
int tsid_param2 = datatree.symbol_table.getTypeSpecificID(param2_symb_id);
output << "ss_param_2nd_deriv(" << tsid_endo+1 << "," << tsid_param1+1
<< "," << tsid_param2+1 << ")";
}
return;
case UnaryOpcode::expectation:
output << "EXPECTATION(" << expectation_information_set << ")";
break;
case UnaryOpcode::erf:
output << "erf";
break;
case UnaryOpcode::erfc:
output << "erfc";
break;
}
bool close_parenthesis = false;
/* Enclose argument with parentheses if:
- current opcode is not uminus, or
- current opcode is uminus and argument has lowest precedence
*/
if (op_code != UnaryOpcode::uminus
|| (op_code == UnaryOpcode::uminus
&& arg->precedenceJson(temporary_terms) < precedenceJson(temporary_terms)))
{
output << "(";
close_parenthesis = true;
}
// Write argument
arg->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
if (close_parenthesis)
output << ")";
// Close parenthesis for uminus
if (op_code == UnaryOpcode::uminus)
output << ")";
}
void
UnaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_t &temporary_terms,
const temporary_terms_idxs_t &temporary_terms_idxs,
const deriv_node_temp_terms_t &tef_terms) const
{
if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
return;
// Always put parenthesis around uminus nodes
if (op_code == UnaryOpcode::uminus)
output << LEFT_PAR(output_type);
switch (op_code)
{
case UnaryOpcode::uminus:
output << "-";
break;
case UnaryOpcode::exp:
if (isLatexOutput(output_type))
output << R"(\exp)";
else
output << "exp";
break;
case UnaryOpcode::log:
if (isLatexOutput(output_type))
output << R"(\log)";
else
output << "log";
break;
case UnaryOpcode::log10:
if (isLatexOutput(output_type))
output << R"(\log_{10})";
else
output << "log10";
break;
case UnaryOpcode::cos:
if (isLatexOutput(output_type))
output << R"(\cos)";
else
output << "cos";
break;
case UnaryOpcode::sin:
if (isLatexOutput(output_type))
output << R"(\sin)";
else
output << "sin";
break;
case UnaryOpcode::tan:
if (isLatexOutput(output_type))
output << R"(\tan)";
else
output << "tan";
break;
case UnaryOpcode::acos:
output << "acos";
break;
case UnaryOpcode::asin:
output << "asin";
break;
case UnaryOpcode::atan:
output << "atan";
break;
case UnaryOpcode::cosh:
output << "cosh";
break;
case UnaryOpcode::sinh:
output << "sinh";
break;
case UnaryOpcode::tanh:
output << "tanh";
break;
case UnaryOpcode::acosh:
output << "acosh";
break;
case UnaryOpcode::asinh:
output << "asinh";
break;
case UnaryOpcode::atanh:
output << "atanh";
break;
case UnaryOpcode::sqrt:
if (isLatexOutput(output_type))
{
output << R"(\sqrt{)";
arg->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
output << "}";
return;
}
output << "sqrt";
break;
case UnaryOpcode::cbrt:
if (isMatlabOutput(output_type))
{
output << "nthroot(";
arg->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
output << ", 3)";
return;
}
else if (isLatexOutput(output_type))
{
output << R"(\sqrt[3]{)";
arg->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
output << "}";
return;
}
else
output << "cbrt";
break;
case UnaryOpcode::abs:
output << "abs";
break;
case UnaryOpcode::sign:
if (output_type == ExprNodeOutputType::CDynamicModel || output_type == ExprNodeOutputType::CStaticModel)
output << "copysign";
else
output << "sign";
break;
case UnaryOpcode::steadyState:
ExprNodeOutputType new_output_type;
switch (output_type)
{
case ExprNodeOutputType::matlabDynamicModel:
case ExprNodeOutputType::occbinDifferenceFile:
new_output_type = ExprNodeOutputType::matlabDynamicSteadyStateOperator;
break;
case ExprNodeOutputType::latexDynamicModel:
new_output_type = ExprNodeOutputType::latexDynamicSteadyStateOperator;
break;
case ExprNodeOutputType::CDynamicModel:
new_output_type = ExprNodeOutputType::CDynamicSteadyStateOperator;
break;
case ExprNodeOutputType::juliaDynamicModel:
new_output_type = ExprNodeOutputType::juliaDynamicSteadyStateOperator;
break;
default:
new_output_type = output_type;
break;
}
output << "(";
arg->writeOutput(output, new_output_type, temporary_terms, temporary_terms_idxs, tef_terms);
output << ")";
return;
case UnaryOpcode::steadyStateParamDeriv:
{
auto varg = dynamic_cast(arg);
assert(varg);
assert(datatree.symbol_table.getType(varg->symb_id) == SymbolType::endogenous);
assert(datatree.symbol_table.getType(param1_symb_id) == SymbolType::parameter);
int tsid_endo = datatree.symbol_table.getTypeSpecificID(varg->symb_id);
int tsid_param = datatree.symbol_table.getTypeSpecificID(param1_symb_id);
assert(isMatlabOutput(output_type));
output << "ss_param_deriv(" << tsid_endo+1 << "," << tsid_param+1 << ")";
}
return;
case UnaryOpcode::steadyStateParam2ndDeriv:
{
auto varg = dynamic_cast(arg);
assert(varg);
assert(datatree.symbol_table.getType(varg->symb_id) == SymbolType::endogenous);
assert(datatree.symbol_table.getType(param1_symb_id) == SymbolType::parameter);
assert(datatree.symbol_table.getType(param2_symb_id) == SymbolType::parameter);
int tsid_endo = datatree.symbol_table.getTypeSpecificID(varg->symb_id);
int tsid_param1 = datatree.symbol_table.getTypeSpecificID(param1_symb_id);
int tsid_param2 = datatree.symbol_table.getTypeSpecificID(param2_symb_id);
assert(isMatlabOutput(output_type));
output << "ss_param_2nd_deriv(" << tsid_endo+1 << "," << tsid_param1+1
<< "," << tsid_param2+1 << ")";
}
return;
case UnaryOpcode::expectation:
if (!isLatexOutput(output_type))
{
cerr << "UnaryOpNode::writeOutput: not implemented on UnaryOpcode::expectation" << endl;
exit(EXIT_FAILURE);
}
output << R"(\mathbb{E}_{t)";
if (expectation_information_set != 0)
{
if (expectation_information_set > 0)
output << "+";
output << expectation_information_set;
}
output << "}";
break;
case UnaryOpcode::erf:
output << "erf";
break;
case UnaryOpcode::erfc:
output << "erfc";
break;
case UnaryOpcode::diff:
output << "diff";
break;
case UnaryOpcode::adl:
output << "adl";
break;
}
if (output_type == ExprNodeOutputType::juliaTimeDataFrame
&& op_code != UnaryOpcode::uminus)
output << "."; // Use vectorized form of the function
bool close_parenthesis = false;
/* Enclose argument with parentheses if:
- current opcode is not uminus, or
- current opcode is uminus and argument has lowest precedence
*/
if (op_code != UnaryOpcode::uminus
|| (op_code == UnaryOpcode::uminus
&& arg->precedence(output_type, temporary_terms) < precedence(output_type, temporary_terms)))
{
output << LEFT_PAR(output_type);
if (op_code == UnaryOpcode::sign && (output_type == ExprNodeOutputType::CDynamicModel || output_type == ExprNodeOutputType::CStaticModel))
output << "1.0,";
close_parenthesis = true;
}
// Write argument
arg->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
if (close_parenthesis)
output << RIGHT_PAR(output_type);
// Close parenthesis for uminus
if (op_code == UnaryOpcode::uminus)
output << RIGHT_PAR(output_type);
}
void
UnaryOpNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_t &temporary_terms,
const temporary_terms_idxs_t &temporary_terms_idxs,
deriv_node_temp_terms_t &tef_terms) const
{
arg->writeExternalFunctionOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
}
void
UnaryOpNode::writeJsonExternalFunctionOutput(vector &efout,
const temporary_terms_t &temporary_terms,
deriv_node_temp_terms_t &tef_terms,
bool isdynamic) const
{
arg->writeJsonExternalFunctionOutput(efout, temporary_terms, tef_terms, isdynamic);
}
void
UnaryOpNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number,
bool lhs_rhs, const temporary_terms_t &temporary_terms,
const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic,
deriv_node_temp_terms_t &tef_terms) const
{
arg->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms,
temporary_terms_idxs, dynamic, steady_dynamic, tef_terms);
}
double
UnaryOpNode::eval_opcode(UnaryOpcode op_code, double v) noexcept(false)
{
switch (op_code)
{
case UnaryOpcode::uminus:
return -v;
case UnaryOpcode::exp:
return exp(v);
case UnaryOpcode::log:
return log(v);
case UnaryOpcode::log10:
return log10(v);
case UnaryOpcode::cos:
return cos(v);
case UnaryOpcode::sin:
return sin(v);
case UnaryOpcode::tan:
return tan(v);
case UnaryOpcode::acos:
return acos(v);
case UnaryOpcode::asin:
return asin(v);
case UnaryOpcode::atan:
return atan(v);
case UnaryOpcode::cosh:
return cosh(v);
case UnaryOpcode::sinh:
return sinh(v);
case UnaryOpcode::tanh:
return tanh(v);
case UnaryOpcode::acosh:
return acosh(v);
case UnaryOpcode::asinh:
return asinh(v);
case UnaryOpcode::atanh:
return atanh(v);
case UnaryOpcode::sqrt:
return sqrt(v);
case UnaryOpcode::cbrt:
return cbrt(v);
case UnaryOpcode::abs:
return abs(v);
case UnaryOpcode::sign:
return (v > 0) ? 1 : ((v < 0) ? -1 : 0);
case UnaryOpcode::steadyState:
return v;
case UnaryOpcode::steadyStateParamDeriv:
cerr << "UnaryOpNode::eval_opcode: not implemented on UnaryOpcode::steadyStateParamDeriv" << endl;
exit(EXIT_FAILURE);
case UnaryOpcode::steadyStateParam2ndDeriv:
cerr << "UnaryOpNode::eval_opcode: not implemented on UnaryOpcode::steadyStateParam2ndDeriv" << endl;
exit(EXIT_FAILURE);
case UnaryOpcode::expectation:
cerr << "UnaryOpNode::eval_opcode: not implemented on UnaryOpcode::expectation" << endl;
exit(EXIT_FAILURE);
case UnaryOpcode::erf:
return erf(v);
case UnaryOpcode::erfc:
return erfc(v);
case UnaryOpcode::diff:
cerr << "UnaryOpNode::eval_opcode: not implemented on UnaryOpcode::diff" << endl;
exit(EXIT_FAILURE);
case UnaryOpcode::adl:
cerr << "UnaryOpNode::eval_opcode: not implemented on UnaryOpcode::adl" << endl;
exit(EXIT_FAILURE);
}
// Suppress GCC warning
exit(EXIT_FAILURE);
}
double
UnaryOpNode::eval(const eval_context_t &eval_context) const noexcept(false)
{
double v = arg->eval(eval_context);
return eval_opcode(op_code, v);
}
void
UnaryOpNode::compile(ostream &CompileCode, unsigned int &instruction_number,
bool lhs_rhs, const temporary_terms_t &temporary_terms,
const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic,
const deriv_node_temp_terms_t &tef_terms) const
{
if (auto this2 = const_cast(this);
temporary_terms.contains(this2))
{
if (dynamic)
{
FLDT_ fldt(temporary_terms_idxs.at(this2));
fldt.write(CompileCode, instruction_number);
}
else
{
FLDST_ fldst(temporary_terms_idxs.at(this2));
fldst.write(CompileCode, instruction_number);
}
return;
}
if (op_code == UnaryOpcode::steadyState)
arg->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, temporary_terms_idxs, dynamic, true, tef_terms);
else
{
arg->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms);
FUNARY_ funary{static_cast(op_code)};
funary.write(CompileCode, instruction_number);
}
}
void
UnaryOpNode::collectVARLHSVariable(set &result) const
{
if (op_code == UnaryOpcode::diff)
result.insert(const_cast(this));
else
arg->collectVARLHSVariable(result);
}
void
UnaryOpNode::collectDynamicVariables(SymbolType type_arg, set> &result) const
{
arg->collectDynamicVariables(type_arg, result);
}
void
UnaryOpNode::computeSubExprContainingVariable(int symb_id, int lag, set &contain_var) const
{
arg->computeSubExprContainingVariable(symb_id, lag, contain_var);
if (contain_var.contains(arg))
contain_var.insert(const_cast(this));
}
BinaryOpNode *
UnaryOpNode::normalizeEquationHelper(const set &contain_var, expr_t rhs) const
{
assert(contain_var.contains(const_cast(this)));
switch (op_code)
{
case UnaryOpcode::uminus:
rhs = datatree.AddUMinus(rhs);
break;
case UnaryOpcode::exp:
rhs = datatree.AddLog(rhs);
break;
case UnaryOpcode::log:
rhs = datatree.AddExp(rhs);
break;
case UnaryOpcode::log10:
rhs = datatree.AddPower(datatree.AddNonNegativeConstant("10"), rhs);
break;
case UnaryOpcode::cos:
rhs = datatree.AddAcos(rhs);
break;
case UnaryOpcode::sin:
rhs = datatree.AddAsin(rhs);
break;
case UnaryOpcode::tan:
rhs = datatree.AddAtan(rhs);
break;
case UnaryOpcode::acos:
rhs = datatree.AddCos(rhs);
break;
case UnaryOpcode::asin:
rhs = datatree.AddSin(rhs);
break;
case UnaryOpcode::atan:
rhs = datatree.AddTan(rhs);
break;
case UnaryOpcode::cosh:
rhs = datatree.AddAcosh(rhs);
break;
case UnaryOpcode::sinh:
rhs = datatree.AddAsinh(rhs);
break;
case UnaryOpcode::tanh:
rhs = datatree.AddAtanh(rhs);
break;
case UnaryOpcode::acosh:
rhs = datatree.AddCosh(rhs);
break;
case UnaryOpcode::asinh:
rhs = datatree.AddSinh(rhs);
break;
case UnaryOpcode::atanh:
rhs = datatree.AddTanh(rhs);
break;
case UnaryOpcode::sqrt:
rhs = datatree.AddPower(rhs, datatree.Two);
break;
case UnaryOpcode::cbrt:
rhs = datatree.AddPower(rhs, datatree.Three);
break;
default:
throw NormalizationFailed();
}
return arg->normalizeEquationHelper(contain_var, rhs);
}
expr_t
UnaryOpNode::getChainRuleDerivative(int deriv_id, const map &recursive_variables)
{
expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables);
return composeDerivatives(darg, deriv_id);
}
expr_t
UnaryOpNode::buildSimilarUnaryOpNode(expr_t alt_arg, DataTree &alt_datatree) const
{
switch (op_code)
{
case UnaryOpcode::uminus:
return alt_datatree.AddUMinus(alt_arg);
case UnaryOpcode::exp:
return alt_datatree.AddExp(alt_arg);
case UnaryOpcode::log:
return alt_datatree.AddLog(alt_arg);
case UnaryOpcode::log10:
return alt_datatree.AddLog10(alt_arg);
case UnaryOpcode::cos:
return alt_datatree.AddCos(alt_arg);
case UnaryOpcode::sin:
return alt_datatree.AddSin(alt_arg);
case UnaryOpcode::tan:
return alt_datatree.AddTan(alt_arg);
case UnaryOpcode::acos:
return alt_datatree.AddAcos(alt_arg);
case UnaryOpcode::asin:
return alt_datatree.AddAsin(alt_arg);
case UnaryOpcode::atan:
return alt_datatree.AddAtan(alt_arg);
case UnaryOpcode::cosh:
return alt_datatree.AddCosh(alt_arg);
case UnaryOpcode::sinh:
return alt_datatree.AddSinh(alt_arg);
case UnaryOpcode::tanh:
return alt_datatree.AddTanh(alt_arg);
case UnaryOpcode::acosh:
return alt_datatree.AddAcosh(alt_arg);
case UnaryOpcode::asinh:
return alt_datatree.AddAsinh(alt_arg);
case UnaryOpcode::atanh:
return alt_datatree.AddAtanh(alt_arg);
case UnaryOpcode::sqrt:
return alt_datatree.AddSqrt(alt_arg);
case UnaryOpcode::cbrt:
return alt_datatree.AddCbrt(alt_arg);
case UnaryOpcode::abs:
return alt_datatree.AddAbs(alt_arg);
case UnaryOpcode::sign:
return alt_datatree.AddSign(alt_arg);
case UnaryOpcode::steadyState:
return alt_datatree.AddSteadyState(alt_arg);
case UnaryOpcode::steadyStateParamDeriv:
cerr << "UnaryOpNode::buildSimilarUnaryOpNode: UnaryOpcode::steadyStateParamDeriv can't be translated" << endl;
exit(EXIT_FAILURE);
case UnaryOpcode::steadyStateParam2ndDeriv:
cerr << "UnaryOpNode::buildSimilarUnaryOpNode: UnaryOpcode::steadyStateParam2ndDeriv can't be translated" << endl;
exit(EXIT_FAILURE);
case UnaryOpcode::expectation:
return alt_datatree.AddExpectation(expectation_information_set, alt_arg);
case UnaryOpcode::erf:
return alt_datatree.AddErf(alt_arg);
case UnaryOpcode::erfc:
return alt_datatree.AddErfc(alt_arg);
case UnaryOpcode::diff:
return alt_datatree.AddDiff(alt_arg);
case UnaryOpcode::adl:
return alt_datatree.AddAdl(alt_arg, adl_param_name, adl_lags);
}
// Suppress GCC warning
exit(EXIT_FAILURE);
}
expr_t
UnaryOpNode::toStatic(DataTree &static_datatree) const
{
expr_t sarg = arg->toStatic(static_datatree);
return buildSimilarUnaryOpNode(sarg, static_datatree);
}
void
UnaryOpNode::computeXrefs(EquationInfo &ei) const
{
arg->computeXrefs(ei);
}
expr_t
UnaryOpNode::clone(DataTree &datatree) const
{
expr_t substarg = arg->clone(datatree);
return buildSimilarUnaryOpNode(substarg, datatree);
}
int
UnaryOpNode::maxEndoLead() const
{
return arg->maxEndoLead();
}
int
UnaryOpNode::maxExoLead() const
{
return arg->maxExoLead();
}
int
UnaryOpNode::maxEndoLag() const
{
return arg->maxEndoLag();
}
int
UnaryOpNode::maxExoLag() const
{
return arg->maxExoLag();
}
int
UnaryOpNode::maxLead() const
{
return arg->maxLead();
}
int
UnaryOpNode::maxLag() const
{
return arg->maxLag();
}
int
UnaryOpNode::maxLagWithDiffsExpanded() const
{
if (op_code == UnaryOpcode::diff)
return arg->maxLagWithDiffsExpanded() + 1;
return arg->maxLagWithDiffsExpanded();
}
expr_t
UnaryOpNode::undiff() const
{
if (op_code == UnaryOpcode::diff)
return arg;
return arg->undiff();
}
int
UnaryOpNode::VarMaxLag(const set &lhs_lag_equiv) const
{
auto [lag_equiv_repr, index] = getLagEquivalenceClass();
if (lhs_lag_equiv.contains(lag_equiv_repr))
return arg->maxLag();
else
return 0;
}
expr_t
UnaryOpNode::substituteAdl() const
{
if (op_code != UnaryOpcode::adl)
{
expr_t argsubst = arg->substituteAdl();
return buildSimilarUnaryOpNode(argsubst, datatree);
}
expr_t arg1subst = arg->substituteAdl();
expr_t retval = nullptr;
ostringstream inttostr;
for (auto it = adl_lags.begin(); it != adl_lags.end(); ++it)
if (it == adl_lags.begin())
{
inttostr << *it;
retval = datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.getID(adl_param_name + "_lag_" + inttostr.str()), 0),
arg1subst->decreaseLeadsLags(*it));
}
else
{
inttostr.clear();
inttostr.str("");
inttostr << *it;
retval = datatree.AddPlus(retval,
datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.getID(adl_param_name + "_lag_"
+ inttostr.str()), 0),
arg1subst->decreaseLeadsLags(*it)));
}
return retval;
}
expr_t
UnaryOpNode::substituteModelLocalVariables() const
{
expr_t argsubst = arg->substituteModelLocalVariables();
return buildSimilarUnaryOpNode(argsubst, datatree);
}
expr_t
UnaryOpNode::substituteVarExpectation(const map &subst_table) const
{
expr_t argsubst = arg->substituteVarExpectation(subst_table);
return buildSimilarUnaryOpNode(argsubst, datatree);
}
int
UnaryOpNode::countDiffs() const
{
if (op_code == UnaryOpcode::diff)
return arg->countDiffs() + 1;
return arg->countDiffs();
}
bool
UnaryOpNode::createAuxVarForUnaryOpNode() const
{
switch (op_code)
{
case UnaryOpcode::exp:
case UnaryOpcode::log:
case UnaryOpcode::log10:
case UnaryOpcode::cos:
case UnaryOpcode::sin:
case UnaryOpcode::tan:
case UnaryOpcode::acos:
case UnaryOpcode::asin:
case UnaryOpcode::atan:
case UnaryOpcode::cosh:
case UnaryOpcode::sinh:
case UnaryOpcode::tanh:
case UnaryOpcode::acosh:
case UnaryOpcode::asinh:
case UnaryOpcode::atanh:
case UnaryOpcode::sqrt:
case UnaryOpcode::cbrt:
case UnaryOpcode::abs:
case UnaryOpcode::sign:
case UnaryOpcode::erf:
case UnaryOpcode::erfc:
return true;
default:
return false;
}
}
void
UnaryOpNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
{
arg->findUnaryOpNodesForAuxVarCreation(nodes);
if (!this->createAuxVarForUnaryOpNode())
return;
auto [lag_equiv_repr, index] = getLagEquivalenceClass();
nodes[lag_equiv_repr][index] = const_cast(this);
}
void
UnaryOpNode::findDiffNodes(lag_equivalence_table_t &nodes) const
{
arg->findDiffNodes(nodes);
if (op_code != UnaryOpcode::diff)
return;
auto [lag_equiv_repr, index] = getLagEquivalenceClass();
nodes[lag_equiv_repr][index] = const_cast(this);
}
optional
UnaryOpNode::findTargetVariable(int lhs_symb_id) const
{
return arg->findTargetVariable(lhs_symb_id);
}
expr_t
UnaryOpNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
vector &neweqs) const
{
// If this is not a diff node, then substitute recursively and return
expr_t argsubst = arg->substituteDiff(nodes, subst_table, neweqs);
if (op_code != UnaryOpcode::diff)
return buildSimilarUnaryOpNode(argsubst, datatree);
if (auto sit = subst_table.find(this);
sit != subst_table.end())
return const_cast(sit->second);
auto [lag_equiv_repr, index] = getLagEquivalenceClass();
auto it = nodes.find(lag_equiv_repr);
if (it == nodes.end() || it->second.find(index) == it->second.end()
|| it->second.at(index) != this)
{
/* diff does not appear in VAR equations, so simply create aux var and return.
Once the comparison of expression nodes works, come back and remove
this part, folding into the next loop. */
int symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, const_cast(this));
VariableNode *aux_var = datatree.AddVariable(symb_id, 0);
neweqs.push_back(datatree.AddEqual(aux_var,
datatree.AddMinus(argsubst,
argsubst->decreaseLeadsLags(1))));
subst_table[this] = dynamic_cast(aux_var);
return const_cast(subst_table[this]);
}
/* At this point, we know that this node (and its lagged/leaded brothers)
must be substituted. We create the auxiliary variable and fill the
substitution table for all those similar nodes, in an iteration going from
leads to lags. */
int last_index = 0;
VariableNode *last_aux_var = nullptr;
for (auto rit = it->second.rbegin(); rit != it->second.rend(); ++rit)
{
expr_t argsubst = dynamic_cast(rit->second)->
arg->substituteDiff(nodes, subst_table, neweqs);
auto vn = dynamic_cast(argsubst);
int symb_id;
if (rit == it->second.rbegin())
{
if (vn)
symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, rit->second, vn->symb_id, vn->lag);
else
symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, rit->second);
// make originating aux var & equation
last_index = rit->first;
last_aux_var = datatree.AddVariable(symb_id, 0);
//ORIG_AUX_DIFF = argsubst - argsubst(-1)
neweqs.push_back(datatree.AddEqual(last_aux_var,
datatree.AddMinus(argsubst,
argsubst->decreaseLeadsLags(1))));
subst_table[rit->second] = dynamic_cast(last_aux_var);
}
else
{
// just add equation of form: AUX_DIFF = LAST_AUX_VAR(-1)
VariableNode *new_aux_var = nullptr;
for (int i = last_index; i > rit->first; i--)
{
if (i == last_index)
symb_id = datatree.symbol_table.addDiffLagAuxiliaryVar(argsubst->idx, rit->second,
last_aux_var->symb_id, -1);
else
symb_id = datatree.symbol_table.addDiffLagAuxiliaryVar(new_aux_var->idx, rit->second,
last_aux_var->symb_id, -1);
new_aux_var = datatree.AddVariable(symb_id, 0);
neweqs.push_back(datatree.AddEqual(new_aux_var,
last_aux_var->decreaseLeadsLags(1)));
last_aux_var = new_aux_var;
}
subst_table[rit->second] = dynamic_cast(new_aux_var);
last_index = rit->first;
}
}
return const_cast(subst_table[this]);
}
expr_t
UnaryOpNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector &neweqs) const
{
if (auto sit = subst_table.find(this);
sit != subst_table.end())
return const_cast(sit->second);
/* If the equivalence class of this node is not marked for substitution,
then substitute recursively and return. */
auto [lag_equiv_repr, index] = getLagEquivalenceClass();
auto it = nodes.find(lag_equiv_repr);
expr_t argsubst = arg->substituteUnaryOpNodes(nodes, subst_table, neweqs);
if (it == nodes.end())
return buildSimilarUnaryOpNode(argsubst, datatree);
string unary_op;
switch (op_code)
{
case UnaryOpcode::exp:
unary_op = "exp";
break;
case UnaryOpcode::log:
unary_op = "log";
break;
case UnaryOpcode::log10:
unary_op = "log10";
break;
case UnaryOpcode::cos:
unary_op = "cos";
break;
case UnaryOpcode::sin:
unary_op = "sin";
break;
case UnaryOpcode::tan:
unary_op = "tan";
break;
case UnaryOpcode::acos:
unary_op = "acos";
break;
case UnaryOpcode::asin:
unary_op = "asin";
break;
case UnaryOpcode::atan:
unary_op = "atan";
break;
case UnaryOpcode::cosh:
unary_op = "cosh";
break;
case UnaryOpcode::sinh:
unary_op = "sinh";
break;
case UnaryOpcode::tanh:
unary_op = "tanh";
break;
case UnaryOpcode::acosh:
unary_op = "acosh";
break;
case UnaryOpcode::asinh:
unary_op = "asinh";
break;
case UnaryOpcode::atanh:
unary_op = "atanh";
break;
case UnaryOpcode::sqrt:
unary_op = "sqrt";
break;
case UnaryOpcode::cbrt:
unary_op = "cbrt";
break;
case UnaryOpcode::abs:
unary_op = "abs";
break;
case UnaryOpcode::sign:
unary_op = "sign";
break;
case UnaryOpcode::erf:
unary_op = "erf";
break;
case UnaryOpcode::erfc:
unary_op = "erfc";
break;
default:
cerr << "UnaryOpNode::substituteUnaryOpNodes: Shouldn't arrive here" << endl;
exit(EXIT_FAILURE);
}
/* At this point, we know that this node (and its lagged/leaded brothers)
must be substituted. We create the auxiliary variable and fill the
substitution table for all those similar nodes, in an iteration going from
leads to lags. */
int base_index = it->second.rbegin()->first; // Within the equivalence class,
// index of the node that will
// be used as the definition for
// the aux var.
VariableNode *aux_var = nullptr;
for (auto rit = it->second.rbegin(); rit != it->second.rend(); ++rit)
if (rit == it->second.rbegin())
{
/* Verify that we’re not operating on a node with leads, since the
transformation does not take into account the expectation operator. We only
need to do this for the first iteration of the loop, because we’re
going from leads to lags. */
if (rit->second->maxLead() > 0)
{
cerr << "Cannot substitute unary operations that contain leads" << endl;
exit(EXIT_FAILURE);
}
auto argsubst_shifted = argsubst->decreaseLeadsLags(index - base_index);
auto aux_def = buildSimilarUnaryOpNode(argsubst_shifted, datatree);
int symb_id;
if (auto vn = dynamic_cast(argsubst_shifted); !vn)
symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, aux_def, unary_op);
else
symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, aux_def, unary_op,
vn->symb_id, vn->lag);
aux_var = datatree.AddVariable(symb_id, 0);
neweqs.push_back(datatree.AddEqual(aux_var, aux_def));
subst_table[rit->second] = dynamic_cast(aux_var);
}
else
subst_table[rit->second] = dynamic_cast(aux_var->decreaseLeadsLags(base_index - rit->first));
assert(subst_table.contains(this));
return const_cast(subst_table.at(this));
}
expr_t
UnaryOpNode::substitutePacExpectation(const string &name, expr_t subexpr)
{
expr_t argsubst = arg->substitutePacExpectation(name, subexpr);
return buildSimilarUnaryOpNode(argsubst, datatree);
}
expr_t
UnaryOpNode::substitutePacTargetNonstationary(const string &name, expr_t subexpr)
{
expr_t argsubst = arg->substitutePacTargetNonstationary(name, subexpr);
return buildSimilarUnaryOpNode(argsubst, datatree);
}
expr_t
UnaryOpNode::decreaseLeadsLags(int n) const
{
expr_t argsubst = arg->decreaseLeadsLags(n);
return buildSimilarUnaryOpNode(argsubst, datatree);
}
expr_t
UnaryOpNode::decreaseLeadsLagsPredeterminedVariables() const
{
expr_t argsubst = arg->decreaseLeadsLagsPredeterminedVariables();
return buildSimilarUnaryOpNode(argsubst, datatree);
}
expr_t
UnaryOpNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector &neweqs, bool deterministic_model) const
{
if (op_code == UnaryOpcode::uminus || deterministic_model)
{
expr_t argsubst = arg->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model);
return buildSimilarUnaryOpNode(argsubst, datatree);
}
else
{
if (maxEndoLead() >= 2)
return createEndoLeadAuxiliaryVarForMyself(subst_table, neweqs);
else
return const_cast(this);
}
}
expr_t
UnaryOpNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector &neweqs) const
{
expr_t argsubst = arg->substituteEndoLagGreaterThanTwo(subst_table, neweqs);
return buildSimilarUnaryOpNode(argsubst, datatree);
}
expr_t
UnaryOpNode::substituteExoLead(subst_table_t &subst_table, vector &neweqs, bool deterministic_model) const
{
if (op_code == UnaryOpcode::uminus || deterministic_model)
{
expr_t argsubst = arg->substituteExoLead(subst_table, neweqs, deterministic_model);
return buildSimilarUnaryOpNode(argsubst, datatree);
}
else
{
if (maxExoLead() >= 1)
return createExoLeadAuxiliaryVarForMyself(subst_table, neweqs);
else
return const_cast(this);
}
}
expr_t
UnaryOpNode::substituteExoLag(subst_table_t &subst_table, vector &neweqs) const
{
expr_t argsubst = arg->substituteExoLag(subst_table, neweqs);
return buildSimilarUnaryOpNode(argsubst, datatree);
}
expr_t
UnaryOpNode::substituteExpectation(subst_table_t &subst_table, vector &neweqs, bool partial_information_model) const
{
if (op_code == UnaryOpcode::expectation)
{
if (auto it = subst_table.find(const_cast(this)); it != subst_table.end())
return const_cast(it->second);
//Arriving here, we need to create an auxiliary variable for this Expectation Operator:
//AUX_EXPECT_(LEAD/LAG)_(period)_(arg.idx) OR
//AUX_EXPECT_(info_set_name)_(arg.idx)
int symb_id = datatree.symbol_table.addExpectationAuxiliaryVar(expectation_information_set, arg->idx, const_cast(this));
expr_t newAuxE = datatree.AddVariable(symb_id, 0);
if (partial_information_model && expectation_information_set == 0)
if (!dynamic_cast(arg))
{
cerr << "ERROR: In Partial Information models, EXPECTATION(0)(X) "
<< "can only be used when X is a single variable." << endl;
exit(EXIT_FAILURE);
}
//take care of any nested expectation operators by calling arg->substituteExpectation(.), then decreaseLeadsLags for this UnaryOpcode::expectation operator
//arg(lag-period) (holds entire subtree of arg(lag-period)
expr_t substexpr = (arg->substituteExpectation(subst_table, neweqs, partial_information_model))->decreaseLeadsLags(expectation_information_set);
assert(substexpr);
neweqs.push_back(datatree.AddEqual(newAuxE, substexpr)); //AUXE_period_arg.idx = arg(lag-period)
newAuxE = datatree.AddVariable(symb_id, expectation_information_set);
assert(dynamic_cast(newAuxE));
subst_table[this] = dynamic_cast(newAuxE);
return newAuxE;
}
else
{
expr_t argsubst = arg->substituteExpectation(subst_table, neweqs, partial_information_model);
return buildSimilarUnaryOpNode(argsubst, datatree);
}
}
expr_t
UnaryOpNode::differentiateForwardVars(const vector &subset, subst_table_t &subst_table, vector &neweqs) const
{
expr_t argsubst = arg->differentiateForwardVars(subset, subst_table, neweqs);
return buildSimilarUnaryOpNode(argsubst, datatree);
}
bool
UnaryOpNode::isNumConstNodeEqualTo(double value) const
{
return false;
}
bool
UnaryOpNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
{
return false;
}
bool
UnaryOpNode::containsPacExpectation(const string &pac_model_name) const
{
return arg->containsPacExpectation(pac_model_name);
}
bool
UnaryOpNode::containsPacTargetNonstationary(const string &pac_model_name) const
{
return arg->containsPacTargetNonstationary(pac_model_name);
}
expr_t
UnaryOpNode::replaceTrendVar() const
{
expr_t argsubst = arg->replaceTrendVar();
return buildSimilarUnaryOpNode(argsubst, datatree);
}
expr_t
UnaryOpNode::detrend(int symb_id, bool log_trend, expr_t trend) const
{
expr_t argsubst = arg->detrend(symb_id, log_trend, trend);
return buildSimilarUnaryOpNode(argsubst, datatree);
}
expr_t
UnaryOpNode::removeTrendLeadLag(const map &trend_symbols_map) const
{
expr_t argsubst = arg->removeTrendLeadLag(trend_symbols_map);
return buildSimilarUnaryOpNode(argsubst, datatree);
}
bool
UnaryOpNode::isInStaticForm() const
{
if (op_code == UnaryOpcode::steadyState || op_code == UnaryOpcode::steadyStateParamDeriv
|| op_code == UnaryOpcode::steadyStateParam2ndDeriv
|| op_code == UnaryOpcode::expectation)
return false;
else
return arg->isInStaticForm();
}
bool
UnaryOpNode::isParamTimesEndogExpr() const
{
return arg->isParamTimesEndogExpr();
}
expr_t
UnaryOpNode::substituteStaticAuxiliaryVariable() const
{
if (op_code == UnaryOpcode::diff)
return datatree.Zero;
expr_t argsubst = arg->substituteStaticAuxiliaryVariable();
if (op_code == UnaryOpcode::expectation)
return argsubst;
else
return buildSimilarUnaryOpNode(argsubst, datatree);
}
expr_t
UnaryOpNode::replaceVarsInEquation(map &table) const
{
expr_t argsubst = arg->replaceVarsInEquation(table);
return buildSimilarUnaryOpNode(argsubst, datatree);
}
expr_t
UnaryOpNode::substituteLogTransform(int orig_symb_id, int aux_symb_id) const
{
expr_t argsubst = arg->substituteLogTransform(orig_symb_id, aux_symb_id);
return buildSimilarUnaryOpNode(argsubst, datatree);
}
BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, int idx_arg, const expr_t arg1_arg,
BinaryOpcode op_code_arg, const expr_t arg2_arg, int powerDerivOrder_arg) :
ExprNode{datatree_arg, idx_arg},
arg1{arg1_arg},
arg2{arg2_arg},
op_code{op_code_arg},
powerDerivOrder{powerDerivOrder_arg}
{
assert(powerDerivOrder >= 0);
}
void
BinaryOpNode::prepareForDerivation()
{
if (preparedForDerivation)
return;
preparedForDerivation = true;
arg1->prepareForDerivation();
arg2->prepareForDerivation();
// Non-null derivatives are the union of those of the arguments
// Compute set union of arg1->non_null_derivatives and arg2->non_null_derivatives
set_union(arg1->non_null_derivatives.begin(),
arg1->non_null_derivatives.end(),
arg2->non_null_derivatives.begin(),
arg2->non_null_derivatives.end(),
inserter(non_null_derivatives, non_null_derivatives.begin()));
}
expr_t
BinaryOpNode::getNonZeroPartofEquation() const
{
assert(arg1 == datatree.Zero || arg2 == datatree.Zero);
if (arg1 == datatree.Zero)
return arg2;
return arg1;
}
expr_t
BinaryOpNode::composeDerivatives(expr_t darg1, expr_t darg2)
{
expr_t t11, t12, t13, t14, t15;
switch (op_code)
{
case BinaryOpcode::plus:
return datatree.AddPlus(darg1, darg2);
case BinaryOpcode::minus:
case BinaryOpcode::equal:
return datatree.AddMinus(darg1, darg2);
case BinaryOpcode::times:
t11 = datatree.AddTimes(darg1, arg2);
t12 = datatree.AddTimes(darg2, arg1);
return datatree.AddPlus(t11, t12);
case BinaryOpcode::divide:
if (darg2 != datatree.Zero)
{
t11 = datatree.AddTimes(darg1, arg2);
t12 = datatree.AddTimes(darg2, arg1);
t13 = datatree.AddMinus(t11, t12);
t14 = datatree.AddTimes(arg2, arg2);
return datatree.AddDivide(t13, t14);
}
else
return datatree.AddDivide(darg1, arg2);
case BinaryOpcode::less:
case BinaryOpcode::greater:
case BinaryOpcode::lessEqual:
case BinaryOpcode::greaterEqual:
case BinaryOpcode::equalEqual:
case BinaryOpcode::different:
return datatree.Zero;
case BinaryOpcode::power:
if (darg2 == datatree.Zero)
if (darg1 == datatree.Zero)
return datatree.Zero;
else
if (dynamic_cast(arg2))
{
t11 = datatree.AddMinus(arg2, datatree.One);
t12 = datatree.AddPower(arg1, t11);
t13 = datatree.AddTimes(arg2, t12);
return datatree.AddTimes(darg1, t13);
}
else
return datatree.AddTimes(darg1, datatree.AddPowerDeriv(arg1, arg2, powerDerivOrder + 1));
else
{
t11 = datatree.AddLog(arg1);
t12 = datatree.AddTimes(darg2, t11);
t13 = datatree.AddTimes(darg1, arg2);
t14 = datatree.AddDivide(t13, arg1);
t15 = datatree.AddPlus(t12, t14);
return datatree.AddTimes(t15, this);
}
case BinaryOpcode::powerDeriv:
if (darg2 == datatree.Zero)
return datatree.AddTimes(darg1, datatree.AddPowerDeriv(arg1, arg2, powerDerivOrder + 1));
else
{
t11 = datatree.AddTimes(darg2, datatree.AddLog(arg1));
t12 = datatree.AddMinus(arg2, datatree.AddPossiblyNegativeConstant(powerDerivOrder));
t13 = datatree.AddTimes(darg1, t12);
t14 = datatree.AddDivide(t13, arg1);
t15 = datatree.AddPlus(t11, t14);
expr_t f = datatree.AddPower(arg1, t12);
expr_t first_part = datatree.AddTimes(f, t15);
for (int i = 0; i < powerDerivOrder; i++)
first_part = datatree.AddTimes(first_part, datatree.AddMinus(arg2, datatree.AddPossiblyNegativeConstant(i)));
t13 = datatree.Zero;
for (int i = 0; i < powerDerivOrder; i++)
{
t11 = datatree.One;
for (int j = 0; j < powerDerivOrder; j++)
if (i != j)
{
t12 = datatree.AddMinus(arg2, datatree.AddPossiblyNegativeConstant(j));
t11 = datatree.AddTimes(t11, t12);
}
t13 = datatree.AddPlus(t13, t11);
}
t13 = datatree.AddTimes(darg2, t13);
t14 = datatree.AddTimes(f, t13);
return datatree.AddPlus(first_part, t14);
}
case BinaryOpcode::max:
t11 = datatree.AddGreater(arg1, arg2);
t12 = datatree.AddTimes(t11, darg1);
t13 = datatree.AddMinus(datatree.One, t11);
t14 = datatree.AddTimes(t13, darg2);
return datatree.AddPlus(t14, t12);
case BinaryOpcode::min:
t11 = datatree.AddGreater(arg2, arg1);
t12 = datatree.AddTimes(t11, darg1);
t13 = datatree.AddMinus(datatree.One, t11);
t14 = datatree.AddTimes(t13, darg2);
return datatree.AddPlus(t14, t12);
}
// Suppress GCC warning
exit(EXIT_FAILURE);
}
expr_t
BinaryOpNode::unpackPowerDeriv() const
{
if (op_code != BinaryOpcode::powerDeriv)
return const_cast(this);
expr_t front = datatree.One;
for (int i = 0; i < powerDerivOrder; i++)
front = datatree.AddTimes(front,
datatree.AddMinus(arg2,
datatree.AddPossiblyNegativeConstant(i)));
expr_t tmp = datatree.AddPower(arg1,
datatree.AddMinus(arg2,
datatree.AddPossiblyNegativeConstant(powerDerivOrder)));
return datatree.AddTimes(front, tmp);
}
expr_t
BinaryOpNode::computeDerivative(int deriv_id)
{
expr_t darg1 = arg1->getDerivative(deriv_id);
expr_t darg2 = arg2->getDerivative(deriv_id);
return composeDerivatives(darg1, darg2);
}
int
BinaryOpNode::precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const
{
// A temporary term behaves as a variable
if (temporary_terms.contains(const_cast(this)))
return 100;
switch (op_code)
{
case BinaryOpcode::equal:
return 0;
case BinaryOpcode::equalEqual:
case BinaryOpcode::different:
return 1;
case BinaryOpcode::lessEqual:
case BinaryOpcode::greaterEqual:
case BinaryOpcode::less:
case BinaryOpcode::greater:
return 2;
case BinaryOpcode::plus:
case BinaryOpcode::minus:
return 3;
case BinaryOpcode::times:
case BinaryOpcode::divide:
return 4;
case BinaryOpcode::power:
case BinaryOpcode::powerDeriv:
if (isCOutput(output_type))
// In C, power operator is of the form pow(a, b)
return 100;
else
return 5;
case BinaryOpcode::min:
case BinaryOpcode::max:
return 100;
}
// Suppress GCC warning
exit(EXIT_FAILURE);
}
int
BinaryOpNode::precedenceJson(const temporary_terms_t &temporary_terms) const
{
// A temporary term behaves as a variable
if (temporary_terms.contains(const_cast(this)))
return 100;
switch (op_code)
{
case BinaryOpcode::equal:
return 0;
case BinaryOpcode::equalEqual:
case BinaryOpcode::different:
return 1;
case BinaryOpcode::lessEqual:
case BinaryOpcode::greaterEqual:
case BinaryOpcode::less:
case BinaryOpcode::greater:
return 2;
case BinaryOpcode::plus:
case BinaryOpcode::minus:
return 3;
case BinaryOpcode::times:
case BinaryOpcode::divide:
return 4;
case BinaryOpcode::power:
case BinaryOpcode::powerDeriv:
return 5;
case BinaryOpcode::min:
case BinaryOpcode::max:
return 100;
}
// Suppress GCC warning
exit(EXIT_FAILURE);
}
int
BinaryOpNode::cost(const map, temporary_terms_t> &temp_terms_map, bool is_matlab) const
{
// For a temporary term, the cost is null
for (const auto &it : temp_terms_map)
if (it.second.contains(const_cast(this)))
return 0;
int arg_cost = arg1->cost(temp_terms_map, is_matlab) + arg2->cost(temp_terms_map, is_matlab);
return cost(arg_cost, is_matlab);
}
int
BinaryOpNode::cost(const vector> &blocks_temporary_terms, bool is_matlab) const
{
// For a temporary term, the cost is null
for (const auto &blk_tt : blocks_temporary_terms)
for (const auto &eq_tt : blk_tt)
if (eq_tt.contains(const_cast(this)))
return 0;
int arg_cost = arg1->cost(blocks_temporary_terms, is_matlab) + arg2->cost(blocks_temporary_terms, is_matlab);
return cost(arg_cost, is_matlab);
}
int
BinaryOpNode::cost(int cost, bool is_matlab) const
{
if (is_matlab)
// Cost for Matlab files
switch (op_code)
{
case BinaryOpcode::less:
case BinaryOpcode::greater:
case BinaryOpcode::lessEqual:
case BinaryOpcode::greaterEqual:
case BinaryOpcode::equalEqual:
case BinaryOpcode::different:
return cost + 60;
case BinaryOpcode::plus:
case BinaryOpcode::minus:
case BinaryOpcode::times:
return cost + 90;
case BinaryOpcode::max:
case BinaryOpcode::min:
return cost + 110;
case BinaryOpcode::divide:
return cost + 990;
case BinaryOpcode::power:
case BinaryOpcode::powerDeriv:
return cost + (min_cost_matlab/2+1);
case BinaryOpcode::equal:
return cost;
}
else
// Cost for C files
switch (op_code)
{
case BinaryOpcode::less:
case BinaryOpcode::greater:
case BinaryOpcode::lessEqual:
case BinaryOpcode::greaterEqual:
case BinaryOpcode::equalEqual:
case BinaryOpcode::different:
return cost + 2;
case BinaryOpcode::plus:
case BinaryOpcode::minus:
case BinaryOpcode::times:
return cost + 4;
case BinaryOpcode::max:
case BinaryOpcode::min:
return cost + 5;
case BinaryOpcode::divide:
return cost + 15;
case BinaryOpcode::power:
return cost + 520;
case BinaryOpcode::powerDeriv:
return cost + (min_cost_c/2+1);
case BinaryOpcode::equal:
return cost;
}
// Suppress GCC warning
exit(EXIT_FAILURE);
}
void
BinaryOpNode::computeTemporaryTerms(const pair &derivOrder,
map, temporary_terms_t> &temp_terms_map,
map>> &reference_count,
bool is_matlab) const
{
expr_t this2 = const_cast(this);
if (auto it = reference_count.find(this2);
it == reference_count.end())
{
// If this node has never been encountered, set its ref count to one,
// and travel through its children
reference_count[this2] = { 1, derivOrder };
arg1->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
arg2->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
}
else
{
/* If the node has already been encountered, increment its ref count
and declare it as a temporary term if it is too costly (except if it is
an equal node: we don't want them as temporary terms) */
auto &[nref, min_order] = it->second;
nref++;
if (nref * cost(temp_terms_map, is_matlab) > min_cost(is_matlab)
&& op_code != BinaryOpcode::equal)
temp_terms_map[min_order].insert(this2);
}
}
void
BinaryOpNode::computeBlockTemporaryTerms(int blk, int eq, vector