simplify code to find autoregressive matrix for VARs

issue#70
Houtan Bastani 2019-03-14 17:20:45 +01:00
parent 0544141545
commit a3c08f932e
No known key found for this signature in database
GPG Key ID: 000094FB955BE169
6 changed files with 128 additions and 15 deletions

View File

@ -4043,9 +4043,7 @@ DynamicModel::fillVarModelTable() const
var_model_table.setLhsExprT(lhs_expr_tr);
// Fill AR Matrix
map<string, map<tuple<int, int, int>, expr_t>> ARr;
fillAutoregressiveMatrix(ARr, false);
var_model_table.setAR(ARr);
var_model_table.setAR(fillAutoregressiveMatrixForVAR());
}
void
@ -4115,17 +4113,32 @@ DynamicModel::fillVarModelTableFromOrigModel(StaticModel &static_model) const
var_model_table.setOrigDiffVar(orig_diff_var);
}
void
DynamicModel::fillAutoregressiveMatrix(map<string, map<tuple<int, int, int>, expr_t>> &ARr, bool is_trend_component_model) const
map<string, map<tuple<int, int, int>, expr_t>>
DynamicModel::fillAutoregressiveMatrixForVAR() const
{
auto eqnums = is_trend_component_model ?
trend_component_model_table.getNonTargetEqNums() : var_model_table.getEqNums();
for (const auto & it : eqnums)
map<string, map<tuple<int, int, int>, expr_t>> ARr;
for (const auto & it : var_model_table.getEqNums())
{
int i = 0;
map<tuple<int, int, int>, expr_t> AR;
vector<int> lhs = is_trend_component_model ?
trend_component_model_table.getNonTargetLhs(it.first) : var_model_table.getLhs(it.first);
for (auto eqn : it.second)
{
auto *bopn = dynamic_cast<BinaryOpNode *>(equations[eqn]->arg2);
bopn->fillAutoregressiveRowForVAR(i++, var_model_table.getLhsOrigIds(it.first), AR);
}
ARr[it.first] = AR;
}
return ARr;
}
void
DynamicModel::fillAutoregressiveMatrix(map<string, map<tuple<int, int, int>, expr_t>> &ARr) const
{
for (const auto & it : trend_component_model_table.getNonTargetEqNums())
{
int i = 0;
map<tuple<int, int, int>, expr_t> AR;
vector<int> lhs = trend_component_model_table.getNonTargetLhs(it.first);
for (auto eqn : it.second)
equations[eqn]->arg2->fillAutoregressiveRow(i++, lhs, AR);
ARr[it.first] = AR;
@ -4335,7 +4348,7 @@ void
DynamicModel::fillTrendComponentmodelTableAREC(ExprNode::subst_table_t &diff_subst_table) const
{
map<string, map<tuple<int, int, int>, expr_t>> ARr, A0r, A0starr;
fillAutoregressiveMatrix(ARr, true);
fillAutoregressiveMatrix(ARr);
trend_component_model_table.setAR(ARr);
fillErrorComponentMatrix(A0r, A0starr, diff_subst_table);
trend_component_model_table.setA0(A0r, A0starr);

View File

@ -316,8 +316,11 @@ public:
//! Set the equations that have non-zero second derivatives
void setNonZeroHessianEquations(map<int, string> &eqs);
//! Fill Autoregressive Matrix for var_model/trend_component_model
void fillAutoregressiveMatrix(map<string, map<tuple<int, int, int>, expr_t>> &ARr, bool is_trend_component_model) const;
//! Fill Autoregressive Matrix for trend_component_model
void fillAutoregressiveMatrix(map<string, map<tuple<int, int, int>, expr_t>> &ARr) const;
//! Fill Autoregressive Matrix for var_model
map<string, map<tuple<int, int, int>, expr_t>> fillAutoregressiveMatrixForVAR() const;
//! Fill Error Component Matrix for trend_component_model
void fillErrorComponentMatrix(map<string, map<tuple<int, int, int>, expr_t>> &A0r, map<string, map<tuple<int, int, int>, expr_t>> &A0starr, ExprNode::subst_table_t &diff_subst_table) const;

View File

@ -5918,6 +5918,74 @@ BinaryOpNode::fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<i
arg2->fillAutoregressiveRow(eqn, lhs, AR);
}
void
BinaryOpNode::fillAutoregressiveRowForVAR(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const
{
vector<pair<expr_t, int>> terms;
decomposeAdditiveTerms(terms, 1);
for (const auto & it : terms)
{
auto bopn = dynamic_cast<BinaryOpNode *>(it.first);
if (bopn != nullptr)
{
auto vn1 = dynamic_cast<VariableNode *>(bopn->arg1);
auto vn2 = dynamic_cast<VariableNode *>(bopn->arg2);
if (vn1 && vn2)
{
int vid, lag;
VariableNode *param;
vid = -1;
if (datatree.symbol_table.getType(vn1->symb_id) == SymbolType::parameter
&& (datatree.symbol_table.getType(vn2->symb_id) == SymbolType::endogenous
|| datatree.symbol_table.getType(vn2->symb_id) == SymbolType::exogenous))
{
param = vn1;
vid = vn2->symb_id;
lag = vn2->lag;
}
else if (datatree.symbol_table.getType(vn2->symb_id) == SymbolType::parameter
&& (datatree.symbol_table.getType(vn1->symb_id) == SymbolType::endogenous
|| datatree.symbol_table.getType(vn1->symb_id) == SymbolType::exogenous))
{
param = vn2;
vid = vn1->symb_id;
lag = vn1->lag;
}
if (vid >= 0)
{
int vidineq = vid;
while (datatree.symbol_table.isAuxiliaryVariable(vid))
try
{
vid = datatree.symbol_table.getOrigSymbIdForAuxVar(vid);
}
catch (...)
{
break;
}
if (vidineq != vid)
{
vid = datatree.symbol_table.getOrigSymbIdForDiffAuxVar(vidineq);
lag = -datatree.symbol_table.getOrigLeadLagForDiffAuxVar(vidineq);
}
if (find(lhs.begin(), lhs.end(), vid) == lhs.end())
continue;
if (AR.find({eqn, lag, vid}) != AR.end())
{
cerr << "BinaryOpNode::fillAutoregressiveRow: Error filling AR matrix: "
<< "lag/symb_id encountered more than once in equtaion" << endl;
exit(EXIT_FAILURE);
}
AR[{eqn, -lag, vid}] = param;
}
}
}
}
}
void
BinaryOpNode::fillErrorCorrectionRowHelper(expr_t arg1, expr_t arg2,
int eqn,

View File

@ -1041,6 +1041,7 @@ public:
void fillAutoregressiveRowHelper(expr_t arg1, expr_t arg2,
int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const;
void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
void fillAutoregressiveRowForVAR(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const;
void fillErrorCorrectionRowHelper(expr_t arg1, expr_t arg2,
int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs,
map<tuple<int, int, int>, expr_t> &A0, map<tuple<int, int, int>, expr_t> &A0star) const;

View File

@ -520,7 +520,7 @@ VarModelTable::writeOutput(const string &basename, ostream &output) const
i++;
}
vector<int> lhs = getLhs(name);
vector<int> lhs = getLhsOrigIds(name);
ar_output << "if strcmp(model_name, '" << name << "')" << endl
<< " ar = zeros(" << lhs.size() << ", " << lhs.size() << ", " << getMaxLag(name) << ");" << endl;
for (const auto & it : AR.at(name))
@ -600,6 +600,26 @@ void
VarModelTable::setLhs(map<string, vector<int>> lhs_arg)
{
lhs = move(lhs_arg);
for (auto it : lhs)
{
vector<int> lhsvec;
for (auto ids : it.second)
{
int lhs_last_orig_symb_id = ids;
int lhs_orig_symb_id = ids;
if (symbol_table.isAuxiliaryVariable(lhs_orig_symb_id))
try
{
lhs_last_orig_symb_id = lhs_orig_symb_id;
lhs_orig_symb_id = symbol_table.getOrigSymbIdForAuxVar(lhs_orig_symb_id);
}
catch (...)
{
}
lhsvec.emplace_back(lhs_last_orig_symb_id);
}
lhs_orig_symb_ids[it.first] = lhsvec;
}
}
void
@ -681,6 +701,13 @@ VarModelTable::getLhs(const string &name_arg) const
return lhs.find(name_arg)->second;
}
vector<int>
VarModelTable::getLhsOrigIds(const string &name_arg) const
{
checkModelName(name_arg);
return lhs_orig_symb_ids.find(name_arg)->second;
}
vector<set<pair<int, int>>>
VarModelTable::getRhs(const string &name_arg) const
{

View File

@ -118,7 +118,7 @@ private:
set<string> names;
map<string, pair<SymbolList, int>> symbol_list_and_order;
map<string, vector<string>> eqtags;
map<string, vector<int>> eqnums, max_lags, lhs, orig_diff_var;
map<string, vector<int>> eqnums, max_lags, lhs, lhs_orig_symb_ids, orig_diff_var;
map<string, vector<set<pair<int, int>>>> rhs;
map<string, vector<bool>> diff;
map<string, vector<expr_t>> lhs_expr_t;
@ -141,6 +141,7 @@ public:
vector<int> getMaxLags(const string &name_arg) const;
int getMaxLag(const string &name_arg) const;
vector<int> getLhs(const string &name_arg) const;
vector<int> getLhsOrigIds(const string &name_arg) const;
map<string, pair<SymbolList, int>> getSymbolListAndOrder() const;
vector<set<pair<int, int>>> getRhs(const string &name_arg) const;
vector<expr_t> getLhsExprT(const string &name_arg) const;