Estimation C++ DLL: Enhancing KalmanFilter performance using symmetric matrix and vector BLAS routines calls added or enhanced in BlasBindings.hh and dynblas.h + some minor fixes to Matrix.hh and LogPriorDensity.cc

time-shift
George Perendia 2010-09-02 17:40:52 +01:00
parent b16c56b71c
commit 73fb122e1e
6 changed files with 93 additions and 30 deletions

View File

@ -81,6 +81,11 @@ extern "C" {
CONST_BLDOU a, CONST_BLINT lda, CONST_BLDOU x, CONST_BLINT incx, CONST_BLDOU a, CONST_BLINT lda, CONST_BLDOU x, CONST_BLINT incx,
CONST_BLDOU beta, BLDOU y, CONST_BLINT incy); CONST_BLDOU beta, BLDOU y, CONST_BLINT incy);
#define dsymv FORTRAN_WRAPPER(dsymv)
void dsymv(BLCHAR uplo, CONST_BLINT m, CONST_BLDOU alpha, CONST_BLDOU a,
CONST_BLINT lda, CONST_BLDOU b, CONST_BLINT ldb, CONST_BLDOU beta,
BLDOU c, CONST_BLINT ldc);
#define dtrsv FORTRAN_WRAPPER(dtrsv) #define dtrsv FORTRAN_WRAPPER(dtrsv)
void dtrsv(BLCHAR uplo, BLCHAR trans, BLCHAR diag, CONST_BLINT n, void dtrsv(BLCHAR uplo, BLCHAR trans, BLCHAR diag, CONST_BLINT n,
CONST_BLDOU a, CONST_BLINT lda, BLDOU x, CONST_BLINT incx); CONST_BLDOU a, CONST_BLINT lda, BLDOU x, CONST_BLINT incx);

View File

