pac_expectation: find rho_0 in pac equation

time-shift
Houtan Bastani 2018-02-15 17:06:30 +01:00
parent 5c50435479
commit d0ac1ffbcb
4 changed files with 69 additions and 22 deletions

View File

@ -4599,32 +4599,64 @@ BinaryOpNode::setVarExpectationIndex(map<string, pair<SymbolList, int> > &var_mo
arg2->setVarExpectationIndex(var_model_info);
}
void
BinaryOpNode::walkPacParametersHelper(const expr_t arg1, const expr_t arg2,
pair<int, int> &lhs,
set<pair<int, pair<int, int> > > &params_and_vals) const
{
set<int> params;
set<pair<int, int> > endogs;
arg1->collectVariables(eParameter, params);
arg2->collectDynamicVariables(eEndogenous, endogs);
if (params.size() == 1)
if (endogs.size() == 1)
params_and_vals.insert(make_pair(*(params.begin()), *(endogs.begin())));
else
if (endogs.size() == 2)
{
BinaryOpNode *testarg2 = dynamic_cast<BinaryOpNode *>(arg2);
VariableNode *test_arg1 = dynamic_cast<VariableNode *>(testarg2->get_arg1());
VariableNode *test_arg2 = dynamic_cast<VariableNode *>(testarg2->get_arg2());
if (testarg2 != NULL && testarg2->get_op_code() == oMinus
&& test_arg1 != NULL &&test_arg2 != NULL
&& lhs.first != -1)
{
int find_symb_id = -1;
try
{
// lhs is an aux var (diff)
find_symb_id = datatree.symbol_table.getOrigSymbIdForAuxVar(lhs.first);
}
catch (...)
{
//lhs is not an aux var
find_symb_id = lhs.first;
}
endogs.clear();
if (test_arg1->get_symb_id() == find_symb_id)
{
test_arg1->collectDynamicVariables(eEndogenous, endogs);
params_and_vals.insert(make_pair(*(params.begin()), *(endogs.begin())));
}
else if (test_arg2->get_symb_id() == find_symb_id)
{
test_arg2->collectDynamicVariables(eEndogenous, endogs);
params_and_vals.insert(make_pair(*(params.begin()), *(endogs.begin())));
}
}
}
}
void
BinaryOpNode::walkPacParameters(bool &pac_encountered, pair<int, int> &lhs, set<pair<int, pair<int, int> > > &params_and_vals) const
{
if (op_code == oTimes)
{
set<int> params;
set<pair<int, int> > endogs;
arg1->collectVariables(eParameter, params);
arg2->collectDynamicVariables(eEndogenous, endogs);
if (params.size() == 1 && endogs.size() == 1)
{
params_and_vals.insert(make_pair(*(params.begin()), *(endogs.begin())));
return;
}
else
{
params.clear();
endogs.clear();
arg1->collectDynamicVariables(eEndogenous, endogs);
arg2->collectVariables(eParameter, params);
if (params.size() == 1 && endogs.size() == 1)
{
params_and_vals.insert(make_pair(*(params.begin()), *(endogs.begin())));
return;
}
}
int orig_params_and_vals_size = params_and_vals.size();
walkPacParametersHelper(arg1, arg2, lhs, params_and_vals);
if (params_and_vals.size() == orig_params_and_vals_size)
walkPacParametersHelper(arg2, arg1, lhs, params_and_vals);
}
else if (op_code == oEqual)
{
@ -7738,7 +7770,7 @@ PacExpectationNode::addParamInfoToPac(pair<int, int> &lhs_arg, set<pair<int, pai
exit(EXIT_FAILURE);
}
if (params_and_vals_arg.size() != 2)
if (params_and_vals_arg.size() != 3)
{
cerr << "Pac Expectation: error in obtaining RHS parameters." << endl;
exit(EXIT_FAILURE);

View File

@ -834,6 +834,9 @@ public:
{
return powerDerivOrder;
}
void walkPacParametersHelper(const expr_t arg1, const expr_t arg2,
pair<int, int> &lhs,
set<pair<int, pair<int, int> > > &params_and_vals) const;
virtual expr_t toStatic(DataTree &static_datatree) const;
virtual void computeXrefs(EquationInfo &ei) const;
virtual pair<int, expr_t> normalizeEquation(int symb_id_endo, vector<pair<int, pair<expr_t, expr_t> > > &List_of_Op_RHS) const;

View File

@ -809,6 +809,16 @@ SymbolTable::searchAuxiliaryVars(int orig_symb_id, int orig_lead_lag) const thro
throw SearchFailedException(orig_symb_id, orig_lead_lag);
}
int
SymbolTable::getOrigSymbIdForAuxVar(int aux_var_symb_id) const throw (UnknownSymbolIDException)
{
for (size_t i = 0; i < aux_vars.size(); i++)
if ((aux_vars[i].get_type() == avEndoLag || aux_vars[i].get_type() == avExoLag || aux_vars[i].get_type() == avDiff)
&& aux_vars[i].get_symb_id() == aux_var_symb_id)
return aux_vars[i].get_orig_symb_id();
throw UnknownSymbolIDException(aux_var_symb_id);
}
expr_t
SymbolTable::getAuxiliaryVarsExprNode(int symb_id) const throw (SearchFailedException)
// throw exception if it is a Lagrange multiplier

View File

@ -281,6 +281,8 @@ public:
Throws an exception if match not found.
*/
int searchAuxiliaryVars(int orig_symb_id, int orig_lead_lag) const throw (SearchFailedException);
//! Serches aux_vars for the aux var represented by aux_var_symb_id and returns its associated orig_symb_id
int getOrigSymbIdForAuxVar(int aux_var_symb_id) const throw (UnknownSymbolIDException);
//! Adds an auxiliary variable when var_model is used with an order that is greater in absolute value
//! than the largest lag present in the model.
int addVarModelEndoLagAuxiliaryVar(int orig_symb_id, int orig_lead_lag, expr_t expr_arg) throw (AlreadyDeclaredException, FrozenException);