Adapt block_kalman_filter for Dynare's BLAS/LAPACK framework

time-shift
Sébastien Villemot 2011-09-23 18:15:18 +02:00
parent 16e9c36eb4
commit b21a99d9d2
2 changed files with 12 additions and 59 deletions

View File

@ -33,16 +33,8 @@ using namespace std;
# include "mex_interface.hh"
#endif
#ifdef ATLAS
# include <cblas.h>
#else
# ifdef MKL
# include <mkl_blas.h>
typedef ptrdiff_t blas_int;
# else
# include <dynblas.h>
# endif
#endif
#include <dynblas.h>
#include <dynlapack.h>
#define BLOCK
@ -194,33 +186,6 @@ LIK = sum(lik(start:end)); % Minus the log-likelihood.*/
#if defined(MATLAB_MEX_FILE) && defined(_WIN32)
extern "C" int dgecon(const char *norm, const int *n, double *a, const int *lda, const double *anorm, double *rcond, double *work, int *iwork, int *info);
extern "C" int dgetrf(const int *m, const int *n, double *a, const int *lda, int *lpiv, int *info);
extern "C" double dlange(const char *norm, const int *m, const int *n, const double *a, const int *lda, double *work);
extern "C" int dgetri(const int *n, double* a, const int *lda, const int *lpiv, double* work, const int *lwork, int *info);
extern "C" void dgemm(const char *transa, const char *transb, const int *m, const int *n,
const int *k, const double *alpha, const double *a, const int *lda,
const double *b, const int *ldb, const double *beta,
double *c, const int *ldc);
extern "C" void dsymm(const char *side, const char *uplo, const int *m, const int *n,
const double *alpha, const double *a, const int *lda,
const double *b, const int *ldb, const double *beta,
double *c, const int *ldc);
#else
extern "C" int dgecon_(const char *norm, const int *n, double *a, const int *lda, const double *anorm, double *rcond, double *work, int *iwork, int *info);
extern "C" int dgetrf_(const int *m, const int *n, double *a, const int *lda, int *lpiv, int *info);
extern "C" double dlange_(const char *norm, const int *m, const int *n, const double *a, const int *lda, double *work);
extern "C" int dgetri_(const int *n, double* a, const int *lda, const int *lpiv, double* work, const int *lwork, int *info);
extern "C" void dgemm_(const char *transa, const char *transb, const int *m, const int *n,
const int *k, const double *alpha, const double *a, const int *lda,
const double *b, const int *ldb, const double *beta,
double *c, const int *ldc);
extern "C" void dsymm_(const char *side, const char *uplo, const int *m, const int *n,
const double *alpha, const double *a, const int *lda,
const double *b, const int *ldb, const double *beta,
double *c, const int *ldc);
#endif
bool
not_all_abs_F_bellow_crit(double* F, int size, double crit)
{
@ -289,7 +254,7 @@ mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
/*Defining the initials values*/
int smpl = mxGetN(pY); // Sample size. ;
int n = mxGetN(pT); // Number of state variables.
int pp = mxGetM(pY); // Maximum number of observed variables.
lapack_int pp = mxGetM(pY); // Maximum number of observed variables.
int n_state = n - pp;
@ -393,11 +358,11 @@ mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
double* F = mxGetPr(pF);
mxArray* piF = mxCreateDoubleMatrix(pp, pp, mxREAL);
double* iF = mxGetPr(piF);
int lw = pp * 4;
lapack_int lw = pp * 4;
double* w = (double*)mxMalloc(lw * sizeof(double));
int* iw = (int*)mxMalloc(pp * sizeof(int));
int* ipiv = (int*)mxMalloc(pp * sizeof(int));
int info = 0;
lapack_int* iw = (lapack_int*)mxMalloc(pp * sizeof(lapack_int));
lapack_int* ipiv = (lapack_int*)mxMalloc(pp * sizeof(lapack_int));
lapack_int info = 0;
double anorm, rcond;
#ifdef BLAS
mxArray* p_P_t_t1 = mxCreateDoubleMatrix(n, n, mxREAL);
@ -430,26 +395,14 @@ mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
iF[i + j * pp] = F[i + j * pp] = P[mf[i] + mf[j] * n] + H[i + j * pp];
/* Computes the norm of x */
#if defined(MATLAB_MEX_FILE) && defined(_WIN32)
double anorm = dlange("1", &pp, &pp, iF, &pp, w);
#else
double anorm = dlange_("1", &pp, &pp, iF, &pp, w);
#endif
/* Modifies F in place with a LU decomposition */
#if defined(MATLAB_MEX_FILE) && defined(_WIN32)
dgetrf(&pp, &pp, iF, &pp, ipiv, &info);
#else
dgetrf_(&pp, &pp, iF, &pp, ipiv, &info);
#endif
if (info != 0) fprintf(stderr, "failure with error %d\n", info);
if (info != 0) fprintf(stderr, "failure with error %d\n", (int) info);
/* Computes the reciprocal norm */
#if defined(MATLAB_MEX_FILE) && defined(_WIN32)
dgecon("1", &pp, iF, &pp, &anorm, &rcond, w, iw, &info);
#else
dgecon_("1", &pp, iF, &pp, &anorm, &rcond, w, iw, &info);
#endif
if (rcond < kalman_tol)
if (not_all_abs_F_bellow_crit(F, pp * pp, kalman_tol)) //~all(abs(F(:))<kalman_tol)
@ -554,11 +507,7 @@ mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
//iF = inv(F);
//int lwork = 4/*2*/* pp;
#if defined(MATLAB_MEX_FILE) && defined(_WIN32)
dgetri(&pp, iF, &pp, ipiv, w, &lw, &info);
#else
dgetri_(&pp, iF, &pp, ipiv, w, &lw, &info);
#endif
//lik(t) = log(dF)+transpose(v)*iF*v;
for (int i = 0; i < pp; i++)

View File

@ -235,6 +235,10 @@ extern "C" {
void dgeqp3(CONST_LAINT m, CONST_LAINT n, LADOU a, CONST_LAINT lda, LAINT jpvt, LADOU tau,
LADOU work, CONST_LAINT lwork, LAINT info);
#define dlange FORTRAN_WRAPPER(dlange)
double dlange(LACHAR norm, CONST_LAINT m, CONST_LAINT n, CONST_LADOU a, CONST_LAINT lda,
LADOU work);
#ifdef __cplusplus
} /* extern "C" */
#endif