diff --git a/mex/sources/block_kalman_filter/block_kalman_filter.cc b/mex/sources/block_kalman_filter/block_kalman_filter.cc index 207f9e596..0c44b0b6a 100644 --- a/mex/sources/block_kalman_filter/block_kalman_filter.cc +++ b/mex/sources/block_kalman_filter/block_kalman_filter.cc @@ -26,10 +26,13 @@ #endif #include "block_kalman_filter.h" using namespace std; -//#define BLAS -#define DIRECT - +#define BLAS +//#define CUBLAS +#ifdef CUBLAS + #include + #include +#endif void mexDisp(mxArray* P) { @@ -157,7 +160,7 @@ BlockKalmanFilter::BlockKalmanFilter(int nlhs, mxArray *plhs[], int nrhs, const if (missing_observations) { if (! mxIsCell (prhs[0])) - DYN_MEX_FUNC_ERR_MSG_TXT("the first input argument of block_missing_observations_kalman_filter must be a Call Array."); + DYN_MEX_FUNC_ERR_MSG_TXT("the first input argument of block_missing_observations_kalman_filter must be a Cell Array."); pdata_index = prhs[0]; if (! mxIsDouble (prhs[1])) DYN_MEX_FUNC_ERR_MSG_TXT("the second input argument of block_missing_observations_kalman_filter must be a scalar."); @@ -234,14 +237,13 @@ BlockKalmanFilter::BlockKalmanFilter(int nlhs, mxArray *plhs[], int nrhs, const *a = mxGetPr(pa); tmp_a = (double*)mxMalloc(n * sizeof(double)); dF = 0.0; // det(F). - p_tmp = mxCreateDoubleMatrix(n, n_state, mxREAL); - *tmp = mxGetPr(p_tmp); + p_tmp1 = mxCreateDoubleMatrix(n, n_shocks, mxREAL); tmp1 = mxGetPr(p_tmp1); t = 0; // Initialization of the time index. plik = mxCreateDoubleMatrix(smpl, 1, mxREAL); lik = mxGetPr(plik); - Inf = mxGetInf(); + Inf = mxGetInf(); LIK = 0.0; // Default value of the log likelihood. notsteady = true; // Steady state flag. F_singular = true; @@ -287,6 +289,22 @@ BlockKalmanFilter::BlockKalmanFilter(int nlhs, mxArray *plhs[], int nrhs, const iw = (lapack_int*)mxMalloc(pp * sizeof(lapack_int)); ipiv = (lapack_int*)mxMalloc(pp * sizeof(lapack_int)); info = 0; +#ifdef BLAS || CUBLAS + p_tmp = mxCreateDoubleMatrix(n, n, mxREAL); + *tmp = mxGetPr(p_tmp); + p_P_t_t1 = mxCreateDoubleMatrix(n, n, mxREAL); + *P_t_t1 = mxGetPr(p_P_t_t1); + pK = mxCreateDoubleMatrix(n, n, mxREAL); + *K = mxGetPr(pK); + p_K_P = mxCreateDoubleMatrix(n, n, mxREAL); + *K_P = mxGetPr(p_K_P); + oldK = (double*)mxMalloc(n * n * sizeof(double)); + *P_mf = (double*)mxMalloc(n * n * sizeof(double)); + for (int i = 0; i < n * n; i++) + oldK[i] = Inf; +#else + p_tmp = mxCreateDoubleMatrix(n, n_state, mxREAL); + *tmp = mxGetPr(p_tmp); p_P_t_t1 = mxCreateDoubleMatrix(n_state, n_state, mxREAL); *P_t_t1 = mxGetPr(p_P_t_t1); pK = mxCreateDoubleMatrix(n, pp, mxREAL); @@ -297,6 +315,7 @@ BlockKalmanFilter::BlockKalmanFilter(int nlhs, mxArray *plhs[], int nrhs, const *P_mf = (double*)mxMalloc(n * pp * sizeof(double)); for (int i = 0; i < n * pp; i++) oldK[i] = Inf; +#endif } void @@ -424,17 +443,17 @@ BlockKalmanFilter::block_kalman_filter(int nlhs, mxArray *plhs[], double *P_mf, } - /* Computes the norm of iF */ - double anorm = dlange("1", &size_d_index, &size_d_index, iF, &size_d_index, w); + /* Computes the norm of iF */ + double anorm = dlange("1", &size_d_index, &size_d_index, iF, &size_d_index, w); //mexPrintf("anorm = %f\n",anorm); - /* Modifies F in place with a LU decomposition */ - dgetrf(&size_d_index, &size_d_index, iF, &size_d_index, ipiv, &info); - if (info != 0) fprintf(stderr, "dgetrf failure with error %d\n", (int) info); + /* Modifies F in place with a LU decomposition */ + dgetrf(&size_d_index, &size_d_index, iF, &size_d_index, ipiv, &info); + if (info != 0) mexPrintf("dgetrf failure with error %d\n", (int) info); - /* Computes the reciprocal norm */ - dgecon("1", &size_d_index, iF, &size_d_index, &anorm, &rcond, w, iw, &info); - if (info != 0) fprintf(stderr, "dgecon failure with error %d\n", (int) info); + /* Computes the reciprocal norm */ + dgecon("1", &size_d_index, iF, &size_d_index, &anorm, &rcond, w, iw, &info); + if (info != 0) mexPrintf("dgecon failure with error %d\n", (int) info); if (rcond < kalman_tol) if (not_all_abs_F_bellow_crit(F, size_d_index * size_d_index, kalman_tol)) //~all(abs(F(:))= no_more_missing_observations) { double max_abs = 0.0;