@ -37,12 +37,12 @@ KalmanFilter::KalmanFilter(const std::string &dynamicDllFile, size_t n_endo, siz
double qz_criterium_arg, const std::vector<size_t> &varobs_arg, double qz_criterium_arg, const std::vector<size_t> &varobs_arg,
double riccati_tol_arg, double lyapunov_tol_arg, int &info) : double riccati_tol_arg, double lyapunov_tol_arg, int &info) :
zeta_varobs_back_mixed(compute_zeta_varobs_back_mixed(zeta_back_arg, zeta_mixed_arg, varobs_arg)), zeta_varobs_back_mixed(compute_zeta_varobs_back_mixed(zeta_back_arg, zeta_mixed_arg, varobs_arg)),
Z(varobs_arg.size(), zeta_varobs_back_mixed.size()), T(zeta_varobs_back_mixed.size()), R(zeta_varobs_back_mixed.size(), n_exo), Z(varobs_arg.size(), zeta_varobs_back_mixed.size()), Zt(Z.getCols(),Z.getRows()), T(zeta_varobs_back_mixed.size()), R(zeta_varobs_back_mixed.size(), n_exo),
Pstar(zeta_varobs_back_mixed.size(), zeta_varobs_back_mixed.size()), Pinf(zeta_varobs_back_mixed.size(), zeta_varobs_back_mixed.size()), Pstar(zeta_varobs_back_mixed.size(), zeta_varobs_back_mixed.size()), Pinf(zeta_varobs_back_mixed.size(), zeta_varobs_back_mixed.size()),
RQRt(zeta_varobs_back_mixed.size(), zeta_varobs_back_mixed.size()), Ptmp(zeta_varobs_back_mixed.size(), zeta_varobs_back_mixed.size()), F(varobs_arg.size(), varobs_arg.size()), RQRt(zeta_varobs_back_mixed.size(), zeta_varobs_back_mixed.size()), Ptmp(zeta_varobs_back_mixed.size(), zeta_varobs_back_mixed.size()), F(varobs_arg.size(), varobs_arg.size()),
Finv(varobs_arg.size(), varobs_arg.size()), K(zeta_varobs_back_mixed.size(), varobs_arg.size()), KFinv(zeta_varobs_back_mixed.size(), varobs_arg.size()), Finv(varobs_arg.size(), varobs_arg.size()), K(zeta_varobs_back_mixed.size(), varobs_arg.size()), KFinv(zeta_varobs_back_mixed.size(), varobs_arg.size()),
oldKFinv(zeta_varobs_back_mixed.size(), varobs_arg.size()), a_init(zeta_varobs_back_mixed.size(), 1), oldKFinv(zeta_varobs_back_mixed.size(), varobs_arg.size()), a_init(zeta_varobs_back_mixed.size()),
a_new(zeta_varobs_back_mixed.size(), 1), vt(varobs_arg.size(), 1), vtFinv(1, varobs_arg.size()), vtFinvVt(1), riccati_tol(riccati_tol_arg), a_new(zeta_varobs_back_mixed.size()), vt(varobs_arg.size()), vtFinv(varobs_arg.size()), riccati_tol(riccati_tol_arg),
initKalmanFilter(dynamicDllFile, n_endo, n_exo, zeta_fwrd_arg, zeta_back_arg, zeta_mixed_arg, initKalmanFilter(dynamicDllFile, n_endo, n_exo, zeta_fwrd_arg, zeta_back_arg, zeta_mixed_arg,
zeta_static_arg, zeta_varobs_back_mixed, qz_criterium_arg, lyapunov_tol_arg, info), zeta_static_arg, zeta_varobs_back_mixed, qz_criterium_arg, lyapunov_tol_arg, info),
FUTP(varobs_arg.size()*(varobs_arg.size()+1)/2) FUTP(varobs_arg.size()*(varobs_arg.size()+1)/2)
@ -55,6 +55,8 @@ KalmanFilter::KalmanFilter(const std::string &dynamicDllFile, size_t n_endo, siz
varobs_arg[i]) - zeta_varobs_back_mixed.begin(); varobs_arg[i]) - zeta_varobs_back_mixed.begin();
Z(i, j) = 1.0; Z(i, j) = 1.0;
} }
mat::transpose(Zt,Z);
} }
std::vector<size_t> std::vector<size_t>
@ -100,7 +102,7 @@ KalmanFilter::compute(const MatrixConstView &dataView, VectorView &steadyState,
double double
KalmanFilter::filter(const MatrixView &detrendedDataView, const Matrix &H, VectorView &vll, size_t start, int &info) KalmanFilter::filter(const MatrixView &detrendedDataView, const Matrix &H, VectorView &vll, size_t start, int &info)
{ {
double loglik=0.0, ll, logFdet, Fdet; double loglik=0.0, ll, logFdet, Fdet, dvtFinvVt;
size_t p = Finv.getRows(); size_t p = Finv.getRows();
bool nonstationary = true; bool nonstationary = true;
for (size_t t = 0; t < detrendedDataView.getCols(); ++t) for (size_t t = 0; t < detrendedDataView.getCols(); ++t)
@ -108,7 +110,7 @@ KalmanFilter::filter(const MatrixView &detrendedDataView, const Matrix &H, Vect
if (nonstationary) if (nonstationary)
{ {
// K=PZ' // K=PZ'
blas::gemm("N", "T", 1.0, Pstar, Z, 0.0, K); blas::symm("L", "U", 1.0, Pstar, Zt, 0.0, K);
//F=ZPZ' +H = ZK+H //F=ZPZ' +H = ZK+H
F = H; F = H;
@ -139,7 +141,7 @@ KalmanFilter::filter(const MatrixView &detrendedDataView, const Matrix &H, Vect
Pstar(i,j)*=0.5; Pstar(i,j)*=0.5;
// K=PZ' // K=PZ'
blas::gemm("N", "T", 1.0, Pstar, Z, 0.0, K); blas::symm("L", "U", 1.0, Pstar, Zt, 0.0, K);
//F=ZPZ' +H = ZK+H //F=ZPZ' +H = ZK+H
F = H; F = H;
@ -160,15 +162,12 @@ KalmanFilter::filter(const MatrixView &detrendedDataView, const Matrix &H, Vect
} }
// KFinv gain matrix // KFinv gain matrix
blas::gemm("N", "N", 1.0, K, Finv, 0.0, KFinv); blas::symm("R", "U", 1.0, Finv, K, 0.0, KFinv);
// deteminant of F: // deteminant of F:
Fdet = 1; Fdet = 1;
for (size_t d = 1; d <= p; ++d) for (size_t d = 1; d <= p; ++d)
Fdet *= FUTP(d + (d-1)*d/2 -1); Fdet *= FUTP(d + (d-1)*d/2 -1);
Fdet *=Fdet;//*pow(-1.0,p); Fdet *=Fdet;
// for (size_t d = 0; d < p; ++d)
// Fdet *= (-F(d, d));
logFdet=log(fabs(Fdet)); logFdet=log(fabs(Fdet));
Ptmp = Pstar; Ptmp = Pstar;
@ -177,7 +176,8 @@ KalmanFilter::filter(const MatrixView &detrendedDataView, const Matrix &H, Vect
blas::gemm("N", "T", -1.0, KFinv, K, 1.0, Ptmp); blas::gemm("N", "T", -1.0, KFinv, K, 1.0, Ptmp);
// 2) Ptmp= T*Ptmp // 2) Ptmp= T*Ptmp
Pstar = Ptmp; Pstar = Ptmp;
blas::gemm("N", "N", 1.0, T, Pstar, 0.0, Ptmp); //blas::gemm("N", "N", 1.0, T, Pstar, 0.0, Ptmp);
blas::symm("R", "U", 1.0, Pstar, T, 0.0, Ptmp);
// 3) Pt+1= Ptmp*T' +RQR' // 3) Pt+1= Ptmp*T' +RQR'
Pstar = RQRt; Pstar = RQRt;
blas::gemm("N", "T", 1.0, Ptmp, T, 1.0, Pstar); blas::gemm("N", "T", 1.0, Ptmp, T, 1.0, Pstar);
@ -192,22 +192,22 @@ KalmanFilter::filter(const MatrixView &detrendedDataView, const Matrix &H, Vect
} }
// err= Yt - Za // err= Yt - Za
MatrixConstView yt(detrendedDataView, 0, t, detrendedDataView.getRows(), 1); // current observation vector VectorConstView yt=mat::get_col(detrendedDataView, t);
vt = yt; vt = yt;
blas::gemm("N", "N", -1.0, Z, a_init, 1.0, vt); blas::gemv("N", -1.0, Z, a_init, 1.0, vt);
// at+1= T(at+ KFinv *err) // at+1= T(at+ KFinv *err)
blas::gemm("N", "N", 1.0, KFinv, vt, 1.0, a_init); blas::gemv("N", 1.0, KFinv, vt, 1.0, a_init);
blas::gemm("N", "N", 1.0, T, a_init, 0.0, a_new); blas::gemv("N", 1.0, T, a_init, 0.0, a_new);
a_init = a_new; a_init = a_new;
/***************** /*****************
Here we calc likelihood and store results. Here we calc likelihood and store results.
*****************/ *****************/
blas::gemm("T", "N", 1.0, vt, Finv, 0.0, vtFinv); blas::symv("U", 1.0, Finv, vt, 0.0, vtFinv);
blas::gemm("N", "N", 1.0, vtFinv, vt, 0.0, vtFinvVt); dvtFinvVt=blas::dot(vtFinv, vt );
ll = -0.5*(p*log(2*M_PI)+logFdet+(*(vtFinvVt.getData()))); ll = -0.5*(p*log(2*M_PI)+logFdet+dvtFinvVt);
vll(t) = ll; vll(t) = ll;
if (t >= start) loglik += ll; if (t >= start) loglik += ll;

View File

@ -61,7 +61,7 @@ public:
private: private:
const std::vector<size_t> zeta_varobs_back_mixed; const std::vector<size_t> zeta_varobs_back_mixed;
static std::vector<size_t> compute_zeta_varobs_back_mixed(const std::vector<size_t> &zeta_back_arg, const std::vector<size_t> &zeta_mixed_arg, const std::vector<size_t> &varobs_arg); static std::vector<size_t> compute_zeta_varobs_back_mixed(const std::vector<size_t> &zeta_back_arg, const std::vector<size_t> &zeta_mixed_arg, const std::vector<size_t> &varobs_arg);
Matrix Z; //nob*mm matrix mapping endogeneous variables and observations Matrix Z, Zt; //nob*mm matrix mapping endogeneous variables and observations and its transpose
Matrix T; //mm*mm transition matrix of the state equation. Matrix T; //mm*mm transition matrix of the state equation.
Matrix R; //mm*rr matrix, mapping structural innovations to state variables. Matrix R; //mm*rr matrix, mapping structural innovations to state variables.
Matrix Pstar; //mm*mm variance-covariance matrix of stationary variables Matrix Pstar; //mm*mm variance-covariance matrix of stationary variables
@ -70,9 +70,9 @@ private:
Matrix RQRt, Ptmp; //mm*mm variance-covariance matrix of variable disturbances Matrix RQRt, Ptmp; //mm*mm variance-covariance matrix of variable disturbances
Matrix F, Finv; // nob*nob F=ZPZt +H an inv(F) Matrix F, Finv; // nob*nob F=ZPZt +H an inv(F)
Matrix K, KFinv, oldKFinv; // mm*nobs K=PZt and K*Finv gain matrices Matrix K, KFinv, oldKFinv; // mm*nobs K=PZt and K*Finv gain matrices
Matrix a_init, a_new; // state vector Vector a_init, a_new; // state vector
Matrix vt; // current observation error vectors Vector vt; // current observation error vectors
Matrix vtFinv, vtFinvVt; // intermeiate observation error *Finv vector Vector vtFinv;// intermediate observation error *Finv vector
double riccati_tol; double riccati_tol;
InitializeKalmanFilter initKalmanFilter; //Initialise KF matrices InitializeKalmanFilter initKalmanFilter; //Initialise KF matrices
Vector FUTP; // F upper triangle packed as vector FUTP(i + (j-1)*j/2) = F(i,j) for 1<=i<=j; Vector FUTP; // F upper triangle packed as vector FUTP(i + (j-1)*j/2) = F(i,j) for 1<=i<=j;

View File

@ -41,7 +41,7 @@ LogPriorDensity::compute(const Vector &ep)
for (size_t i = 0; i < ep.getSize(); ++i) for (size_t i = 0; i < ep.getSize(); ++i)
{ {
logPriorDensity += log(((*(estParsDesc.estParams[i]).prior)).pdf(ep(i))); logPriorDensity += log(((*(estParsDesc.estParams[i]).prior)).pdf(ep(i)));
if (std::isinf(abs(logPriorDensity))) if (std::isinf(fabs(logPriorDensity)))
return logPriorDensity; return logPriorDensity;
} }
return logPriorDensity; return logPriorDensity;

View File

@ -27,6 +27,19 @@
namespace blas namespace blas
{ {
/* Level 1 */
//! dot product of two vectors
template<class Vec1, class Vec2>
inline double
dot(const Vec1 &A, Vec2 &B)
{
assert(A.getSize() == B.getSize());
blas_int n = A.getSize();
blas_int lda = A.getStride(), ldb = B.getStride();
return ddot(&n, A.getData(), &lda, B.getData(), &ldb);
}
/* Level 2 */ /* Level 2 */
//! Symmetric rank 1 operation: A = alpha*X*X' + A //! Symmetric rank 1 operation: A = alpha*X*X' + A
@ -41,6 +54,46 @@ namespace blas
dsyr(uplo, &n, &alpha, X.getData(), &incx, A.getData(), &lda); dsyr(uplo, &n, &alpha, X.getData(), &incx, A.getData(), &lda);
} }
//! General matrix * vector multiplication
// c = alpha*A*b + beta*c, or c := alpha*A'*b + beta*c,
// where alpha and beta are scalars, b and c are vectors and A is an
// m by n matrix.
template<class Mat1, class Vec2, class Vec3>
inline void
gemv(const char *transa, double alpha, const Mat1 &A,
const Vec2 &B, double beta, Vec3 &C)
{
blas_int m = A.getRows(), n = B.getSize(), k = A.getCols(), l = C.getSize();
if (*transa == 'T')
{
m = A.getCols();
k = A.getRows();
}
assert(m == l);
assert(k == n);
blas_int lda = A.getLd(), ldb = B.getStride(), ldc = C.getStride();
dgemv(transa, &m, &n, &alpha, A.getData(), &lda,
B.getData(), &ldb, &beta, C.getData(), &ldc);
}
//! Symmetric matrix * vector multiplication
// c = alpha*A*b + beta*c,
// where alpha and beta are scalars, b and c are vectors and A is a
// m by m symmetric matrix.
template<class Mat1, class Vec2, class Vec3>
inline void
symv(const char *uplo, double alpha, const Mat1 &A,
const Vec2 &B, double beta, Vec3 &C)
{
assert(A.getRows() == A.getCols());
blas_int n = A.getRows();
assert(A.getRows() == B.getSize());
assert(A.getRows() == C.getSize());
blas_int lda = A.getLd(), ldb = B.getStride(), ldc = C.getStride();
dsymv(uplo, &n, &alpha, A.getData(), &lda,
B.getData(), &ldb, &beta, C.getData(), &ldc);
}
/* Level 3 */ /* Level 3 */
//! General matrix multiplication //! General matrix multiplication
@ -90,7 +143,7 @@ namespace blas
B.getData(), &ldb, &beta, C.getData(), &ldc); B.getData(), &ldb, &beta, C.getData(), &ldc);
} }
//! Symmetric matrix multiplication //! Symmetric matrix A * (poss. rectangular) matrix B multiplication
template<class Mat1, class Mat2, class Mat3> template<class Mat1, class Mat2, class Mat3>
inline void inline void
symm(const char *side, const char *uplo, symm(const char *side, const char *uplo,
@ -98,14 +151,19 @@ namespace blas
double beta, Mat3 &C) double beta, Mat3 &C)
{ {
assert(A.getRows() == A.getCols()); assert(A.getRows() == A.getCols());
assert(A.getRows() == C.getRows()); assert(B.getRows() == C.getRows());
assert(A.getCols() == B.getRows());
assert(B.getCols() == C.getCols()); assert(B.getCols() == C.getCols());
blas_int m = A.getRows(), n = B.getCols(); if (*side == 'L' || *side == 'l')
assert(A.getCols() == B.getRows());
else if (*side == 'R' || *side == 'r')
assert(A.getRows() == B.getCols());
blas_int m = B.getRows(), n = B.getCols();
blas_int lda = A.getLd(), ldb = B.getLd(), ldc = C.getLd(); blas_int lda = A.getLd(), ldb = B.getLd(), ldc = C.getLd();
dsymm(side, uplo, &m, &n, &alpha, A.getData(), &lda, dsymm(side, uplo, &m, &n, &alpha, A.getData(), &lda,
B.getData(), &ldb, &beta, C.getData(), &ldc); B.getData(), &ldb, &beta, C.getData(), &ldc);
} }
} // End of namespace } // End of namespace
#endif #endif

View File

@ -207,14 +207,14 @@ namespace mat
inline VectorConstView inline VectorConstView
get_col(const Mat &M, size_t j) get_col(const Mat &M, size_t j)
{ {
return VectorView(M.getData()+j*M.getLd(), M.getRows(), 1); return VectorConstView(M.getData()+j*M.getLd(), M.getRows(), 1);
} }
template<class Mat> template<class Mat>
inline VectorConstView inline VectorConstView
get_row(const Mat &M, size_t i) get_row(const Mat &M, size_t i)
{ {
return VectorView(M.getData()+i, M.getCols(), M.getLd()); return VectorConstView(M.getData()+i, M.getCols(), M.getLd());
} }
template<class Mat1, class Mat2> template<class Mat1, class Mat2>