ExprNode and its subclasses are no longer friends of DataTree

This ensures a better isolation between the container and the containees.
issue#70
Sébastien Villemot 2018-09-05 17:41:11 +02:00
parent 004d909621
commit c6d4cb88c3
2 changed files with 47 additions and 39 deletions

View File

@ -38,19 +38,7 @@ using namespace std;
class DataTree
{
friend class ExprNode;
friend class NumConstNode;
friend class VariableNode;
friend class UnaryOpNode;
friend class BinaryOpNode;
friend class TrinaryOpNode;
friend class AbstractExternalFunctionNode;
friend class ExternalFunctionNode;
friend class FirstDerivExternalFunctionNode;
friend class SecondDerivExternalFunctionNode;
friend class VarExpectationNode;
friend class PacExpectationNode;
protected:
public:
//! A reference to the symbol table
SymbolTable &symbol_table;
//! Reference to numerical constants table
@ -62,6 +50,7 @@ protected:
//! A reference to the VAR model table
VarModelTable &var_model_table;
protected:
//! num_constant_id -> NumConstNode
using num_const_node_map_t = map<int, NumConstNode *>;
num_const_node_map_t num_const_node_map;
@ -126,7 +115,6 @@ private:
//! The list of nodes
vector<unique_ptr<ExprNode>> node_list;
inline expr_t AddPossiblyNegativeConstant(double val);
inline expr_t AddUnaryOp(UnaryOpcode op_code, expr_t arg, int arg_exp_info_set = 0, int param1_symb_id = 0, int param2_symb_id = 0, const string &adl_param_name = "", const vector<int> &adl_lags = vector<int>());
inline expr_t AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder = 0);
inline expr_t AddTrinaryOp(expr_t arg1, TrinaryOpcode op_code, expr_t arg2, expr_t arg3);
@ -157,6 +145,7 @@ public:
{
};
inline expr_t AddPossiblyNegativeConstant(double val);
//! Adds a non-negative numerical constant (possibly Inf or NaN)
expr_t AddNonNegativeConstant(const string &value);
//! Adds a variable
@ -316,6 +305,25 @@ public:
{
return false;
};
class UnknownLocalVariableException
{
public:
//! Symbol ID
int id;
UnknownLocalVariableException(int id_arg) : id(id_arg)
{
}
};
expr_t getLocalVariable(int symb_id) const
{
auto it = local_variables_table.find(symb_id);
if (it == local_variables_table.end())
throw UnknownLocalVariableException(symb_id);
return it->second;
}
};
inline expr_t

View File

