ExprNode and its subclasses are no longer friends of DataTree
This ensures a better isolation between the container and the containees.issue#70
parent
004d909621
commit
c6d4cb88c3
|
@ -38,19 +38,7 @@ using namespace std;
|
||||||
|
|
||||||
class DataTree
|
class DataTree
|
||||||
{
|
{
|
||||||
friend class ExprNode;
|
public:
|
||||||
friend class NumConstNode;
|
|
||||||
friend class VariableNode;
|
|
||||||
friend class UnaryOpNode;
|
|
||||||
friend class BinaryOpNode;
|
|
||||||
friend class TrinaryOpNode;
|
|
||||||
friend class AbstractExternalFunctionNode;
|
|
||||||
friend class ExternalFunctionNode;
|
|
||||||
friend class FirstDerivExternalFunctionNode;
|
|
||||||
friend class SecondDerivExternalFunctionNode;
|
|
||||||
friend class VarExpectationNode;
|
|
||||||
friend class PacExpectationNode;
|
|
||||||
protected:
|
|
||||||
//! A reference to the symbol table
|
//! A reference to the symbol table
|
||||||
SymbolTable &symbol_table;
|
SymbolTable &symbol_table;
|
||||||
//! Reference to numerical constants table
|
//! Reference to numerical constants table
|
||||||
|
@ -62,6 +50,7 @@ protected:
|
||||||
//! A reference to the VAR model table
|
//! A reference to the VAR model table
|
||||||
VarModelTable &var_model_table;
|
VarModelTable &var_model_table;
|
||||||
|
|
||||||
|
protected:
|
||||||
//! num_constant_id -> NumConstNode
|
//! num_constant_id -> NumConstNode
|
||||||
using num_const_node_map_t = map<int, NumConstNode *>;
|
using num_const_node_map_t = map<int, NumConstNode *>;
|
||||||
num_const_node_map_t num_const_node_map;
|
num_const_node_map_t num_const_node_map;
|
||||||
|
@ -126,7 +115,6 @@ private:
|
||||||
//! The list of nodes
|
//! The list of nodes
|
||||||
vector<unique_ptr<ExprNode>> node_list;
|
vector<unique_ptr<ExprNode>> node_list;
|
||||||
|
|
||||||
inline expr_t AddPossiblyNegativeConstant(double val);
|
|
||||||
inline expr_t AddUnaryOp(UnaryOpcode op_code, expr_t arg, int arg_exp_info_set = 0, int param1_symb_id = 0, int param2_symb_id = 0, const string &adl_param_name = "", const vector<int> &adl_lags = vector<int>());
|
inline expr_t AddUnaryOp(UnaryOpcode op_code, expr_t arg, int arg_exp_info_set = 0, int param1_symb_id = 0, int param2_symb_id = 0, const string &adl_param_name = "", const vector<int> &adl_lags = vector<int>());
|
||||||
inline expr_t AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder = 0);
|
inline expr_t AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder = 0);
|
||||||
inline expr_t AddTrinaryOp(expr_t arg1, TrinaryOpcode op_code, expr_t arg2, expr_t arg3);
|
inline expr_t AddTrinaryOp(expr_t arg1, TrinaryOpcode op_code, expr_t arg2, expr_t arg3);
|
||||||
|
@ -157,6 +145,7 @@ public:
|
||||||
{
|
{
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline expr_t AddPossiblyNegativeConstant(double val);
|
||||||
//! Adds a non-negative numerical constant (possibly Inf or NaN)
|
//! Adds a non-negative numerical constant (possibly Inf or NaN)
|
||||||
expr_t AddNonNegativeConstant(const string &value);
|
expr_t AddNonNegativeConstant(const string &value);
|
||||||
//! Adds a variable
|
//! Adds a variable
|
||||||
|
@ -316,6 +305,25 @@ public:
|
||||||
{
|
{
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class UnknownLocalVariableException
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
//! Symbol ID
|
||||||
|
int id;
|
||||||
|
UnknownLocalVariableException(int id_arg) : id(id_arg)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
expr_t getLocalVariable(int symb_id) const
|
||||||
|
{
|
||||||
|
auto it = local_variables_table.find(symb_id);
|
||||||
|
if (it == local_variables_table.end())
|
||||||
|
throw UnknownLocalVariableException(symb_id);
|
||||||
|
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
inline expr_t
|
inline expr_t
|
||||||
|
|
|
@ -727,9 +727,9 @@ VariableNode::prepareForDerivation()
|
||||||
non_null_derivatives.insert(datatree.getDerivID(symb_id, lag));
|
non_null_derivatives.insert(datatree.getDerivID(symb_id, lag));
|
||||||
break;
|
break;
|
||||||
case SymbolType::modelLocalVariable:
|
case SymbolType::modelLocalVariable:
|
||||||
datatree.local_variables_table[symb_id]->prepareForDerivation();
|
datatree.getLocalVariable(symb_id)->prepareForDerivation();
|
||||||
// Non null derivatives are those of the value of the local parameter
|
// Non null derivatives are those of the value of the local parameter
|
||||||
non_null_derivatives = datatree.local_variables_table[symb_id]->non_null_derivatives;
|
non_null_derivatives = datatree.getLocalVariable(symb_id)->non_null_derivatives;
|
||||||
break;
|
break;
|
||||||
case SymbolType::modFileLocalVariable:
|
case SymbolType::modFileLocalVariable:
|
||||||
case SymbolType::statementDeclaredVariable:
|
case SymbolType::statementDeclaredVariable:
|
||||||
|
@ -762,7 +762,7 @@ VariableNode::computeDerivative(int deriv_id)
|
||||||
else
|
else
|
||||||
return datatree.Zero;
|
return datatree.Zero;
|
||||||
case SymbolType::modelLocalVariable:
|
case SymbolType::modelLocalVariable:
|
||||||
return datatree.local_variables_table[symb_id]->getDerivative(deriv_id);
|
return datatree.getLocalVariable(symb_id)->getDerivative(deriv_id);
|
||||||
case SymbolType::modFileLocalVariable:
|
case SymbolType::modFileLocalVariable:
|
||||||
cerr << "ModFileLocalVariable is not derivable" << endl;
|
cerr << "ModFileLocalVariable is not derivable" << endl;
|
||||||
exit(EXIT_FAILURE);
|
exit(EXIT_FAILURE);
|
||||||
|
@ -791,7 +791,7 @@ VariableNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, t
|
||||||
if (it != temporary_terms.end())
|
if (it != temporary_terms.end())
|
||||||
temporary_terms_inuse.insert(idx);
|
temporary_terms_inuse.insert(idx);
|
||||||
if (type == SymbolType::modelLocalVariable)
|
if (type == SymbolType::modelLocalVariable)
|
||||||
datatree.local_variables_table[symb_id]->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block);
|
datatree.getLocalVariable(symb_id)->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool
|
bool
|
||||||
|
@ -866,7 +866,7 @@ VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
|
||||||
|| output_type == oCDynamicSteadyStateOperator)
|
|| output_type == oCDynamicSteadyStateOperator)
|
||||||
{
|
{
|
||||||
output << "(";
|
output << "(";
|
||||||
datatree.local_variables_table[symb_id]->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
|
datatree.getLocalVariable(symb_id)->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
|
||||||
output << ")";
|
output << ")";
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
@ -1108,7 +1108,7 @@ VariableNode::compile(ostream &CompileCode, unsigned int &instruction_number,
|
||||||
const deriv_node_temp_terms_t &tef_terms) const
|
const deriv_node_temp_terms_t &tef_terms) const
|
||||||
{
|
{
|
||||||
if (type == SymbolType::modelLocalVariable || type == SymbolType::modFileLocalVariable)
|
if (type == SymbolType::modelLocalVariable || type == SymbolType::modFileLocalVariable)
|
||||||
datatree.local_variables_table[symb_id]->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms);
|
datatree.getLocalVariable(symb_id)->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms);
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
int tsid = datatree.symbol_table.getTypeSpecificID(symb_id);
|
int tsid = datatree.symbol_table.getTypeSpecificID(symb_id);
|
||||||
|
@ -1184,7 +1184,7 @@ VariableNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
|
||||||
int equation) const
|
int equation) const
|
||||||
{
|
{
|
||||||
if (type == SymbolType::modelLocalVariable)
|
if (type == SymbolType::modelLocalVariable)
|
||||||
datatree.local_variables_table[symb_id]->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation);
|
datatree.getLocalVariable(symb_id)->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation);
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
|
@ -1205,7 +1205,7 @@ VariableNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &
|
||||||
if (type == type_arg)
|
if (type == type_arg)
|
||||||
result.emplace(symb_id, lag);
|
result.emplace(symb_id, lag);
|
||||||
if (type == SymbolType::modelLocalVariable)
|
if (type == SymbolType::modelLocalVariable)
|
||||||
datatree.local_variables_table[symb_id]->collectDynamicVariables(type_arg, result);
|
datatree.getLocalVariable(symb_id)->collectDynamicVariables(type_arg, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
pair<int, expr_t>
|
pair<int, expr_t>
|
||||||
|
@ -1229,14 +1229,14 @@ VariableNode::normalizeEquation(int var_endo, vector<pair<int, pair<expr_t, expr
|
||||||
/* the endogenous variable */
|
/* the endogenous variable */
|
||||||
return { 1, nullptr };
|
return { 1, nullptr };
|
||||||
else
|
else
|
||||||
return { 0, datatree.AddVariableInternal(symb_id, lag) };
|
return { 0, datatree.AddVariable(symb_id, lag) };
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
if (type == SymbolType::parameter)
|
if (type == SymbolType::parameter)
|
||||||
return { 0, datatree.AddVariableInternal(symb_id, 0) };
|
return { 0, datatree.AddVariable(symb_id, 0) };
|
||||||
else
|
else
|
||||||
return { 0, datatree.AddVariableInternal(symb_id, lag) };
|
return { 0, datatree.AddVariable(symb_id, lag) };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1277,7 +1277,7 @@ VariableNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recur
|
||||||
return datatree.Zero;
|
return datatree.Zero;
|
||||||
}
|
}
|
||||||
case SymbolType::modelLocalVariable:
|
case SymbolType::modelLocalVariable:
|
||||||
return datatree.local_variables_table[symb_id]->getChainRuleDerivative(deriv_id, recursive_variables);
|
return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables);
|
||||||
case SymbolType::modFileLocalVariable:
|
case SymbolType::modFileLocalVariable:
|
||||||
cerr << "ModFileLocalVariable is not derivable" << endl;
|
cerr << "ModFileLocalVariable is not derivable" << endl;
|
||||||
exit(EXIT_FAILURE);
|
exit(EXIT_FAILURE);
|
||||||
|
@ -1351,7 +1351,7 @@ VariableNode::maxEndoLead() const
|
||||||
case SymbolType::endogenous:
|
case SymbolType::endogenous:
|
||||||
return max(lag, 0);
|
return max(lag, 0);
|
||||||
case SymbolType::modelLocalVariable:
|
case SymbolType::modelLocalVariable:
|
||||||
return datatree.local_variables_table[symb_id]->maxEndoLead();
|
return datatree.getLocalVariable(symb_id)->maxEndoLead();
|
||||||
default:
|
default:
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -1365,7 +1365,7 @@ VariableNode::maxExoLead() const
|
||||||
case SymbolType::exogenous:
|
case SymbolType::exogenous:
|
||||||
return max(lag, 0);
|
return max(lag, 0);
|
||||||
case SymbolType::modelLocalVariable:
|
case SymbolType::modelLocalVariable:
|
||||||
return datatree.local_variables_table[symb_id]->maxExoLead();
|
return datatree.getLocalVariable(symb_id)->maxExoLead();
|
||||||
default:
|
default:
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -1379,7 +1379,7 @@ VariableNode::maxEndoLag() const
|
||||||
case SymbolType::endogenous:
|
case SymbolType::endogenous:
|
||||||
return max(-lag, 0);
|
return max(-lag, 0);
|
||||||
case SymbolType::modelLocalVariable:
|
case SymbolType::modelLocalVariable:
|
||||||
return datatree.local_variables_table[symb_id]->maxEndoLag();
|
return datatree.getLocalVariable(symb_id)->maxEndoLag();
|
||||||
default:
|
default:
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -1393,7 +1393,7 @@ VariableNode::maxExoLag() const
|
||||||
case SymbolType::exogenous:
|
case SymbolType::exogenous:
|
||||||
return max(-lag, 0);
|
return max(-lag, 0);
|
||||||
case SymbolType::modelLocalVariable:
|
case SymbolType::modelLocalVariable:
|
||||||
return datatree.local_variables_table[symb_id]->maxExoLag();
|
return datatree.getLocalVariable(symb_id)->maxExoLag();
|
||||||
default:
|
default:
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -1409,7 +1409,7 @@ VariableNode::maxLead() const
|
||||||
case SymbolType::exogenous:
|
case SymbolType::exogenous:
|
||||||
return lag;
|
return lag;
|
||||||
case SymbolType::modelLocalVariable:
|
case SymbolType::modelLocalVariable:
|
||||||
return datatree.local_variables_table[symb_id]->maxLead();
|
return datatree.getLocalVariable(symb_id)->maxLead();
|
||||||
default:
|
default:
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -1428,7 +1428,7 @@ VariableNode::VarMinLag() const
|
||||||
else
|
else
|
||||||
return 1; // Can have contemporaneus exog in VAR
|
return 1; // Can have contemporaneus exog in VAR
|
||||||
case SymbolType::modelLocalVariable:
|
case SymbolType::modelLocalVariable:
|
||||||
return datatree.local_variables_table[symb_id]->VarMinLag();
|
return datatree.getLocalVariable(symb_id)->VarMinLag();
|
||||||
default:
|
default:
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
@ -1444,7 +1444,7 @@ VariableNode::maxLag() const
|
||||||
case SymbolType::exogenous:
|
case SymbolType::exogenous:
|
||||||
return -lag;
|
return -lag;
|
||||||
case SymbolType::modelLocalVariable:
|
case SymbolType::modelLocalVariable:
|
||||||
return datatree.local_variables_table[symb_id]->maxLag();
|
return datatree.getLocalVariable(symb_id)->maxLag();
|
||||||
default:
|
default:
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -1532,7 +1532,7 @@ VariableNode::decreaseLeadsLags(int n) const
|
||||||
case SymbolType::logTrend:
|
case SymbolType::logTrend:
|
||||||
return datatree.AddVariable(symb_id, lag-n);
|
return datatree.AddVariable(symb_id, lag-n);
|
||||||
case SymbolType::modelLocalVariable:
|
case SymbolType::modelLocalVariable:
|
||||||
return datatree.local_variables_table[symb_id]->decreaseLeadsLags(n);
|
return datatree.getLocalVariable(symb_id)->decreaseLeadsLags(n);
|
||||||
default:
|
default:
|
||||||
return const_cast<VariableNode *>(this);
|
return const_cast<VariableNode *>(this);
|
||||||
}
|
}
|
||||||
|
@ -1559,7 +1559,7 @@ VariableNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vecto
|
||||||
else
|
else
|
||||||
return createEndoLeadAuxiliaryVarForMyself(subst_table, neweqs);
|
return createEndoLeadAuxiliaryVarForMyself(subst_table, neweqs);
|
||||||
case SymbolType::modelLocalVariable:
|
case SymbolType::modelLocalVariable:
|
||||||
value = datatree.local_variables_table[symb_id];
|
value = datatree.getLocalVariable(symb_id);
|
||||||
if (value->maxEndoLead() <= 1)
|
if (value->maxEndoLead() <= 1)
|
||||||
return const_cast<VariableNode *>(this);
|
return const_cast<VariableNode *>(this);
|
||||||
else
|
else
|
||||||
|
@ -1610,7 +1610,7 @@ VariableNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector
|
||||||
return substexpr;
|
return substexpr;
|
||||||
|
|
||||||
case SymbolType::modelLocalVariable:
|
case SymbolType::modelLocalVariable:
|
||||||
value = datatree.local_variables_table[symb_id];
|
value = datatree.getLocalVariable(symb_id);
|
||||||
if (value->maxEndoLag() <= 1)
|
if (value->maxEndoLag() <= 1)
|
||||||
return const_cast<VariableNode *>(this);
|
return const_cast<VariableNode *>(this);
|
||||||
else
|
else
|
||||||
|
@ -1632,7 +1632,7 @@ VariableNode::substituteExoLead(subst_table_t &subst_table, vector<BinaryOpNode
|
||||||
else
|
else
|
||||||
return createExoLeadAuxiliaryVarForMyself(subst_table, neweqs);
|
return createExoLeadAuxiliaryVarForMyself(subst_table, neweqs);
|
||||||
case SymbolType::modelLocalVariable:
|
case SymbolType::modelLocalVariable:
|
||||||
value = datatree.local_variables_table[symb_id];
|
value = datatree.getLocalVariable(symb_id);
|
||||||
if (value->maxExoLead() == 0)
|
if (value->maxExoLead() == 0)
|
||||||
return const_cast<VariableNode *>(this);
|
return const_cast<VariableNode *>(this);
|
||||||
else
|
else
|
||||||
|
@ -1683,7 +1683,7 @@ VariableNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *
|
||||||
return substexpr;
|
return substexpr;
|
||||||
|
|
||||||
case SymbolType::modelLocalVariable:
|
case SymbolType::modelLocalVariable:
|
||||||
value = datatree.local_variables_table[symb_id];
|
value = datatree.getLocalVariable(symb_id);
|
||||||
if (value->maxExoLag() == 0)
|
if (value->maxExoLag() == 0)
|
||||||
return const_cast<VariableNode *>(this);
|
return const_cast<VariableNode *>(this);
|
||||||
else
|
else
|
||||||
|
@ -1729,7 +1729,7 @@ VariableNode::differentiateForwardVars(const vector<string> &subset, subst_table
|
||||||
return datatree.AddPlus(datatree.AddVariable(symb_id, 0), diffvar);
|
return datatree.AddPlus(datatree.AddVariable(symb_id, 0), diffvar);
|
||||||
}
|
}
|
||||||
case SymbolType::modelLocalVariable:
|
case SymbolType::modelLocalVariable:
|
||||||
value = datatree.local_variables_table[symb_id];
|
value = datatree.getLocalVariable(symb_id);
|
||||||
if (value->maxEndoLead() <= 0)
|
if (value->maxEndoLead() <= 0)
|
||||||
return const_cast<VariableNode *>(this);
|
return const_cast<VariableNode *>(this);
|
||||||
else
|
else
|
||||||
|
|
Loading…
Reference in New Issue