diff --git a/mex/sources/gensylv/cc/GeneralMatrix.cpp b/mex/sources/gensylv/cc/GeneralMatrix.cpp index 314be2e53..5aef204d9 100644 --- a/mex/sources/gensylv/cc/GeneralMatrix.cpp +++ b/mex/sources/gensylv/cc/GeneralMatrix.cpp @@ -271,15 +271,16 @@ void GeneralMatrix::gemm(const char* transa, const ConstGeneralMatrix& a, throw SYLV_MES_EXCEPTION("Wrong dimensions for matrix multiplication."); } - int m = opa_rows; - int n = opb_cols; - int k = opa_cols; - int lda = a.ld; - int ldb = b.ld; - int ldc = ld; + blas_int m = opa_rows; + blas_int n = opb_cols; + blas_int k = opa_cols; + blas_int lda = a.ld; + blas_int ldb = b.ld; + blas_int ldc = ld; if (lda > 0 && ldb > 0 && ldc > 0) { - BLAS_dgemm(transa, transb, &m, &n, &k, &alpha, a.data.base(), &lda, - b.data.base(), &ldb, &beta, data.base(), &ldc); + BLAS_dgemm(const_cast(transa), const_cast(transb), &m, &n, &k, &alpha, + const_cast(a.data.base()), &lda, + const_cast(b.data.base()), &ldb, &beta, data.base(), &ldc); } else if (numRows()*numCols() > 0) { if (beta == 0.0) zeros(); @@ -366,15 +367,16 @@ void ConstGeneralMatrix::multVec(double a, Vector& x, double b, const ConstVecto throw SYLV_MES_EXCEPTION("Wrong dimensions for vector multiply."); } if (rows > 0) { - int mm = rows; - int nn = cols; + blas_int mm = rows; + blas_int nn = cols; double alpha = b; - int lda = ld; - int incx = d.skip(); + blas_int lda = ld; + blas_int incx = d.skip(); double beta = a; - int incy = x.skip(); - BLAS_dgemv("N", &mm, &nn, &alpha, data.base(), &lda, d.base(), &incx, - &beta, x.base(), &incy); + blas_int incy = x.skip(); + BLAS_dgemv(const_cast("N"), &mm, &nn, &alpha, const_cast(data.base()), + &lda, const_cast(d.base()), &incx, + &beta, x.base(), &incy); } } @@ -386,20 +388,21 @@ void ConstGeneralMatrix::multVecTrans(double a, Vector& x, double b, throw SYLV_MES_EXCEPTION("Wrong dimensions for vector multiply."); } if (rows > 0) { - int mm = rows; - int nn = cols; + blas_int mm = rows; + blas_int nn = cols; double alpha = b; - int lda = rows; - int incx = d.skip(); + blas_int lda = rows; + blas_int incx = d.skip(); double beta = a; - int incy = x.skip(); - BLAS_dgemv("T", &mm, &nn, &alpha, data.base(), &lda, d.base(), &incx, - &beta, x.base(), &incy); + blas_int incy = x.skip(); + BLAS_dgemv(const_cast("T"), &mm, &nn, &alpha, const_cast(data.base()), + &lda, const_cast(d.base()), &incx, + &beta, x.base(), &incy); } } /* m = inv(this)*m */ -void ConstGeneralMatrix::multInvLeft(const char* trans, int mrows, int mcols, int mld, double* d) const +void ConstGeneralMatrix::multInvLeft(const char* trans, lapack_int mrows, lapack_int mcols, lapack_int mld, double* d) const { if (rows != cols) { throw SYLV_MES_EXCEPTION("The matrix is not square for inversion."); @@ -410,10 +413,11 @@ void ConstGeneralMatrix::multInvLeft(const char* trans, int mrows, int mcols, in if (rows > 0) { GeneralMatrix inv(*this); - int* ipiv = new int[rows]; - int info; - LAPACK_dgetrf(&rows, &rows, inv.getData().base(), &rows, ipiv, &info); - LAPACK_dgetrs(trans, &rows, &mcols, inv.base(), &rows, ipiv, d, + lapack_int* ipiv = new lapack_int[rows]; + lapack_int info; + lapack_int rows_arg = rows; + LAPACK_dgetrf(&rows_arg, &rows_arg, inv.getData().base(), &rows_arg, ipiv, &info); + LAPACK_dgetrs(const_cast(trans), &rows_arg, &mcols, inv.base(), &rows_arg, ipiv, d, &mld, &info); delete [] ipiv; } diff --git a/mex/sources/gensylv/cc/GeneralMatrix.h b/mex/sources/gensylv/cc/GeneralMatrix.h index 9d749d848..6fa337d35 100644 --- a/mex/sources/gensylv/cc/GeneralMatrix.h +++ b/mex/sources/gensylv/cc/GeneralMatrix.h @@ -25,6 +25,8 @@ #define GENERAL_MATRIX_H #include "Vector.h" +#include "cppblas.h" +#include "cpplapack.h" class GeneralMatrix; @@ -82,7 +84,7 @@ public: virtual void print() const; protected: - void multInvLeft(const char* trans, int mrows, int mcols, int mld, double* d) const; + void multInvLeft(const char* trans, lapack_int mrows, lapack_int mcols, lapack_int mld, double* d) const; }; diff --git a/mex/sources/gensylv/cc/QuasiTriangular.cpp b/mex/sources/gensylv/cc/QuasiTriangular.cpp index 3e1f40779..09fe08a40 100644 --- a/mex/sources/gensylv/cc/QuasiTriangular.cpp +++ b/mex/sources/gensylv/cc/QuasiTriangular.cpp @@ -415,11 +415,11 @@ QuasiTriangular::QuasiTriangular(int p, const QuasiTriangular& t) : SqSylvMatrix(t.numRows()), diagonal(getData().base(), t.diagonal) { Vector aux(t.getData()); - int d_size = diagonal.getSize(); + blas_int d_size = diagonal.getSize(); double alpha = 1.0; double beta = 0.0; - BLAS_dgemm("N", "N", &d_size, &d_size, &d_size, &alpha, aux.base(), - &d_size, t.getData().base(), &d_size, &beta, getData().base(), &d_size); + BLAS_dgemm(const_cast("N"), const_cast("N"), &d_size, &d_size, &d_size, &alpha, aux.base(), + &d_size, const_cast(t.getData().base()), &d_size, &beta, getData().base(), &d_size); } QuasiTriangular::QuasiTriangular(const SchurDecomp& decomp) @@ -548,10 +548,10 @@ void QuasiTriangular::solvePre(Vector& x, double& eig_min) eig_min = eig_size; } - int nn = diagonal.getSize(); - int lda = diagonal.getSize(); - int incx = x.skip(); - BLAS_dtrsv("U", "N", "N", &nn, getData().base(), &lda, x.base(), &incx); + blas_int nn = diagonal.getSize(); + blas_int lda = diagonal.getSize(); + blas_int incx = x.skip(); + BLAS_dtrsv(const_cast("U"), const_cast("N"), const_cast("N"), &nn, getData().base(), &lda, x.base(), &incx); } void QuasiTriangular::solvePreTrans(Vector& x, double& eig_min) @@ -569,10 +569,10 @@ void QuasiTriangular::solvePreTrans(Vector& x, double& eig_min) eig_min = eig_size; } - int nn = diagonal.getSize(); - int lda = diagonal.getSize(); - int incx = x.skip(); - BLAS_dtrsv("U", "T", "N", &nn, getData().base(), &lda, x.base(), &incx); + blas_int nn = diagonal.getSize(); + blas_int lda = diagonal.getSize(); + blas_int incx = x.skip(); + BLAS_dtrsv(const_cast("U"), const_cast("T"), const_cast("N"), &nn, getData().base(), &lda, x.base(), &incx); } @@ -580,10 +580,11 @@ void QuasiTriangular::solvePreTrans(Vector& x, double& eig_min) void QuasiTriangular::multVec(Vector& x, const ConstVector& b) const { x = b; - int nn = diagonal.getSize(); - int lda = diagonal.getSize(); - int incx = x.skip(); - BLAS_dtrmv("U", "N", "N", &nn, getData().base(), &lda, x.base(), &incx); + blas_int nn = diagonal.getSize(); + blas_int lda = diagonal.getSize(); + blas_int incx = x.skip(); + BLAS_dtrmv(const_cast("U"), const_cast("N"), const_cast("N"), &nn, + const_cast(getData().base()), &lda, x.base(), &incx); for (const_diag_iter di = diag_begin(); di != diag_end(); ++di) { if (!(*di).isReal()) { int jbar = (*di).getIndex(); @@ -596,10 +597,11 @@ void QuasiTriangular::multVec(Vector& x, const ConstVector& b) const void QuasiTriangular::multVecTrans(Vector& x, const ConstVector& b) const { x = b; - int nn = diagonal.getSize(); - int lda = diagonal.getSize(); - int incx = x.skip(); - BLAS_dtrmv("U", "T", "N", &nn, getData().base(), &lda, x.base(), &incx); + blas_int nn = diagonal.getSize(); + blas_int lda = diagonal.getSize(); + blas_int incx = x.skip(); + BLAS_dtrmv(const_cast("U"), const_cast("T"), const_cast("N"), &nn, + const_cast(getData().base()), &lda, x.base(), &incx); for (const_diag_iter di = diag_begin(); di != diag_end(); ++di) { if (!(*di).isReal()) { int jbar = (*di).getIndex(); diff --git a/mex/sources/gensylv/cc/SchurDecomp.cpp b/mex/sources/gensylv/cc/SchurDecomp.cpp index 1d20efc2b..3387afc4c 100644 --- a/mex/sources/gensylv/cc/SchurDecomp.cpp +++ b/mex/sources/gensylv/cc/SchurDecomp.cpp @@ -28,16 +28,16 @@ SchurDecomp::SchurDecomp(const SqSylvMatrix& m) : q_destroy(true), t_destroy(true) { - int rows = m.numRows(); + lapack_int rows = m.numRows(); q = new SqSylvMatrix(rows); SqSylvMatrix auxt(m); - int sdim; + lapack_int sdim; double* const wr = new double[rows]; double* const wi = new double[rows]; - int lwork = 6*rows; + lapack_int lwork = 6*rows; double* const work = new double[lwork]; - int info; - LAPACK_dgees("V", "N", 0, &rows, auxt.base(), &rows, &sdim, + lapack_int info; + LAPACK_dgees(const_cast("V"), const_cast("N"), 0, &rows, auxt.base(), &rows, &sdim, wr, wi, q->base(), &rows, work, &lwork, 0, &info); delete [] work; diff --git a/mex/sources/gensylv/cc/SchurDecompEig.cpp b/mex/sources/gensylv/cc/SchurDecompEig.cpp index 6dd8a180f..0f5e970b4 100644 --- a/mex/sources/gensylv/cc/SchurDecompEig.cpp +++ b/mex/sources/gensylv/cc/SchurDecompEig.cpp @@ -64,12 +64,12 @@ bool SchurDecompEig::tryToSwap(diag_iter& it, diag_iter& itadd) itadd = it; --itadd; - int n = getDim(); - int ifst = (*it).getIndex() + 1; - int ilst = (*itadd).getIndex() + 1; + lapack_int n = getDim(); + lapack_int ifst = (*it).getIndex() + 1; + lapack_int ilst = (*itadd).getIndex() + 1; double* work = new double[n]; - int info; - LAPACK_dtrexc("V", &n, getT().base(), &n, getQ().base(), &n, &ifst, &ilst, work, + lapack_int info; + LAPACK_dtrexc(const_cast("V"), &n, getT().base(), &n, getQ().base(), &n, &ifst, &ilst, work, &info); delete [] work; if (info < 0) { diff --git a/mex/sources/gensylv/cc/SimilarityDecomp.cpp b/mex/sources/gensylv/cc/SimilarityDecomp.cpp index e5631737c..b7b3753bd 100644 --- a/mex/sources/gensylv/cc/SimilarityDecomp.cpp +++ b/mex/sources/gensylv/cc/SimilarityDecomp.cpp @@ -72,12 +72,12 @@ bool SimilarityDecomp::solveX(diag_iter start, diag_iter end, SqSylvMatrix B((const GeneralMatrix&)*b, ei, ei, X.numCols()); GeneralMatrix C((const GeneralMatrix&)*b, si, ei, X.numRows(), X.numCols()); - int isgn = -1; - int m = A.numRows(); - int n = B.numRows(); + lapack_int isgn = -1; + lapack_int m = A.numRows(); + lapack_int n = B.numRows(); double scale; - int info; - LAPACK_dtrsyl("N", "N", &isgn, &m, &n, A.base(), &m, B.base(), &n, + lapack_int info; + LAPACK_dtrsyl(const_cast("N"), const_cast("N"), &isgn, &m, &n, A.base(), &m, B.base(), &n, C.base(), &m, &scale, &info); if (info < -1) throw SYLV_MES_EXCEPTION("Wrong parameter to LAPACK dtrsyl."); diff --git a/mex/sources/gensylv/cc/SylvMatrix.cpp b/mex/sources/gensylv/cc/SylvMatrix.cpp index ccada689b..ab62e64a2 100644 --- a/mex/sources/gensylv/cc/SylvMatrix.cpp +++ b/mex/sources/gensylv/cc/SylvMatrix.cpp @@ -63,16 +63,17 @@ void SylvMatrix::multLeft(int zero_cols, const GeneralMatrix& a, const GeneralMa // another copy of (usually big) b (we are not able to do inplace // submatrix of const GeneralMatrix) if (a.getLD() > 0 && ld > 0) { - int mm = a.numRows(); - int nn = cols; - int kk = a.numCols(); + blas_int mm = a.numRows(); + blas_int nn = cols; + blas_int kk = a.numCols(); double alpha = 1.0; - int lda = a.getLD(); - int ldb = ld; + blas_int lda = a.getLD(); + blas_int ldb = ld; double beta = 0.0; - int ldc = ld; - BLAS_dgemm("N", "N", &mm, &nn, &kk, &alpha, a.getData().base(), &lda, - b.getData().base()+off, &ldb, &beta, data.base(), &ldc); + blas_int ldc = ld; + BLAS_dgemm(const_cast("N"), const_cast("N"), &mm, &nn, &kk, &alpha, + const_cast(a.getData().base()), &lda, + const_cast(b.getData().base()+off), &ldb, &beta, data.base(), &ldc); } } @@ -225,29 +226,30 @@ void SqSylvMatrix::multInvLeft2(GeneralMatrix& a, GeneralMatrix& b, } // PLU factorization Vector inv(data); - int * const ipiv = new int[rows]; - int info; - LAPACK_dgetrf(&rows, &rows, inv.base(), &rows, ipiv, &info); + lapack_int * const ipiv = new lapack_int[rows]; + lapack_int info; + lapack_int rows_arg = rows; + LAPACK_dgetrf(&rows_arg, &rows_arg, inv.base(), &rows_arg, ipiv, &info); // solve a - int acols = a.numCols(); + lapack_int acols = a.numCols(); double* abase = a.base(); - LAPACK_dgetrs("N", &rows, &acols, inv.base(), &rows, ipiv, - abase, &rows, &info); + LAPACK_dgetrs(const_cast("N"), &rows_arg, &acols, inv.base(), &rows_arg, ipiv, + abase, &rows_arg, &info); // solve b - int bcols = b.numCols(); + lapack_int bcols = b.numCols(); double* bbase = b.base(); - LAPACK_dgetrs("N", &rows, &bcols, inv.base(), &rows, ipiv, - bbase, &rows, &info); + LAPACK_dgetrs(const_cast("N"), &rows_arg, &bcols, inv.base(), &rows_arg, ipiv, + bbase, &rows_arg, &info); delete [] ipiv; // condition numbers double* const work = new double[4*rows]; - int* const iwork = new int[rows]; + lapack_int* const iwork = new lapack_int[rows]; double norm1 = getNorm1(); - LAPACK_dgecon("1", &rows, inv.base(), &rows, &norm1, &rcond1, + LAPACK_dgecon(const_cast("1"), &rows_arg, inv.base(), &rows_arg, &norm1, &rcond1, work, iwork, &info); double norminf = getNormInf(); - LAPACK_dgecon("I", &rows, inv.base(), &rows, &norminf, &rcondinf, + LAPACK_dgecon(const_cast("I"), &rows_arg, inv.base(), &rows_arg, &norminf, &rcondinf, work, iwork, &info); delete [] iwork; delete [] work; diff --git a/mex/sources/gensylv/cc/Vector.cpp b/mex/sources/gensylv/cc/Vector.cpp index 128e8e375..ae47b27e6 100644 --- a/mex/sources/gensylv/cc/Vector.cpp +++ b/mex/sources/gensylv/cc/Vector.cpp @@ -41,11 +41,11 @@ using namespace std; ZeroPad zero_pad; -void Vector::copy(const double* d, int inc) +void Vector::copy(const double* d, lapack_int inc) { - int n = length(); - int incy = skip(); - BLAS_dcopy(&n, d, &inc, base(), &incy); + blas_int n = length(); + blas_int incy = skip(); + BLAS_dcopy(&n, const_cast(d), &inc, const_cast(base()), &incy); } Vector::Vector(const Vector& v) @@ -190,10 +190,10 @@ void Vector::add(double r, const Vector& v) void Vector::add(double r, const ConstVector& v) { - int n = length(); - int incx = v.skip(); - int incy = skip(); - BLAS_daxpy(&n, &r, v.base(), &incx, base(), &incy); + blas_int n = length(); + blas_int incx = v.skip(); + blas_int incy = skip(); + BLAS_daxpy(&n, &r, const_cast(v.base()), &incx, base(), &incy); } void Vector::add(const double* z, const Vector& v) @@ -203,16 +203,16 @@ void Vector::add(const double* z, const Vector& v) void Vector::add(const double* z, const ConstVector& v) { - int n = length()/2; - int incx = v.skip(); - int incy = skip(); - BLAS_zaxpy(&n, z, v.base(), &incx, base(), &incy); + blas_int n = length()/2; + blas_int incx = v.skip(); + blas_int incy = skip(); + BLAS_zaxpy(&n, const_cast(z), const_cast(v.base()), &incx, base(), &incy); } void Vector::mult(double r) { - int n = length(); - int incx = skip(); + blas_int n = length(); + blas_int incx = skip(); BLAS_dscal(&n, &r, base(), &incx); } @@ -334,10 +334,10 @@ double ConstVector::dot(const ConstVector& y) const { if (length() != y.length()) throw SYLV_MES_EXCEPTION("Vector has different length in ConstVector::dot."); - int n = length(); - int incx = skip(); - int incy = y.skip(); - return BLAS_ddot(&n, base(), &incx, y.base(), &incy); + blas_int n = length(); + blas_int incx = skip(); + blas_int incy = y.skip(); + return BLAS_ddot(&n, const_cast(base()), &incx, const_cast(y.base()), &incy); } bool ConstVector::isFinite() const diff --git a/mex/sources/gensylv/cc/Vector.h b/mex/sources/gensylv/cc/Vector.h index 14eb1ad64..9d9f583bb 100644 --- a/mex/sources/gensylv/cc/Vector.h +++ b/mex/sources/gensylv/cc/Vector.h @@ -24,6 +24,12 @@ #ifndef VECTOR_H #define VECTOR_H +#ifdef MATLAB +#include "mex.h" +#endif + +#include "../../matlab_versions_compatibility.h" + /* NOTE! Vector and ConstVector have not common super class in order * to avoid running virtual method invokation mechanism. Some * members, and methods are thus duplicated */ @@ -110,7 +116,7 @@ public: const Vector& b1, const Vector& b2) {mult2a(-alpha, -beta1, -beta2, x1, x2, b1, b2);} private: - void copy(const double* d, int inc); + void copy(const double* d, lapack_int inc); const Vector& operator=(int); // must not be used (not implemented) const Vector& operator=(double); // must not be used (not implemented) }; diff --git a/mex/sources/gensylv/cc/cppblas.h b/mex/sources/gensylv/cc/cppblas.h index d326491eb..94d8eb68d 100644 --- a/mex/sources/gensylv/cc/cppblas.h +++ b/mex/sources/gensylv/cc/cppblas.h @@ -24,6 +24,12 @@ #ifndef CPPBLAS_H #define CPPBLAS_H +#ifdef MATLAB +#include "mex.h" +#endif + +#include "../../matlab_versions_compatibility.h" + #if defined(MATLAB) && !defined(__linux__) && !defined(OCTAVE) #define BLAS_dgemm dgemm #define BLAS_dgemv dgemv @@ -48,6 +54,7 @@ #define BLAS_ddot ddot_ #endif +#if defined NO_BLAS_H #define BLCHAR const char* #define CONST_BLINT const int* #define CONST_BLDOU const double* @@ -78,7 +85,9 @@ extern "C" { double BLAS_ddot(CONST_BLINT n, CONST_BLDOU x, CONST_BLINT incx, CONST_BLDOU y, CONST_BLINT incy); }; - +#else /* NO_BLAS_H isn't defined */ +#include "blas.h" +#endif #endif /* CPPBLAS_H */ diff --git a/mex/sources/gensylv/cc/cpplapack.h b/mex/sources/gensylv/cc/cpplapack.h index b02c19faa..1fc0e9915 100644 --- a/mex/sources/gensylv/cc/cpplapack.h +++ b/mex/sources/gensylv/cc/cpplapack.h @@ -24,6 +24,12 @@ #ifndef CPPLAPACK_H #define CPPLAPACK_H +#ifdef MATLAB +#include "mex.h" +#endif + +#include "../../matlab_versions_compatibility.h" + #if defined(MATLAB) && !defined(__linux__) && !defined(OCTAVE) #define LAPACK_dgetrs dgetrs #define LAPACK_dgetrf dgetrf @@ -46,6 +52,7 @@ #define LAPACK_dsyev dsyev_ #endif +#if defined NO_LAPACK_H #define LACHAR const char* #define CONST_LAINT const int* #define LAINT int* @@ -83,7 +90,9 @@ extern "C" { void LAPACK_dsyev(LACHAR jobz, LACHAR uplo, CONST_LAINT n, LADOU a, CONST_LAINT lda, LADOU w, LADOU work, CONST_LAINT lwork, LAINT info); }; - +#else /* NO_LAPACK_H isn't defined */ +#include "lapack.h" +#endif #endif /* CPPLAPACK_H */ diff --git a/mex/sources/matlab_versions_compatibility.h b/mex/sources/matlab_versions_compatibility.h index 6ac9c3569..46174a9ab 100644 --- a/mex/sources/matlab_versions_compatibility.h +++ b/mex/sources/matlab_versions_compatibility.h @@ -1,10 +1,13 @@ #if !defined(MATLAB_VERSIONS_COMPATIBILITY_H) #define MATLAB_VERSIONS_COMPATIBILITY_H + #if !defined(LAPACK_USE_MWSIGNEDINDEX) || defined(OCTAVE) typedef int lapack_int; +typedef int blas_int; #else typedef mwSignedIndex lapack_int; +typedef mwSignedIndex blas_int; #endif #endif