@ -727,9 +727,9 @@ VariableNode::prepareForDerivation()
non_null_derivatives.insert(datatree.getDerivID(symb_id, lag));
break;
case SymbolType::modelLocalVariable:
datatree.local_variables_table[symb_id]->prepareForDerivation();
datatree.getLocalVariable(symb_id)->prepareForDerivation();
// Non null derivatives are those of the value of the local parameter
non_null_derivatives = datatree.local_variables_table[symb_id]->non_null_derivatives;
non_null_derivatives = datatree.getLocalVariable(symb_id)->non_null_derivatives;
break;
case SymbolType::modFileLocalVariable:
case SymbolType::statementDeclaredVariable:
@ -762,7 +762,7 @@ VariableNode::computeDerivative(int deriv_id)
else
return datatree.Zero;
case SymbolType::modelLocalVariable:
return datatree.local_variables_table[symb_id]->getDerivative(deriv_id);
return datatree.getLocalVariable(symb_id)->getDerivative(deriv_id);
case SymbolType::modFileLocalVariable:
cerr << "ModFileLocalVariable is not derivable" << endl;
exit(EXIT_FAILURE);
@ -791,7 +791,7 @@ VariableNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, t
if (it != temporary_terms.end())
temporary_terms_inuse.insert(idx);
if (type == SymbolType::modelLocalVariable)
datatree.local_variables_table[symb_id]->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block);
datatree.getLocalVariable(symb_id)->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block);
}
bool
@ -866,7 +866,7 @@ VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
|| output_type == oCDynamicSteadyStateOperator)
{
output << "(";
datatree.local_variables_table[symb_id]->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
datatree.getLocalVariable(symb_id)->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
output << ")";
}
else
@ -1108,7 +1108,7 @@ VariableNode::compile(ostream &CompileCode, unsigned int &instruction_number,
const deriv_node_temp_terms_t &tef_terms) const
{
if (type == SymbolType::modelLocalVariable || type == SymbolType::modFileLocalVariable)
datatree.local_variables_table[symb_id]->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms);
datatree.getLocalVariable(symb_id)->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms);
else
{
int tsid = datatree.symbol_table.getTypeSpecificID(symb_id);
@ -1184,7 +1184,7 @@ VariableNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
int equation) const
{
if (type == SymbolType::modelLocalVariable)
datatree.local_variables_table[symb_id]->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation);
datatree.getLocalVariable(symb_id)->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation);
}
void
@ -1205,7 +1205,7 @@ VariableNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &
if (type == type_arg)
result.emplace(symb_id, lag);
if (type == SymbolType::modelLocalVariable)
datatree.local_variables_table[symb_id]->collectDynamicVariables(type_arg, result);
datatree.getLocalVariable(symb_id)->collectDynamicVariables(type_arg, result);
}
pair<int, expr_t>
@ -1229,14 +1229,14 @@ VariableNode::normalizeEquation(int var_endo, vector<pair<int, pair<expr_t, expr
/* the endogenous variable */
return { 1, nullptr };
else
return { 0, datatree.AddVariableInternal(symb_id, lag) };
return { 0, datatree.AddVariable(symb_id, lag) };
}
else
{
if (type == SymbolType::parameter)
return { 0, datatree.AddVariableInternal(symb_id, 0) };
return { 0, datatree.AddVariable(symb_id, 0) };
else
return { 0, datatree.AddVariableInternal(symb_id, lag) };
return { 0, datatree.AddVariable(symb_id, lag) };
}
}
@ -1277,7 +1277,7 @@ VariableNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recur
return datatree.Zero;
}
case SymbolType::modelLocalVariable:
return datatree.local_variables_table[symb_id]->getChainRuleDerivative(deriv_id, recursive_variables);
return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables);
case SymbolType::modFileLocalVariable:
cerr << "ModFileLocalVariable is not derivable" << endl;
exit(EXIT_FAILURE);
@ -1351,7 +1351,7 @@ VariableNode::maxEndoLead() const
case SymbolType::endogenous:
return max(lag, 0);
case SymbolType::modelLocalVariable:
return datatree.local_variables_table[symb_id]->maxEndoLead();
return datatree.getLocalVariable(symb_id)->maxEndoLead();
default:
return 0;
}
@ -1365,7 +1365,7 @@ VariableNode::maxExoLead() const
case SymbolType::exogenous:
return max(lag, 0);
case SymbolType::modelLocalVariable:
return datatree.local_variables_table[symb_id]->maxExoLead();
return datatree.getLocalVariable(symb_id)->maxExoLead();
default:
return 0;
}
@ -1379,7 +1379,7 @@ VariableNode::maxEndoLag() const
case SymbolType::endogenous:
return max(-lag, 0);
case SymbolType::modelLocalVariable:
return datatree.local_variables_table[symb_id]->maxEndoLag();
return datatree.getLocalVariable(symb_id)->maxEndoLag();
default:
return 0;
}
@ -1393,7 +1393,7 @@ VariableNode::maxExoLag() const
case SymbolType::exogenous:
return max(-lag, 0);
case SymbolType::modelLocalVariable:
return datatree.local_variables_table[symb_id]->maxExoLag();
return datatree.getLocalVariable(symb_id)->maxExoLag();
default:
return 0;
}
@ -1409,7 +1409,7 @@ VariableNode::maxLead() const
case SymbolType::exogenous:
return lag;
case SymbolType::modelLocalVariable:
return datatree.local_variables_table[symb_id]->maxLead();
return datatree.getLocalVariable(symb_id)->maxLead();
default:
return 0;
}
@ -1428,7 +1428,7 @@ VariableNode::VarMinLag() const
else
return 1; // Can have contemporaneus exog in VAR
case SymbolType::modelLocalVariable:
return datatree.local_variables_table[symb_id]->VarMinLag();
return datatree.getLocalVariable(symb_id)->VarMinLag();
default:
return 1;
}
@ -1444,7 +1444,7 @@ VariableNode::maxLag() const
case SymbolType::exogenous:
return -lag;
case SymbolType::modelLocalVariable:
return datatree.local_variables_table[symb_id]->maxLag();
return datatree.getLocalVariable(symb_id)->maxLag();
default:
return 0;
}
@ -1532,7 +1532,7 @@ VariableNode::decreaseLeadsLags(int n) const
case SymbolType::logTrend:
return datatree.AddVariable(symb_id, lag-n);
case SymbolType::modelLocalVariable:
return datatree.local_variables_table[symb_id]->decreaseLeadsLags(n);
return datatree.getLocalVariable(symb_id)->decreaseLeadsLags(n);
default:
return const_cast<VariableNode *>(this);
}
@ -1559,7 +1559,7 @@ VariableNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vecto
else
return createEndoLeadAuxiliaryVarForMyself(subst_table, neweqs);
case SymbolType::modelLocalVariable:
value = datatree.local_variables_table[symb_id];
value = datatree.getLocalVariable(symb_id);
if (value->maxEndoLead() <= 1)
return const_cast<VariableNode *>(this);
else
@ -1610,7 +1610,7 @@ VariableNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector
return substexpr;
case SymbolType::modelLocalVariable:
value = datatree.local_variables_table[symb_id];
value = datatree.getLocalVariable(symb_id);
if (value->maxEndoLag() <= 1)
return const_cast<VariableNode *>(this);
else
@ -1632,7 +1632,7 @@ VariableNode::substituteExoLead(subst_table_t &subst_table, vector<BinaryOpNode
else
return createExoLeadAuxiliaryVarForMyself(subst_table, neweqs);
case SymbolType::modelLocalVariable:
value = datatree.local_variables_table[symb_id];
value = datatree.getLocalVariable(symb_id);
if (value->maxExoLead() == 0)
return const_cast<VariableNode *>(this);
else
@ -1683,7 +1683,7 @@ VariableNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *
return substexpr;
case SymbolType::modelLocalVariable:
value = datatree.local_variables_table[symb_id];
value = datatree.getLocalVariable(symb_id);
if (value->maxExoLag() == 0)
return const_cast<VariableNode *>(this);
else
@ -1729,7 +1729,7 @@ VariableNode::differentiateForwardVars(const vector<string> &subset, subst_table
return datatree.AddPlus(datatree.AddVariable(symb_id, 0), diffvar);
}
case SymbolType::modelLocalVariable:
value = datatree.local_variables_table[symb_id];
value = datatree.getLocalVariable(symb_id);
if (value->maxEndoLead() <= 0)
return const_cast<VariableNode *>(this);
else