Implements a Fortran update of the states variance-covariance matrix for the Kalman filter
parent
c74cba4ab0
commit
a245fbb390
|
@ -1,6 +1,6 @@
|
|||
ACLOCAL_AMFLAGS = -I ../../../m4
|
||||
|
||||
SUBDIRS = mjdgges kronecker bytecode block_kalman_filter sobol perfect_foresight_problem num_procs block_trust_region disclyap_fast libkordersim local_state_space_iterations folded_to_unfolded_dr k_order_simul k_order_mean cycle_reduction logarithmic_reduction
|
||||
SUBDIRS = mjdgges kronecker bytecode block_kalman_filter sobol perfect_foresight_problem num_procs block_trust_region disclyap_fast libkordersim local_state_space_iterations folded_to_unfolded_dr k_order_simul k_order_mean cycle_reduction logarithmic_reduction riccati_update
|
||||
|
||||
# libdynare++ must come before gensylv and k_order_perturbation
|
||||
if ENABLE_MEX_DYNAREPLUSPLUS
|
||||
|
|
|
@ -171,7 +171,7 @@ AC_CONFIG_FILES([Makefile
|
|||
block_trust_region/Makefile
|
||||
disclyap_fast/Makefile
|
||||
cycle_reduction/Makefile
|
||||
logarithmic_reduction/Makefile])
|
||||
|
||||
logarithmic_reduction/Makefile
|
||||
riccati_update/Makefile])
|
||||
|
||||
AC_OUTPUT
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
include ../mex.am
|
||||
include ../../riccati_update.am
|
|
@ -1,6 +1,6 @@
|
|||
ACLOCAL_AMFLAGS = -I ../../../m4
|
||||
|
||||
SUBDIRS = mjdgges kronecker bytecode block_kalman_filter sobol perfect_foresight_problem num_procs block_trust_region disclyap_fast libkordersim local_state_space_iterations folded_to_unfolded_dr k_order_simul k_order_mean cycle_reduction logarithmic_reduction
|
||||
SUBDIRS = mjdgges kronecker bytecode block_kalman_filter sobol perfect_foresight_problem num_procs block_trust_region disclyap_fast libkordersim local_state_space_iterations folded_to_unfolded_dr k_order_simul k_order_mean cycle_reduction logarithmic_reduction riccati_update
|
||||
|
||||
# libdynare++ must come before gensylv and k_order_perturbation
|
||||
if ENABLE_MEX_DYNAREPLUSPLUS
|
||||
|
|
|
@ -174,6 +174,7 @@ AC_CONFIG_FILES([Makefile
|
|||
block_trust_region/Makefile
|
||||
disclyap_fast/Makefile
|
||||
cycle_reduction/Makefile
|
||||
logarithmic_reduction/Makefile])
|
||||
logarithmic_reduction/Makefile
|
||||
riccati_update/Makefile])
|
||||
|
||||
AC_OUTPUT
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
EXEEXT = .mex
|
||||
include ../mex.am
|
||||
include ../../riccati_update.am
|
|
@ -0,0 +1,18 @@
|
|||
mex_PROGRAMS = riccati_update
|
||||
|
||||
nodist_riccati_update_SOURCES = \
|
||||
matlab_mex.F08 \
|
||||
blas_lapack.F08 \
|
||||
mexFunction.f08
|
||||
|
||||
BUILT_SOURCES = $(nodist_riccati_update_SOURCES)
|
||||
CLEANFILES = $(nodist_riccati_update_SOURCES)
|
||||
|
||||
mexFunction.o: matlab_mex.mod blas.mod lapack.mod
|
||||
|
||||
mexFunction.mod: mexFunction.o
|
||||
|
||||
%.f08: $(top_srcdir)/../../sources/riccati_update/%.f08
|
||||
$(LN_S) -f $< $@
|
||||
%.F08: $(top_srcdir)/../../sources/riccati_update/%.F08
|
||||
$(LN_S) -f $< $@
|
|
@ -27,7 +27,8 @@ EXTRA_DIST = \
|
|||
block_trust_region \
|
||||
disclyap_fast \
|
||||
cycle_reduction \
|
||||
logarithmic_reduction
|
||||
logarithmic_reduction \
|
||||
riccati_update
|
||||
|
||||
clean-local:
|
||||
rm -rf `find mex/sources -name *.o`
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
! Copyright © 2022 Dynare Team
|
||||
!
|
||||
! This file is part of Dynare.
|
||||
!
|
||||
! Dynare is free software: you can redistribute it and/or modify
|
||||
! it under the terms of the GNU General Public License as published by
|
||||
! the Free Software Foundation, either version 3 of the License, or
|
||||
! (at your option) any later version.
|
||||
!
|
||||
! Dynare is distributed in the hope that it will be useful,
|
||||
! but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
! GNU General Public License for more details.
|
||||
!
|
||||
! You should have received a copy of the GNU General Public License
|
||||
! along with Dynare. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
! Implements Ptmp = T*(P-K*Z*P)*transpose(T)+Q where
|
||||
! P is the (r x r) variance-covariance matrix of the state vector
|
||||
! T is the (r x r) transition matrix of the state vector
|
||||
! K is the (r x n) gain matrix
|
||||
! Z is the (n x r) matrix linking observable variables to state variables
|
||||
! Q is the (r x r) variance-covariance matrix of innovations in the state equation
|
||||
! and accounting for different properties:
|
||||
! P is a (symmetric) positive semi-definite matrix
|
||||
! T can be triangular
|
||||
|
||||
subroutine mexFunction(nlhs, plhs, nrhs, prhs) bind(c, name='mexFunction')
|
||||
use matlab_mex
|
||||
use blas
|
||||
implicit none
|
||||
|
||||
type(c_ptr), dimension(*), intent(in), target :: prhs
|
||||
type(c_ptr), dimension(*), intent(out) :: plhs
|
||||
integer(c_int), intent(in), value :: nlhs, nrhs
|
||||
|
||||
real(real64), dimension(:,:), pointer, contiguous :: P, T, K, Z, Q, Pnew
|
||||
real(real64), dimension(:,:), allocatable :: tmp1, tmp2
|
||||
integer :: i, j, n, r
|
||||
character(kind=c_char, len=2) :: num2str
|
||||
|
||||
! 0. Checking the consistency and validity of input arguments
|
||||
if (nrhs /= 5_c_int) then
|
||||
call mexErrMsgTxt("Must have 5 input arguments")
|
||||
end if
|
||||
if (nlhs > 1_c_int) then
|
||||
call mexErrMsgTxt("Too many output arguments")
|
||||
end if
|
||||
|
||||
do i=1,5
|
||||
if (.not. (c_associated(prhs(i)) .and. mxIsDouble(prhs(i)) .and. &
|
||||
(.not. mxIsComplex(prhs(i))) .and. (.not. mxIsSparse(prhs(i))))) then
|
||||
write (num2str,"(i2)") i
|
||||
call mexErrMsgTxt("Argument " // trim(num2str) // " should be a real dense matrix")
|
||||
end if
|
||||
end do
|
||||
|
||||
r = int(mxGetM(prhs(1))) ! Number of states
|
||||
n = int(mxGetN(prhs(3))) ! Number of observables
|
||||
|
||||
if ((r /= mxGetN(prhs(1))) & ! Number of columns of P
|
||||
&.or. (r /= mxGetM(prhs(2))) & ! Number of lines of T
|
||||
&.or. (r /= mxGetN(prhs(2))) & ! Number of columns of T
|
||||
&.or. (r /= mxGetM(prhs(3))) & ! Number of lines of K
|
||||
&.or. (n /= mxGetM(prhs(4))) & ! Number of lines of Z
|
||||
&.or. (r /= mxGetN(prhs(4))) & ! Number of columns of Z
|
||||
&.or. (r /= mxGetM(prhs(5))) & ! Number of lines of Q
|
||||
&.or. (r /= mxGetN(prhs(5))) & ! Number of columns of Q
|
||||
) then
|
||||
call mexErrMsgTxt("Input dimension mismatch")
|
||||
end if
|
||||
|
||||
! 1. Storing the relevant information in Fortran format
|
||||
P(1:r,1:r) => mxGetPr(prhs(1))
|
||||
T(1:r,1:r) => mxGetPr(prhs(2))
|
||||
K(1:r,1:n) => mxGetPr(prhs(3))
|
||||
Z(1:n,1:r) => mxGetPr(prhs(4))
|
||||
Q(1:r,1:r) => mxGetPr(prhs(5))
|
||||
|
||||
plhs(1) = mxCreateDoubleMatrix(int(r, mwSize), int(r, mwSize), mxREAL)
|
||||
Pnew(1:r, 1:r) => mxGetPr(plhs(1))
|
||||
|
||||
! 2. Computing the Riccati update of the P matrix
|
||||
allocate(tmp1(r,r), tmp2(r,r))
|
||||
! Pnew <- Q
|
||||
Pnew = Q
|
||||
! tmp1 <- K*Z
|
||||
call matmul_add("N", "N", 1._real64, K, Z, 0._real64, tmp1)
|
||||
! tmp2 <- P
|
||||
tmp2 = P
|
||||
! tmp2 <- tmp2 - tmp1*P
|
||||
call matmul_add("N", "N", -1._real64, tmp1, P, 1._real64, tmp2)
|
||||
! tmp1 <- T*tmp2
|
||||
call matmul_add("N", "N", 1._real64, T, tmp2, 0._real64, tmp1)
|
||||
! Pnew <- tmp1*T' + Pnew
|
||||
call matmul_add("N", "T", 1._real64, tmp1, T, 1._real64, Pnew)
|
||||
|
||||
end subroutine mexFunction
|
|
@ -1227,8 +1227,8 @@ M_TRS_FILES += run_block_byte_tests_matlab.m.trs \
|
|||
run_kronecker_tests.m.trs \
|
||||
nonlinearsolvers.m.trs \
|
||||
cyclereduction.m.trs \
|
||||
logarithmicreduction.m.trs
|
||||
|
||||
logarithmicreduction.m.trs \
|
||||
riccatiupdate.m.trs
|
||||
|
||||
M_XFAIL_TRS_FILES = $(patsubst %.mod, %.m.trs, $(XFAIL_MODFILES))
|
||||
|
||||
|
@ -1242,8 +1242,8 @@ O_TRS_FILES += run_block_byte_tests_octave.o.trs \
|
|||
run_kronecker_tests.o.trs \
|
||||
nonlinearsolvers.o.trs \
|
||||
cyclereduction.o.trs \
|
||||
logarithmicreduction.o.trs
|
||||
|
||||
logarithmicreduction.o.trs \
|
||||
riccatiupdate.o.trs
|
||||
|
||||
O_XFAIL_TRS_FILES = $(patsubst %.mod, %.o.trs, $(XFAIL_MODFILES))
|
||||
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
debug = true;
|
||||
|
||||
if debug
|
||||
[top_test_dir, ~, ~] = fileparts(mfilename('fullpath'));
|
||||
else
|
||||
top_test_dir = getenv('TOP_TEST_DIR');
|
||||
end
|
||||
|
||||
addpath([top_test_dir filesep '..' filesep 'matlab']);
|
||||
|
||||
if ~debug
|
||||
% Test Dynare Version
|
||||
if ~strcmp(dynare_version(), getenv('DYNARE_VERSION'))
|
||||
error('Incorrect version of Dynare is being tested')
|
||||
end
|
||||
end
|
||||
|
||||
dynare_config;
|
||||
|
||||
NumberOfTests = 0;
|
||||
testFailed = 0;
|
||||
|
||||
if ~debug
|
||||
skipline()
|
||||
disp('*** TESTING: riccatiupdate.m ***');
|
||||
end
|
||||
|
||||
if isoctave
|
||||
addpath([top_test_dir filesep '..' filesep 'mex' filesep 'octave']);
|
||||
else
|
||||
addpath([top_test_dir filesep '..' filesep 'mex' filesep 'matlab']);
|
||||
end
|
||||
|
||||
t0 = clock;
|
||||
|
||||
% Set the number of experiments for time measurement
|
||||
N = 5000;
|
||||
% Set the dimension of the problem to be solved.
|
||||
r = 50;
|
||||
n = 100;
|
||||
tol = 1e-15;
|
||||
% Set the input arguments
|
||||
% P, Q: use the fact that for any real matrix A, A'*A is positive semidefinite
|
||||
P = rand(n,r);
|
||||
P = P'*P;
|
||||
Q = rand(n,r);
|
||||
Q = Q'*Q;
|
||||
K = rand(r,n);
|
||||
Z = rand(n,r);
|
||||
T = rand(r,r);
|
||||
|
||||
% 1. Update the state vairance-covariance matrix with Matlab
|
||||
tElapsed1 = 0.;
|
||||
tic;
|
||||
for i=1:N
|
||||
Ptmp_matlab = T*(P-K*Z*P)*transpose(T)+Q;
|
||||
end
|
||||
tElapsed1 = toc;
|
||||
disp(['Elapsed time for the Matlab Riccati update is: ' num2str(tElapsed1) ' (N=' int2str(N) ').'])
|
||||
|
||||
% 2. Update the state varance-covariance matrix with the mex routine
|
||||
NumberOfTests = NumberOfTests+1;
|
||||
tElapsed2 = 0.;
|
||||
Ptmp_fortran = P;
|
||||
try
|
||||
tic;
|
||||
for i=1:N
|
||||
Ptmp_fortran = riccati_update(P, T, K, Z, Q);
|
||||
end
|
||||
tElapsed2 = toc;
|
||||
disp(['Elapsed time for the Fortran Riccati update is: ' num2str(tElapsed2) ' (N=' int2str(N) ').'])
|
||||
R = norm(Ptmp_fortran-Ptmp_matlab,1);
|
||||
if (R > tol)
|
||||
testFailed = testFailed+1;
|
||||
if debug
|
||||
dprintf('The Fortran Riccati update is wrong')
|
||||
end
|
||||
end
|
||||
catch
|
||||
testFailed = testFailed+1;
|
||||
if debug
|
||||
dprintf('Fortran Riccati update failed')
|
||||
end
|
||||
end
|
||||
|
||||
% Compare the Fortran and Matlab execution time
|
||||
if debug
|
||||
if tElapsed1<tElapsed2
|
||||
skipline()
|
||||
dprintf('Matlab Riccati update is %5.2f times faster than its Fortran counterpart.', tElapsed2/tElapsed1)
|
||||
skipline()
|
||||
else
|
||||
skipline()
|
||||
dprintf('Fortran Riccati update is %5.2f times faster than its Matlab counterpart.', tElapsed1/tElapsed2)
|
||||
skipline()
|
||||
end
|
||||
end
|
||||
|
||||
% Compare results after multiple calls
|
||||
N = 50;
|
||||
disp(['After 1 update using the Riccati formula, the norm-1 discrepancy is ' num2str(norm(Ptmp_fortran-Ptmp_matlab,1)) '.']);
|
||||
for i=2:N
|
||||
Ptmp_matlab_ini = Ptmp_matlab;
|
||||
Ptmp_fortran_ini = Ptmp_fortran;
|
||||
Ptmp_matlab = T*(Ptmp_matlab_ini-K*Z*Ptmp_matlab_ini)*transpose(T)+Q;
|
||||
Ptmp_fortran = riccati_update(Ptmp_fortran_ini, T, K, Z, Q);
|
||||
end
|
||||
disp(['After ' int2str(N) ' updates using the Riccati formula, the norm-1 discrepancy is ' num2str(norm(Ptmp_fortran-Ptmp_matlab,1)) '.'])
|
||||
|
||||
t1 = clock;
|
||||
|
||||
if ~debug
|
||||
cd(getenv('TOP_TEST_DIR'));
|
||||
else
|
||||
dprintf('FAILED tests: %i', testFailed)
|
||||
end
|
||||
|
||||
if isoctave
|
||||
fid = fopen('riccatiupdate.o.trs', 'w+');
|
||||
else
|
||||
fid = fopen('riccatiupdate.m.trs', 'w+');
|
||||
end
|
||||
if testFailed
|
||||
fprintf(fid,':test-result: FAIL\n');
|
||||
else
|
||||
fprintf(fid,':test-result: PASS\n');
|
||||
end
|
||||
fprintf(fid,':number-tests: %i\n', NumberOfTests);
|
||||
fprintf(fid,':number-failed-tests: %i\n', testFailed);
|
||||
fprintf(fid,':list-of-passed-tests: riccatiupdate.m\n');
|
||||
fprintf(fid,':elapsed-time: %f\n', etime(t1, t0));
|
||||
fclose(fid);
|
||||
|
||||
if ~debug
|
||||
exit;
|
||||
end
|
||||
|
Loading…
Reference in New Issue