changed expectation operator from BinaryOpNode to UnaryOpNode

git-svn-id: https://www.dynare.org/svn/dynare/trunk@3102 ac1d8469-bf42-47a9-8791-bf33cf982152
issue#70
houtanb 2009-10-30 05:21:54 +00:00
parent 760085d6fe
commit e45d3a4cb2
6 changed files with 88 additions and 81 deletions

View File

@ -149,7 +149,8 @@ enum UnaryOpcode
oAsinh, oAsinh,
oAtanh, oAtanh,
oSqrt, oSqrt,
oSteadyState oSteadyState,
oExpectation
}; };
enum BinaryOpcode enum BinaryOpcode
@ -167,8 +168,7 @@ enum BinaryOpcode
oLessEqual, oLessEqual,
oGreaterEqual, oGreaterEqual,
oEqualEqual, oEqualEqual,
oDifferent, oDifferent
oExpectation
}; };
enum TrinaryOpcode enum TrinaryOpcode

View File

@ -244,17 +244,6 @@ DataTree::AddPower(NodeID iArg1, NodeID iArg2)
return Zero; return Zero;
} }
NodeID
DataTree::AddExpectation(int iArg1, NodeID iArg2)
{
ostringstream period;
period << abs(iArg1);
if (iArg1 >= 0)
return AddBinaryOp(AddNumConstant(period.str()), oExpectation, iArg2);
else
return AddBinaryOp(AddUMinus(AddNumConstant(period.str())), oExpectation, iArg2);
}
NodeID NodeID
DataTree::AddExp(NodeID iArg1) DataTree::AddExp(NodeID iArg1)
{ {
@ -433,6 +422,12 @@ DataTree::AddSteadyState(NodeID iArg1)
return AddUnaryOp(oSteadyState, iArg1); return AddUnaryOp(oSteadyState, iArg1);
} }
NodeID
DataTree::AddExpectation(int iArg1, NodeID iArg2)
{
return AddUnaryOp(oExpectation, iArg2, iArg1);
}
NodeID NodeID
DataTree::AddEqual(NodeID iArg1, NodeID iArg2) DataTree::AddEqual(NodeID iArg1, NodeID iArg2)
{ {

View File

@ -78,7 +78,7 @@ private:
int node_counter; int node_counter;
inline NodeID AddPossiblyNegativeConstant(double val); inline NodeID AddPossiblyNegativeConstant(double val);
inline NodeID AddUnaryOp(UnaryOpcode op_code, NodeID arg); inline NodeID AddUnaryOp(UnaryOpcode op_code, NodeID arg, int arg_exp_info_set = 0);
inline NodeID AddBinaryOp(NodeID arg1, BinaryOpcode op_code, NodeID arg2); inline NodeID AddBinaryOp(NodeID arg1, BinaryOpcode op_code, NodeID arg2);
inline NodeID AddTrinaryOp(NodeID arg1, TrinaryOpcode op_code, NodeID arg2, NodeID arg3); inline NodeID AddTrinaryOp(NodeID arg1, TrinaryOpcode op_code, NodeID arg2, NodeID arg3);
@ -216,7 +216,7 @@ DataTree::AddPossiblyNegativeConstant(double v)
} }
inline NodeID inline NodeID
DataTree::AddUnaryOp(UnaryOpcode op_code, NodeID arg) DataTree::AddUnaryOp(UnaryOpcode op_code, NodeID arg, int arg_exp_info_set)
{ {
// If the node already exists in tree, share it // If the node already exists in tree, share it
unary_op_node_map_type::iterator it = unary_op_node_map.find(make_pair(arg, op_code)); unary_op_node_map_type::iterator it = unary_op_node_map.find(make_pair(arg, op_code));
@ -238,7 +238,7 @@ DataTree::AddUnaryOp(UnaryOpcode op_code, NodeID arg)
{ {
} }
} }
return new UnaryOpNode(*this, op_code, arg); return new UnaryOpNode(*this, op_code, arg, arg_exp_info_set);
} }
inline NodeID inline NodeID

View File

