pooled_ols: fix up and use common parsing

time-shift
Houtan Bastani 2019-01-14 15:23:30 +01:00
parent 1e41235b6b
commit 5d80fc903b
No known key found for this signature in database
GPG Key ID: 000094FB955BE169
1 changed files with 58 additions and 71 deletions

View File

@ -21,7 +21,7 @@ function varargout = pooled_ols(ds, param_common, param_regex, overlapping_dates
% SPECIAL REQUIREMENTS
% dynare must be run with the option: json=compute
% Copyright (C) 2017-2018 Dynare Team
% Copyright (C) 2017-2019 Dynare Team
%
% This file is part of Dynare.
%
@ -50,8 +50,13 @@ if isempty(param_common) && isempty(param_regex)
else
dyn_ols(ds, {}, eqtags);
end
return;
return
end
if nargin < 5
eqtags = {};
end
assert(~isempty(param_common) && iscellstr(param_common), 'The second argument must be a cellstr');
assert(~isempty(param_regex) && iscellstr(param_regex), 'The third argument must be a cellstr');
@ -61,39 +66,17 @@ else
assert(islogical(overlapping_dates) && length(overlapping_dates) == 1, 'The fourth argument must be a bool');
end
%% Read JSON
jsonfile = [M_.fname filesep() 'model' filesep() 'json' filesep() 'modfile-original.json'];
if exist(jsonfile, 'file') ~= 2
error('Could not find %s! Please use the json=compute option (See the Dynare invocation section in the reference manual).', jsonfile);
end
jsonmodel = loadjson(jsonfile);
jsonmodel = jsonmodel.model;
if nargin < 5
eqtags ={};
else
jsonmodel = getEquationsByTags(jsonmodel, 'name', eqtags);
end
%% Get Equation(s)
[ast, jsonmodel] = get_ast_jsonmodel(eqtags);
neqs = length(jsonmodel);
%% Replace parameter names in equations
country_name = param_common{1};
regexcountries = ['(' strjoin(param_common(2:end),'|') ')'];
param_regex_idx = false(length(param_regex), 1);
for i = 1:length(param_regex)
splitp = strsplit(param_regex{i}, '*');
assert(length(splitp) >= 2);
for j = 1:length(jsonmodel)
rhstmp = regexprep(jsonmodel{j}.rhs, ...
strjoin(splitp, regexcountries), ...
strjoin(splitp, country_name));
if length(intersect(jsonmodel{j}.rhs, rhstmp)) ~= length(jsonmodel{j}.rhs)
jsonmodel{j}.rhs = rhstmp;
param_regex_idx(i) = true;
end
end
end
param_regex = param_regex(param_regex_idx);
ast = replace_parameters(ast, country_name, regexcountries, param_regex);
%% Handle FGLS
st = dbstack(1);
save_structure_name = 'pooled_ols';
if strcmp(st(1).name, 'pooled_fgls')
@ -102,46 +85,23 @@ if strcmp(st(1).name, 'pooled_fgls')
end
%% Find parameters and variable names in every equation & Setup estimation matrices
[X, Y, startdates, enddates, startidxs, residnames, pbeta, vars, surpidxs, surconstrainedparams] = ...
pooled_sur_common(ds, jsonmodel);
[Y, ~, X] = common_parsing(ds, ast, jsonmodel, overlapping_dates);
clear ast jsonmodel;
nobs = Y{1}.nobs;
[Y, X] = put_in_sur_form(Y, X);
if overlapping_dates
maxfp = max([startdates{:}]);
minlp = min([enddates{:}]);
nobs = minlp - maxfp;
newY = zeros(nobs*length(jsonmodel), 1);
newX = zeros(nobs*length(jsonmodel), columns(X));
newstartidxs = zeros(size(startidxs));
newstartidxs(1) = 1;
for i = 1:length(jsonmodel)
if i == length(jsonmodel)
yds = dseries(Y(startidxs(i):end), startdates{i});
xds = dseries(X(startidxs(i):end, :), startdates{i});
else
yds = dseries(Y(startidxs(i):startidxs(i+1)-1), startdates{i});
xds = dseries(X(startidxs(i):startidxs(i+1)-1, :), startdates{i});
end
newY(newstartidxs(i):newstartidxs(i) + nobs, 1) = yds(maxfp:minlp).data;
newX(newstartidxs(i):newstartidxs(i) + nobs, :) = xds(maxfp:minlp, :).data;
if i ~= length(jsonmodel)
newstartidxs(i+1) = newstartidxs(i) + nobs + 1;
end
end
Y = newY;
X = newX;
startidxs = newstartidxs;
oo_.(save_structure_name).sample_range = maxfp:minlp;
oo_.(save_structure_name).residnames = residnames;
oo_.(save_structure_name).Y = Y;
oo_.(save_structure_name).X = X;
oo_.(save_structure_name).pbeta = pbeta;
oo_.(save_structure_name).country_name = country_name;
end
%% Save
oo_.(save_structure_name).sample_range = X.firstdate:X.firstdate+nobs;
%oo_.(save_structure_name).residnames = residnames;
oo_.(save_structure_name).Y = Y.data;
oo_.(save_structure_name).X = X.data;
oo_.(save_structure_name).pbeta = X.name;
oo_.(save_structure_name).country_name = country_name;
%% Estimation
% Estimated Parameters
[q, r] = qr(X, 0);
oo_.(save_structure_name).beta = r\(q'*Y);
[q, r] = qr(X.data, 0);
oo_.(save_structure_name).beta = r\(q'*Y.data);
if strcmp(st(1).name, 'pooled_fgls')
return
@ -149,9 +109,9 @@ end
% Assign parameter values back to parameters using param_regex & param_common
regexcountries = ['(' strjoin(param_common(1:end),'|') ')'];
assigned_idxs = false(size(pbeta));
assigned_idxs = false(size(X.name));
for i = 1:length(param_regex)
beta_idx = strcmp(pbeta, strrep(param_regex{i}, '*', country_name));
beta_idx = strcmp(X.name, strrep(param_regex{i}, '*', country_name));
assigned_idxs = assigned_idxs | beta_idx;
value = oo_.(save_structure_name).beta(beta_idx);
if isempty(eqtags)
@ -164,15 +124,15 @@ for i = 1:length(param_regex)
end
idxs = find(assigned_idxs == 0);
values = oo_.(save_structure_name).beta(idxs);
names = pbeta(idxs);
names = X.name(idxs);
assert(length(values) == length(names));
for i = 1:length(idxs)
M_.params(strcmp(M_.param_names, names{i})) = values(i);
end
residuals = Y - X * oo_.(save_structure_name).beta;
for i = 1:length(jsonmodel)
if i == length(jsonmodel)
residuals = Y.data - X.data * oo_.(save_structure_name).beta;
for i = 1:neqs
if i == neqs
oo_.(save_structure_name).resid.(residnames{i}{:}) = residuals(startidxs(i):end);
else
oo_.(save_structure_name).resid.(residnames{i}{:}) = residuals(startidxs(i):startidxs(i+1)-1);
@ -182,3 +142,30 @@ for i = 1:length(jsonmodel)
M_.Sigma_e(idx, idx) = var(oo_.(save_structure_name).resid.(residnames{i}{:}));
end
end
function ast = replace_parameters(ast, country_name, regexcountries, param_regex)
for i = 1:length(ast)
ast{i}.AST = replace_parameters_recursive(ast{i}.AST, country_name, regexcountries, param_regex);
end
end
function node = replace_parameters_recursive(node, country_name, regexcountries, param_regex)
if strcmp(node.node_type, 'VariableNode')
if strcmp(node.type, 'parameter')
for i = 1:length(param_regex)
splitp = strsplit(param_regex{i}, '*');
assert(length(splitp) >= 2);
tmp = regexprep(node.name, strjoin(splitp, regexcountries), strjoin(splitp, country_name));
if ~strcmp(tmp, node.name)
node.name = tmp;
return
end
end
end
elseif strcmp(node.node_type, 'UnaryOpNode')
node.arg = replace_parameters_recursive(node.arg, country_name, regexcountries, param_regex);
elseif strcmp(node.node_type, 'BinaryOpNode')
node.arg1 = replace_parameters_recursive(node.arg1, country_name, regexcountries, param_regex);
node.arg2 = replace_parameters_recursive(node.arg2, country_name, regexcountries, param_regex);
end
end