fix bug in var max lag and simplify code

issue#70
Houtan Bastani 2018-05-31 15:34:25 +02:00
parent e532ed9bab
commit 1b952a12e6
3 changed files with 72 additions and 77 deletions

View File

@ -3355,23 +3355,19 @@ DynamicModel::checkVarMinLag(vector<int> &eqnumber) const
int int
DynamicModel::getVarMaxLag(StaticModel &static_model, vector<int> &eqnumber) const DynamicModel::getVarMaxLag(StaticModel &static_model, vector<int> &eqnumber) const
{ {
vector<expr_t> lhs; set<expr_t> lhs;
for (vector<int>::const_iterator it = eqnumber.begin(); for (vector<int>::const_iterator it = eqnumber.begin();
it != eqnumber.end(); it++) it != eqnumber.end(); it++)
equations[*it]->get_arg1()->collectVARLHSVariable(lhs);
if (eqnumber.size() != lhs.size())
{ {
set<expr_t> lhs_set; cerr << "The LHS variables of the VAR are not unique" << endl;
equations[*it]->get_arg1()->collectVARLHSVariable(lhs_set); exit(EXIT_FAILURE);
if (lhs_set.size() != 1)
{
cerr << "ERROR: in Equation "
<< ". A VAR may only have one endogenous variable on the LHS. " << endl;
exit(EXIT_FAILURE);
}
lhs.push_back(*(lhs_set.begin()));
} }
set<expr_t> lhs_static; set<expr_t> lhs_static;
for(vector<expr_t>::const_iterator it = lhs.begin(); for(set<expr_t>::const_iterator it = lhs.begin();
it != lhs.end(); it++) it != lhs.end(); it++)
lhs_static.insert((*it)->toStatic(static_model)); lhs_static.insert((*it)->toStatic(static_model));
@ -3390,7 +3386,7 @@ DynamicModel::getVarLhsDiffAndInfo(vector<int> &eqnumber, vector<bool> &diff,
for (vector<int>::const_iterator it = eqnumber.begin(); for (vector<int>::const_iterator it = eqnumber.begin();
it != eqnumber.end(); it++) it != eqnumber.end(); it++)
{ {
diff.push_back(equations[*it]->get_arg1()->isDiffPresent()); equations[*it]->get_arg1()->countDiffs() > 0 ? diff.push_back(true) : diff.push_back(false);
if (diff.back()) if (diff.back())
{ {
set<pair<int, int> > diff_set; set<pair<int, int> > diff_set;

View File

@ -294,12 +294,6 @@ ExprNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_ar
return false; return false;
} }
bool
ExprNode::isDiffPresent() const
{
return false;
}
void void
ExprNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const ExprNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
{ {
@ -313,10 +307,10 @@ NumConstNode::NumConstNode(DataTree &datatree_arg, int id_arg) :
datatree.num_const_node_map[id] = this; datatree.num_const_node_map[id] = this;
} }
bool int
NumConstNode::isDiffPresent() const NumConstNode::countDiffs() const
{ {
return false; return 0;
} }
void void
@ -389,6 +383,8 @@ NumConstNode::compile(ostream &CompileCode, unsigned int &instruction_number,
void void
NumConstNode::collectVARLHSVariable(set<expr_t> &result) const NumConstNode::collectVARLHSVariable(set<expr_t> &result) const
{ {
cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
exit(EXIT_FAILURE);
} }
void void
@ -1144,7 +1140,7 @@ VariableNode::collectVARLHSVariable(set<expr_t> &result) const
result.insert(const_cast<VariableNode *>(this)); result.insert(const_cast<VariableNode *>(this));
else else
{ {
cerr << "ERROR: A VAR must have one endogenous variable on the LHS." << endl; cerr << "ERROR: you can only have endogenous variables or unary ops on LHS of VAR" << endl;
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
} }
@ -1730,10 +1726,10 @@ VariableNode::detrend(int symb_id, bool log_trend, expr_t trend) const
} }
} }
bool int
VariableNode::isDiffPresent() const VariableNode::countDiffs() const
{ {
return false; return 0;
} }
expr_t expr_t
@ -2965,19 +2961,20 @@ UnaryOpNode::VarMaxLag(DataTree &static_datatree, set<expr_t> &static_lhs, int &
arg->VarMaxLag(static_datatree, static_lhs, max_lag); arg->VarMaxLag(static_datatree, static_lhs, max_lag);
else else
{ {
for (set<expr_t>::const_iterator it = static_lhs.begin(); set<expr_t>::const_iterator it = static_lhs.find(this->toStatic(static_datatree));
it != static_lhs.end(); it++) if (it != static_lhs.end())
if (*it == this->toStatic(static_datatree)) {
{ int max_lag_tmp = arg->maxLag() - arg->countDiffs();
int max_lag_tmp = arg->maxLag(); if (max_lag_tmp > max_lag)
if (max_lag_tmp > max_lag) max_lag = max_lag_tmp;
max_lag = max_lag_tmp; }
return; else
} {
int max_lag_tmp = 0; int max_lag_tmp = 0;
arg->VarMaxLag(static_datatree, static_lhs, max_lag_tmp); arg->VarMaxLag(static_datatree, static_lhs, max_lag_tmp);
if (max_lag_tmp + 1 > max_lag) if (max_lag_tmp + 1 > max_lag)
max_lag = max_lag_tmp + 1; max_lag = max_lag_tmp + 1;
}
} }
} }
@ -3027,12 +3024,12 @@ UnaryOpNode::substituteAdl() const
return retval; return retval;
} }
bool int
UnaryOpNode::isDiffPresent() const UnaryOpNode::countDiffs() const
{ {
if (op_code == oDiff) if (op_code == oDiff)
return true; return arg->countDiffs() + 1;
return arg->isDiffPresent(); return arg->countDiffs();
} }
bool bool
@ -4414,8 +4411,8 @@ BinaryOpNode::VarMaxLag(DataTree &static_datatree, set<expr_t> &static_lhs, int
void void
BinaryOpNode::collectVARLHSVariable(set<expr_t> &result) const BinaryOpNode::collectVARLHSVariable(set<expr_t> &result) const
{ {
arg1->collectVARLHSVariable(result); cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
arg2->collectVARLHSVariable(result); exit(EXIT_FAILURE);
} }
void void
@ -5011,10 +5008,10 @@ BinaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &no
return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree); return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
} }
bool int
BinaryOpNode::isDiffPresent() const BinaryOpNode::countDiffs() const
{ {
return arg1->isDiffPresent() || arg2->isDiffPresent(); return arg1->countDiffs() + arg2->countDiffs();
} }
expr_t expr_t
@ -5657,9 +5654,8 @@ TrinaryOpNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int
void void
TrinaryOpNode::collectVARLHSVariable(set<expr_t> &result) const TrinaryOpNode::collectVARLHSVariable(set<expr_t> &result) const
{ {
arg1->collectVARLHSVariable(result); cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
arg2->collectVARLHSVariable(result); exit(EXIT_FAILURE);
arg3->collectVARLHSVariable(result);
} }
void void
@ -5923,10 +5919,10 @@ TrinaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &n
return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree); return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
} }
bool int
TrinaryOpNode::isDiffPresent() const TrinaryOpNode::countDiffs() const
{ {
return arg1->isDiffPresent() || arg2->isDiffPresent() || arg3->isDiffPresent(); return arg1->countDiffs() + arg2->countDiffs() + arg3->countDiffs();
} }
expr_t expr_t
@ -6127,9 +6123,8 @@ AbstractExternalFunctionNode::compileExternalFunctionArguments(ostream &CompileC
void void
AbstractExternalFunctionNode::collectVARLHSVariable(set<expr_t> &result) const AbstractExternalFunctionNode::collectVARLHSVariable(set<expr_t> &result) const
{ {
for (vector<expr_t>::const_iterator it = arguments.begin(); cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
it != arguments.end(); it++) exit(EXIT_FAILURE);
(*it)->collectVARLHSVariable(result);
} }
void void
@ -6362,13 +6357,13 @@ AbstractExternalFunctionNode::substituteUnaryOpNodes(DataTree &static_datatree,
return buildSimilarExternalFunctionNode(arguments_subst, datatree); return buildSimilarExternalFunctionNode(arguments_subst, datatree);
} }
bool int
AbstractExternalFunctionNode::isDiffPresent() const AbstractExternalFunctionNode::countDiffs() const
{ {
bool result = false; int ndiffs = 0;
for (vector<expr_t>::const_iterator it = arguments.begin(); it != arguments.end(); it++) for (vector<expr_t>::const_iterator it = arguments.begin(); it != arguments.end(); it++)
result = result || (*it)->isDiffPresent(); ndiffs += (*it)->countDiffs();
return result; return ndiffs;
} }
expr_t expr_t
@ -7829,10 +7824,10 @@ VarExpectationNode::eval(const eval_context_t &eval_context) const throw (EvalEx
return it->second; return it->second;
} }
bool int
VarExpectationNode::isDiffPresent() const VarExpectationNode::countDiffs() const
{ {
return false; return 0;
} }
void void
@ -7843,6 +7838,8 @@ VarExpectationNode::computeXrefs(EquationInfo &ei) const
void void
VarExpectationNode::collectVARLHSVariable(set<expr_t> &result) const VarExpectationNode::collectVARLHSVariable(set<expr_t> &result) const
{ {
cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
exit(EXIT_FAILURE);
} }
void void
@ -8281,6 +8278,8 @@ PacExpectationNode::computeXrefs(EquationInfo &ei) const
void void
PacExpectationNode::collectVARLHSVariable(set<expr_t> &result) const PacExpectationNode::collectVARLHSVariable(set<expr_t> &result) const
{ {
cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
exit(EXIT_FAILURE);
} }
void void
@ -8306,10 +8305,10 @@ PacExpectationNode::compile(ostream &CompileCode, unsigned int &instruction_numb
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
bool int
PacExpectationNode::isDiffPresent() const PacExpectationNode::countDiffs() const
{ {
return false; return 0;
} }
pair<int, expr_t > pair<int, expr_t >

View File

@ -465,8 +465,8 @@ class ExprNode
//! Returns true if the expression contains one or several exogenous variable //! Returns true if the expression contains one or several exogenous variable
virtual bool containsExogenous() const = 0; virtual bool containsExogenous() const = 0;
//! Returns true if the expression contains a diff operator //! Returns the number of diffs present
virtual bool isDiffPresent(void) const = 0; virtual int countDiffs() const = 0;
//! Return true if the nodeID is a variable withe a type equal to type_arg, a specific variable id aqual to varfiable_id and a lag equal to lag_arg and false otherwise //! Return true if the nodeID is a variable withe a type equal to type_arg, a specific variable id aqual to varfiable_id and a lag equal to lag_arg and false otherwise
/*! /*!
@ -595,7 +595,7 @@ public:
virtual bool isNumConstNodeEqualTo(double value) const; virtual bool isNumConstNodeEqualTo(double value) const;
virtual bool containsEndogenous(void) const; virtual bool containsEndogenous(void) const;
virtual bool containsExogenous() const; virtual bool containsExogenous() const;
virtual bool isDiffPresent(void) const; virtual int countDiffs() const;
virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const; virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const;
virtual expr_t replaceTrendVar() const; virtual expr_t replaceTrendVar() const;
virtual expr_t detrend(int symb_id, bool log_trend, expr_t trend) const; virtual expr_t detrend(int symb_id, bool log_trend, expr_t trend) const;
@ -685,7 +685,7 @@ public:
virtual bool isNumConstNodeEqualTo(double value) const; virtual bool isNumConstNodeEqualTo(double value) const;
virtual bool containsEndogenous(void) const; virtual bool containsEndogenous(void) const;
virtual bool containsExogenous() const; virtual bool containsExogenous() const;
virtual bool isDiffPresent(void) const; virtual int countDiffs() const;
virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const; virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const;
virtual expr_t replaceTrendVar() const; virtual expr_t replaceTrendVar() const;
virtual expr_t detrend(int symb_id, bool log_trend, expr_t trend) const; virtual expr_t detrend(int symb_id, bool log_trend, expr_t trend) const;
@ -799,7 +799,7 @@ public:
virtual bool isNumConstNodeEqualTo(double value) const; virtual bool isNumConstNodeEqualTo(double value) const;
virtual bool containsEndogenous(void) const; virtual bool containsEndogenous(void) const;
virtual bool containsExogenous() const; virtual bool containsExogenous() const;
virtual bool isDiffPresent(void) const; virtual int countDiffs() const;
virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const; virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const;
virtual expr_t replaceTrendVar() const; virtual expr_t replaceTrendVar() const;
virtual expr_t detrend(int symb_id, bool log_trend, expr_t trend) const; virtual expr_t detrend(int symb_id, bool log_trend, expr_t trend) const;
@ -928,7 +928,7 @@ public:
virtual bool isNumConstNodeEqualTo(double value) const; virtual bool isNumConstNodeEqualTo(double value) const;
virtual bool containsEndogenous(void) const; virtual bool containsEndogenous(void) const;
virtual bool containsExogenous() const; virtual bool containsExogenous() const;
virtual bool isDiffPresent(void) const; virtual int countDiffs() const;
virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const; virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const;
virtual expr_t replaceTrendVar() const; virtual expr_t replaceTrendVar() const;
virtual expr_t detrend(int symb_id, bool log_trend, expr_t trend) const; virtual expr_t detrend(int symb_id, bool log_trend, expr_t trend) const;
@ -1033,7 +1033,7 @@ public:
virtual bool isNumConstNodeEqualTo(double value) const; virtual bool isNumConstNodeEqualTo(double value) const;
virtual bool containsEndogenous(void) const; virtual bool containsEndogenous(void) const;
virtual bool containsExogenous() const; virtual bool containsExogenous() const;
virtual bool isDiffPresent(void) const; virtual int countDiffs() const;
virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const; virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const;
virtual expr_t replaceTrendVar() const; virtual expr_t replaceTrendVar() const;
virtual expr_t detrend(int symb_id, bool log_trend, expr_t trend) const; virtual expr_t detrend(int symb_id, bool log_trend, expr_t trend) const;
@ -1139,7 +1139,7 @@ public:
virtual bool isNumConstNodeEqualTo(double value) const; virtual bool isNumConstNodeEqualTo(double value) const;
virtual bool containsEndogenous(void) const; virtual bool containsEndogenous(void) const;
virtual bool containsExogenous() const; virtual bool containsExogenous() const;
virtual bool isDiffPresent(void) const; virtual int countDiffs() const;
virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const; virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const;
virtual void writePrhs(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, deriv_node_temp_terms_t &tef_terms, const string &ending) const; virtual void writePrhs(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, deriv_node_temp_terms_t &tef_terms, const string &ending) const;
virtual expr_t replaceTrendVar() const; virtual expr_t replaceTrendVar() const;
@ -1338,7 +1338,7 @@ public:
virtual void collectDynamicVariables(SymbolType type_arg, set<pair<int, int> > &result) const; virtual void collectDynamicVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
virtual bool containsEndogenous(void) const; virtual bool containsEndogenous(void) const;
virtual bool containsExogenous() const; virtual bool containsExogenous() const;
virtual bool isDiffPresent(void) const; virtual int countDiffs() const;
virtual bool isNumConstNodeEqualTo(double value) const; virtual bool isNumConstNodeEqualTo(double value) const;
virtual expr_t differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const; virtual expr_t differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
virtual expr_t decreaseLeadsLagsPredeterminedVariables() const; virtual expr_t decreaseLeadsLagsPredeterminedVariables() const;
@ -1423,7 +1423,7 @@ public:
virtual void collectDynamicVariables(SymbolType type_arg, set<pair<int, int> > &result) const; virtual void collectDynamicVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
virtual bool containsEndogenous(void) const; virtual bool containsEndogenous(void) const;
virtual bool containsExogenous() const; virtual bool containsExogenous() const;
virtual bool isDiffPresent(void) const; virtual int countDiffs() const;
virtual bool isNumConstNodeEqualTo(double value) const; virtual bool isNumConstNodeEqualTo(double value) const;
virtual expr_t differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const; virtual expr_t differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
virtual expr_t decreaseLeadsLagsPredeterminedVariables() const; virtual expr_t decreaseLeadsLagsPredeterminedVariables() const;