@ -3031,6 +3031,9 @@ DynamicModel::substituteLeadLagInternal(aux_var_t type)
case avExoLag: case avExoLag:
cout << "exo lags"; cout << "exo lags";
break; break;
case avExpectation:
cout << "expectation";
break;
} }
cout << ": added " << neweqs.size() << " auxiliary variables and equations." << endl; cout << ": added " << neweqs.size() << " auxiliary variables and equations." << endl;
} }
@ -3043,8 +3046,8 @@ DynamicModel::substituteExpectation(bool partial_information_model)
vector<BinaryOpNode *> neweqs; vector<BinaryOpNode *> neweqs;
// Substitute in model binary op node map // Substitute in model binary op node map
for(binary_op_node_map_type::reverse_iterator it = binary_op_node_map.rbegin(); for(unary_op_node_map_type::reverse_iterator it = unary_op_node_map.rbegin();
it != binary_op_node_map.rend(); it++) it != unary_op_node_map.rend(); it++)
it->second->substituteExpectation(subst_table, neweqs, partial_information_model); it->second->substituteExpectation(subst_table, neweqs, partial_information_model);
// Substitute in equations // Substitute in equations
@ -3064,7 +3067,7 @@ DynamicModel::substituteExpectation(bool partial_information_model)
if (neweqs.size() > 0) if (neweqs.size() > 0)
if (partial_information_model) if (partial_information_model)
cout << "Substitution of Expectation operator: added auxiliary variables and equations." << endl; //FIX to reflect correct number of equations cout << "Substitution of Expectation operator: added " << subst_table.size() << " auxiliary variables and " << neweqs.size() << " auxiliary equations." << endl;
else else
cout << "Substitution of Expectation operator: added " << neweqs.size() << " auxiliary variables and equations." << endl; cout << "Substitution of Expectation operator: added " << neweqs.size() << " auxiliary variables and equations." << endl;
} }

View File

@ -966,9 +966,10 @@ VariableNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpN
return const_cast<VariableNode *>(this); return const_cast<VariableNode *>(this);
} }
UnaryOpNode::UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const NodeID arg_arg) : UnaryOpNode::UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const NodeID arg_arg, const int arg_exp_info_set) :
ExprNode(datatree_arg), ExprNode(datatree_arg),
arg(arg_arg), arg(arg_arg),
Expectation_information_set(arg_exp_info_set),
op_code(op_code_arg) op_code(op_code_arg)
{ {
// Add myself to the unary op map // Add myself to the unary op map
@ -1057,6 +1058,8 @@ UnaryOpNode::composeDerivatives(NodeID darg)
return datatree.Zero; return datatree.Zero;
else else
return darg; return darg;
case oExpectation:
assert(0);
} }
// Suppress GCC warning // Suppress GCC warning
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
@ -1115,7 +1118,8 @@ UnaryOpNode::cost(const temporary_terms_type &temporary_terms, bool is_matlab) c
return cost + 350; return cost + 350;
case oSqrt: case oSqrt:
return cost + 570; return cost + 570;
case oSteadyState: case oSteadyState:
case oExpectation:
return cost; return cost;
} }
else else
@ -1151,7 +1155,8 @@ UnaryOpNode::cost(const temporary_terms_type &temporary_terms, bool is_matlab) c
return cost + 150; return cost + 150;
case oSqrt: case oSqrt:
return cost + 90; return cost + 90;
case oSteadyState: case oSteadyState:
case oExpectation:
return cost; return cost;
} }
// Suppress GCC warning // Suppress GCC warning
@ -1314,6 +1319,8 @@ UnaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
} }
arg->writeOutput(output, new_output_type, temporary_terms); arg->writeOutput(output, new_output_type, temporary_terms);
return; return;
case oExpectation:
assert(0);
} }
bool close_parenthesis = false; bool close_parenthesis = false;
@ -1389,8 +1396,10 @@ UnaryOpNode::eval_opcode(UnaryOpcode op_code, double v) throw (EvalException)
#endif #endif
case oSqrt: case oSqrt:
return(sqrt(v)); return(sqrt(v));
case oSteadyState: case oSteadyState:
return(v); return(v);
case oExpectation:
throw EvalException();
} }
// Suppress GCC warning // Suppress GCC warning
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
@ -1492,6 +1501,8 @@ UnaryOpNode::normalizeEquation(int var_endo, vector<pair<int, pair<NodeID, NodeI
return(make_pair(1, (NodeID)NULL)); return(make_pair(1, (NodeID)NULL));
case oSteadyState: case oSteadyState:
return(make_pair(1, (NodeID)NULL)); return(make_pair(1, (NodeID)NULL));
case oExpectation:
assert(0);
} }
} }
else else
@ -1534,6 +1545,8 @@ UnaryOpNode::normalizeEquation(int var_endo, vector<pair<int, pair<NodeID, NodeI
return(make_pair(0, datatree.AddSqrt(New_NodeID))); return(make_pair(0, datatree.AddSqrt(New_NodeID)));
case oSteadyState: case oSteadyState:
return(make_pair(0, datatree.AddSteadyState(New_NodeID))); return(make_pair(0, datatree.AddSteadyState(New_NodeID)));
case oExpectation:
assert(0);
} }
} }
return(make_pair(1, (NodeID)NULL)); return(make_pair(1, (NodeID)NULL));
@ -1588,6 +1601,8 @@ UnaryOpNode::buildSimilarUnaryOpNode(NodeID alt_arg, DataTree &alt_datatree) con
return alt_datatree.AddSqrt(alt_arg); return alt_datatree.AddSqrt(alt_arg);
case oSteadyState: case oSteadyState:
return alt_datatree.AddSteadyState(alt_arg); return alt_datatree.AddSteadyState(alt_arg);
case oExpectation:
return alt_datatree.AddExpectation(Expectation_information_set, alt_arg);
} }
// Suppress GCC warning // Suppress GCC warning
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
@ -1670,8 +1685,52 @@ UnaryOpNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *>
NodeID NodeID
UnaryOpNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const UnaryOpNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const
{ {
NodeID argsubst = arg->substituteExpectation(subst_table, neweqs, partial_information_model); switch(op_code)
return buildSimilarUnaryOpNode(argsubst, datatree); {
case oExpectation:
{
subst_table_t::iterator it = subst_table.find(const_cast<UnaryOpNode *>(this));
//This if statement should evaluate to true when substituting Exp operators out of equations in second pass
if (it != subst_table.end())
return const_cast<VariableNode *>(it->second);
//Arriving here, we need to create an auxiliary variable for this Expectation Operator:
int symb_id = datatree.symbol_table.addExpectationAuxiliaryVar(Expectation_information_set, arg->idx); //AUXE_period_arg.idx
NodeID newAuxE = datatree.AddVariable(symb_id, 0);
assert(dynamic_cast<VariableNode *>(newAuxE) != NULL);
if (partial_information_model && Expectation_information_set==0)
{
//Ensure x is a single variable as opposed to an expression
if (dynamic_cast<VariableNode *>(arg) == NULL)
{
cerr << "In Partial Information models, EXPECTATION(0)(X) can only be used when X is a single variable." << endl;
exit(EXIT_FAILURE);
}
}
else
{
//take care of any nested expectation operators by calling arg->substituteExpectation(.), then decreaseLeadsLags for this oExp operator
//arg(lag-period) (holds entire subtree of arg(lag-period)
NodeID substexpr = (arg->substituteExpectation(subst_table, neweqs, partial_information_model))->decreaseLeadsLags(Expectation_information_set);
assert(substexpr != NULL);
neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(newAuxE, substexpr))); //AUXE_period_arg.idx = arg(lag-period)
newAuxE = newAuxE->decreaseLeadsLags(-1*Expectation_information_set);
assert(dynamic_cast<VariableNode *>(newAuxE) != NULL);
}
subst_table[this] = dynamic_cast<VariableNode *>(newAuxE);
return newAuxE;
}
default:
{
NodeID argsubst = arg->substituteExpectation(subst_table, neweqs, partial_information_model);
return buildSimilarUnaryOpNode(argsubst, datatree);
}
}
} }
BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg, BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg,
@ -1820,7 +1879,6 @@ BinaryOpNode::precedence(ExprNodeOutputType output_type, const temporary_terms_t
return 5; return 5;
case oMin: case oMin:
case oMax: case oMax:
case oExpectation:
return 100; return 100;
} }
// Suppress GCC warning // Suppress GCC warning
@ -1861,7 +1919,6 @@ BinaryOpNode::cost(const temporary_terms_type &temporary_terms, bool is_matlab)
case oPower: case oPower:
return cost + 1160; return cost + 1160;
case oEqual: case oEqual:
case oExpectation:
return cost; return cost;
} }
else else
@ -1887,7 +1944,6 @@ BinaryOpNode::cost(const temporary_terms_type &temporary_terms, bool is_matlab)
case oPower: case oPower:
return cost + 520; return cost + 520;
case oEqual: case oEqual:
case oExpectation:
return cost; return cost;
} }
// Suppress GCC warning // Suppress GCC warning
@ -1986,7 +2042,6 @@ BinaryOpNode::eval_opcode(double v1, BinaryOpcode op_code, double v2) throw (Eva
case oDifferent: case oDifferent:
return (v1 != v2); return (v1 != v2);
case oEqual: case oEqual:
case oExpectation:
throw EvalException(); throw EvalException();
} }
// Suppress GCC warning // Suppress GCC warning
@ -2524,8 +2579,6 @@ BinaryOpNode::buildSimilarBinaryOpNode(NodeID alt_arg1, NodeID alt_arg2, DataTre
return alt_datatree.AddEqualEqual(alt_arg1, alt_arg2); return alt_datatree.AddEqualEqual(alt_arg1, alt_arg2);
case oDifferent: case oDifferent:
return alt_datatree.AddDifferent(alt_arg1, alt_arg2); return alt_datatree.AddDifferent(alt_arg1, alt_arg2);
case oExpectation:
return alt_datatree.AddExpectation((int)(alt_arg1->eval(map<int, double>())), alt_arg2);
} }
// Suppress GCC warning // Suppress GCC warning
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
@ -2650,54 +2703,9 @@ BinaryOpNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *
NodeID NodeID
BinaryOpNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const BinaryOpNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const
{ {
switch(op_code) NodeID arg1subst = arg1->substituteExpectation(subst_table, neweqs, partial_information_model);
{ NodeID arg2subst = arg2->substituteExpectation(subst_table, neweqs, partial_information_model);
case oExpectation: return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
{
int period = (int)(arg1->eval(map<int, double>()));
subst_table_t::iterator it = subst_table.find(const_cast<BinaryOpNode *>(this));
//IF should evaluate to true when substituting Exp operators out of equations in second pass
if (it != subst_table.end())
return const_cast<VariableNode *>(it->second);
//Arriving here, we need to create an auxiliary variable for this Expectation Operator:
int symb_id = datatree.symbol_table.addExpectationAuxiliaryVar(arg1->idx, arg2->idx); //AUXE_arg1.idx_arg2.idx
NodeID newAuxE = datatree.AddVariable(symb_id, 0);
assert(dynamic_cast<VariableNode *>(newAuxE) != NULL);
if (partial_information_model && period==0)
{
//Ensure x is a single variable as opposed to an expression
if (dynamic_cast<VariableNode *>(arg2) == NULL)
{
cerr << "In Partial Information models, EXPECTATION(0)(X) can only be used when X is a single variable." << endl;
exit(EXIT_FAILURE);
}
}
else
{
//take care of any nested expectation operators by calling arg2->substituteExpectation(.), then decreaseLeadsLags for this oExp operator
//arg2(lag-period) (holds entire subtree of arg2(lag-period)
NodeID substexpr = (arg2->substituteExpectation(subst_table, neweqs, partial_information_model))->decreaseLeadsLags(period);
assert(substexpr != NULL);
neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(newAuxE, substexpr))); //AUXE_arg1.idx_arg2.idx = arg2(lag-period)
newAuxE = newAuxE->decreaseLeadsLags(-1*period);
assert(dynamic_cast<VariableNode *>(newAuxE) != NULL);
}
subst_table[this] = dynamic_cast<VariableNode *>(newAuxE);
return newAuxE;
}
default:
{
NodeID arg1subst = arg1->substituteExpectation(subst_table, neweqs, partial_information_model);
NodeID arg2subst = arg2->substituteExpectation(subst_table, neweqs, partial_information_model);
return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
}
}
} }
TrinaryOpNode::TrinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg, TrinaryOpNode::TrinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg,

View File

@ -387,13 +387,14 @@ class UnaryOpNode : public ExprNode
{ {
private: private:
const NodeID arg; const NodeID arg;
const int Expectation_information_set;
const UnaryOpcode op_code; const UnaryOpcode op_code;
virtual NodeID computeDerivative(int deriv_id); virtual NodeID computeDerivative(int deriv_id);
virtual int cost(const temporary_terms_type &temporary_terms, bool is_matlab) const; virtual int cost(const temporary_terms_type &temporary_terms, bool is_matlab) const;
//! Returns the derivative of this node if darg is the derivative of the argument //! Returns the derivative of this node if darg is the derivative of the argument
NodeID composeDerivatives(NodeID darg); NodeID composeDerivatives(NodeID darg);
public: public:
UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const NodeID arg_arg); UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const NodeID arg_arg, const int arg_exp_info_set);
virtual void prepareForDerivation(); virtual void prepareForDerivation();
virtual void computeTemporaryTerms(map<NodeID, int> &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const; virtual void computeTemporaryTerms(map<NodeID, int> &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const;
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const; virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;