4.1: fixing problems in gensylv for Matlab 7.8 64 bit

git-svn-id: https://www.dynare.org/svn/dynare/trunk@2599 ac1d8469-bf42-47a9-8791-bf33cf982152
time-shift
michel 2009-04-18 09:01:55 +00:00
parent c4e1b8ef80
commit a1ad1ed033
12 changed files with 140 additions and 103 deletions

View File

@ -271,15 +271,16 @@ void GeneralMatrix::gemm(const char* transa, const ConstGeneralMatrix& a,
throw SYLV_MES_EXCEPTION("Wrong dimensions for matrix multiplication.");
}
int m = opa_rows;
int n = opb_cols;
int k = opa_cols;
int lda = a.ld;
int ldb = b.ld;
int ldc = ld;
blas_int m = opa_rows;
blas_int n = opb_cols;
blas_int k = opa_cols;
blas_int lda = a.ld;
blas_int ldb = b.ld;
blas_int ldc = ld;
if (lda > 0 && ldb > 0 && ldc > 0) {
BLAS_dgemm(transa, transb, &m, &n, &k, &alpha, a.data.base(), &lda,
b.data.base(), &ldb, &beta, data.base(), &ldc);
BLAS_dgemm(const_cast<char*>(transa), const_cast<char*>(transb), &m, &n, &k, &alpha,
const_cast<double*>(a.data.base()), &lda,
const_cast<double*>(b.data.base()), &ldb, &beta, data.base(), &ldc);
} else if (numRows()*numCols() > 0) {
if (beta == 0.0)
zeros();
@ -366,15 +367,16 @@ void ConstGeneralMatrix::multVec(double a, Vector& x, double b, const ConstVecto
throw SYLV_MES_EXCEPTION("Wrong dimensions for vector multiply.");
}
if (rows > 0) {
int mm = rows;
int nn = cols;
blas_int mm = rows;
blas_int nn = cols;
double alpha = b;
int lda = ld;
int incx = d.skip();
blas_int lda = ld;
blas_int incx = d.skip();
double beta = a;
int incy = x.skip();
BLAS_dgemv("N", &mm, &nn, &alpha, data.base(), &lda, d.base(), &incx,
&beta, x.base(), &incy);
blas_int incy = x.skip();
BLAS_dgemv(const_cast<char*>("N"), &mm, &nn, &alpha, const_cast<double*>(data.base()),
&lda, const_cast<double*>(d.base()), &incx,
&beta, x.base(), &incy);
}
}
@ -386,20 +388,21 @@ void ConstGeneralMatrix::multVecTrans(double a, Vector& x, double b,
throw SYLV_MES_EXCEPTION("Wrong dimensions for vector multiply.");
}
if (rows > 0) {
int mm = rows;
int nn = cols;
blas_int mm = rows;
blas_int nn = cols;
double alpha = b;
int lda = rows;
int incx = d.skip();
blas_int lda = rows;
blas_int incx = d.skip();
double beta = a;
int incy = x.skip();
BLAS_dgemv("T", &mm, &nn, &alpha, data.base(), &lda, d.base(), &incx,
&beta, x.base(), &incy);
blas_int incy = x.skip();
BLAS_dgemv(const_cast<char*>("T"), &mm, &nn, &alpha, const_cast<double*>(data.base()),
&lda, const_cast<double*>(d.base()), &incx,
&beta, x.base(), &incy);
}
}
/* m = inv(this)*m */
void ConstGeneralMatrix::multInvLeft(const char* trans, int mrows, int mcols, int mld, double* d) const
void ConstGeneralMatrix::multInvLeft(const char* trans, lapack_int mrows, lapack_int mcols, lapack_int mld, double* d) const
{
if (rows != cols) {
throw SYLV_MES_EXCEPTION("The matrix is not square for inversion.");
@ -410,10 +413,11 @@ void ConstGeneralMatrix::multInvLeft(const char* trans, int mrows, int mcols, in
if (rows > 0) {
GeneralMatrix inv(*this);
int* ipiv = new int[rows];
int info;
LAPACK_dgetrf(&rows, &rows, inv.getData().base(), &rows, ipiv, &info);
LAPACK_dgetrs(trans, &rows, &mcols, inv.base(), &rows, ipiv, d,
lapack_int* ipiv = new lapack_int[rows];
lapack_int info;
lapack_int rows_arg = rows;
LAPACK_dgetrf(&rows_arg, &rows_arg, inv.getData().base(), &rows_arg, ipiv, &info);
LAPACK_dgetrs(const_cast<char*>(trans), &rows_arg, &mcols, inv.base(), &rows_arg, ipiv, d,
&mld, &info);
delete [] ipiv;
}

View File

@ -25,6 +25,8 @@
#define GENERAL_MATRIX_H
#include "Vector.h"
#include "cppblas.h"
#include "cpplapack.h"
class GeneralMatrix;
@ -82,7 +84,7 @@ public:
virtual void print() const;
protected:
void multInvLeft(const char* trans, int mrows, int mcols, int mld, double* d) const;
void multInvLeft(const char* trans, lapack_int mrows, lapack_int mcols, lapack_int mld, double* d) const;
};

View File

@ -415,11 +415,11 @@ QuasiTriangular::QuasiTriangular(int p, const QuasiTriangular& t)
: SqSylvMatrix(t.numRows()), diagonal(getData().base(), t.diagonal)
{
Vector aux(t.getData());
int d_size = diagonal.getSize();
blas_int d_size = diagonal.getSize();
double alpha = 1.0;
double beta = 0.0;
BLAS_dgemm("N", "N", &d_size, &d_size, &d_size, &alpha, aux.base(),
&d_size, t.getData().base(), &d_size, &beta, getData().base(), &d_size);
BLAS_dgemm(const_cast<char*>("N"), const_cast<char*>("N"), &d_size, &d_size, &d_size, &alpha, aux.base(),
&d_size, const_cast<double*>(t.getData().base()), &d_size, &beta, getData().base(), &d_size);
}
QuasiTriangular::QuasiTriangular(const SchurDecomp& decomp)
@ -548,10 +548,10 @@ void QuasiTriangular::solvePre(Vector& x, double& eig_min)
eig_min = eig_size;
}
int nn = diagonal.getSize();
int lda = diagonal.getSize();
int incx = x.skip();
BLAS_dtrsv("U", "N", "N", &nn, getData().base(), &lda, x.base(), &incx);
blas_int nn = diagonal.getSize();
blas_int lda = diagonal.getSize();
blas_int incx = x.skip();
BLAS_dtrsv(const_cast<char*>("U"), const_cast<char*>("N"), const_cast<char*>("N"), &nn, getData().base(), &lda, x.base(), &incx);
}
void QuasiTriangular::solvePreTrans(Vector& x, double& eig_min)
@ -569,10 +569,10 @@ void QuasiTriangular::solvePreTrans(Vector& x, double& eig_min)
eig_min = eig_size;
}
int nn = diagonal.getSize();
int lda = diagonal.getSize();
int incx = x.skip();
BLAS_dtrsv("U", "T", "N", &nn, getData().base(), &lda, x.base(), &incx);
blas_int nn = diagonal.getSize();
blas_int lda = diagonal.getSize();
blas_int incx = x.skip();
BLAS_dtrsv(const_cast<char*>("U"), const_cast<char*>("T"), const_cast<char*>("N"), &nn, getData().base(), &lda, x.base(), &incx);
}
@ -580,10 +580,11 @@ void QuasiTriangular::solvePreTrans(Vector& x, double& eig_min)
void QuasiTriangular::multVec(Vector& x, const ConstVector& b) const
{
x = b;
int nn = diagonal.getSize();
int lda = diagonal.getSize();
int incx = x.skip();
BLAS_dtrmv("U", "N", "N", &nn, getData().base(), &lda, x.base(), &incx);
blas_int nn = diagonal.getSize();
blas_int lda = diagonal.getSize();
blas_int incx = x.skip();
BLAS_dtrmv(const_cast<char*>("U"), const_cast<char*>("N"), const_cast<char*>("N"), &nn,
const_cast<double*>(getData().base()), &lda, x.base(), &incx);
for (const_diag_iter di = diag_begin(); di != diag_end(); ++di) {
if (!(*di).isReal()) {
int jbar = (*di).getIndex();
@ -596,10 +597,11 @@ void QuasiTriangular::multVec(Vector& x, const ConstVector& b) const
void QuasiTriangular::multVecTrans(Vector& x, const ConstVector& b) const
{
x = b;
int nn = diagonal.getSize();
int lda = diagonal.getSize();
int incx = x.skip();
BLAS_dtrmv("U", "T", "N", &nn, getData().base(), &lda, x.base(), &incx);
blas_int nn = diagonal.getSize();
blas_int lda = diagonal.getSize();
blas_int incx = x.skip();
BLAS_dtrmv(const_cast<char*>("U"), const_cast<char*>("T"), const_cast<char*>("N"), &nn,
const_cast<double*>(getData().base()), &lda, x.base(), &incx);
for (const_diag_iter di = diag_begin(); di != diag_end(); ++di) {
if (!(*di).isReal()) {
int jbar = (*di).getIndex();

View File

@ -28,16 +28,16 @@
SchurDecomp::SchurDecomp(const SqSylvMatrix& m)
: q_destroy(true), t_destroy(true)
{
int rows = m.numRows();
lapack_int rows = m.numRows();
q = new SqSylvMatrix(rows);
SqSylvMatrix auxt(m);
int sdim;
lapack_int sdim;
double* const wr = new double[rows];
double* const wi = new double[rows];
int lwork = 6*rows;
lapack_int lwork = 6*rows;
double* const work = new double[lwork];
int info;
LAPACK_dgees("V", "N", 0, &rows, auxt.base(), &rows, &sdim,
lapack_int info;
LAPACK_dgees(const_cast<char*>("V"), const_cast<char*>("N"), 0, &rows, auxt.base(), &rows, &sdim,
wr, wi, q->base(), &rows,
work, &lwork, 0, &info);
delete [] work;

View File

@ -64,12 +64,12 @@ bool SchurDecompEig::tryToSwap(diag_iter& it, diag_iter& itadd)
itadd = it;
--itadd;
int n = getDim();
int ifst = (*it).getIndex() + 1;
int ilst = (*itadd).getIndex() + 1;
lapack_int n = getDim();
lapack_int ifst = (*it).getIndex() + 1;
lapack_int ilst = (*itadd).getIndex() + 1;
double* work = new double[n];
int info;
LAPACK_dtrexc("V", &n, getT().base(), &n, getQ().base(), &n, &ifst, &ilst, work,
lapack_int info;
LAPACK_dtrexc(const_cast<char*>("V"), &n, getT().base(), &n, getQ().base(), &n, &ifst, &ilst, work,
&info);
delete [] work;
if (info < 0) {

View File

@ -72,12 +72,12 @@ bool SimilarityDecomp::solveX(diag_iter start, diag_iter end,
SqSylvMatrix B((const GeneralMatrix&)*b, ei, ei, X.numCols());
GeneralMatrix C((const GeneralMatrix&)*b, si, ei, X.numRows(), X.numCols());
int isgn = -1;
int m = A.numRows();
int n = B.numRows();
lapack_int isgn = -1;
lapack_int m = A.numRows();
lapack_int n = B.numRows();
double scale;
int info;
LAPACK_dtrsyl("N", "N", &isgn, &m, &n, A.base(), &m, B.base(), &n,
lapack_int info;
LAPACK_dtrsyl(const_cast<char*>("N"), const_cast<char*>("N"), &isgn, &m, &n, A.base(), &m, B.base(), &n,
C.base(), &m, &scale, &info);
if (info < -1)
throw SYLV_MES_EXCEPTION("Wrong parameter to LAPACK dtrsyl.");

View File

@ -63,16 +63,17 @@ void SylvMatrix::multLeft(int zero_cols, const GeneralMatrix& a, const GeneralMa
// another copy of (usually big) b (we are not able to do inplace
// submatrix of const GeneralMatrix)
if (a.getLD() > 0 && ld > 0) {
int mm = a.numRows();
int nn = cols;
int kk = a.numCols();
blas_int mm = a.numRows();
blas_int nn = cols;
blas_int kk = a.numCols();
double alpha = 1.0;
int lda = a.getLD();
int ldb = ld;
blas_int lda = a.getLD();
blas_int ldb = ld;
double beta = 0.0;
int ldc = ld;
BLAS_dgemm("N", "N", &mm, &nn, &kk, &alpha, a.getData().base(), &lda,
b.getData().base()+off, &ldb, &beta, data.base(), &ldc);
blas_int ldc = ld;
BLAS_dgemm(const_cast<char*>("N"), const_cast<char*>("N"), &mm, &nn, &kk, &alpha,
const_cast<double*>(a.getData().base()), &lda,
const_cast<double*>(b.getData().base()+off), &ldb, &beta, data.base(), &ldc);
}
}
@ -225,29 +226,30 @@ void SqSylvMatrix::multInvLeft2(GeneralMatrix& a, GeneralMatrix& b,
}
// PLU factorization
Vector inv(data);
int * const ipiv = new int[rows];
int info;
LAPACK_dgetrf(&rows, &rows, inv.base(), &rows, ipiv, &info);
lapack_int * const ipiv = new lapack_int[rows];
lapack_int info;
lapack_int rows_arg = rows;
LAPACK_dgetrf(&rows_arg, &rows_arg, inv.base(), &rows_arg, ipiv, &info);
// solve a
int acols = a.numCols();
lapack_int acols = a.numCols();
double* abase = a.base();
LAPACK_dgetrs("N", &rows, &acols, inv.base(), &rows, ipiv,
abase, &rows, &info);
LAPACK_dgetrs(const_cast<char*>("N"), &rows_arg, &acols, inv.base(), &rows_arg, ipiv,
abase, &rows_arg, &info);
// solve b
int bcols = b.numCols();
lapack_int bcols = b.numCols();
double* bbase = b.base();
LAPACK_dgetrs("N", &rows, &bcols, inv.base(), &rows, ipiv,
bbase, &rows, &info);
LAPACK_dgetrs(const_cast<char*>("N"), &rows_arg, &bcols, inv.base(), &rows_arg, ipiv,
bbase, &rows_arg, &info);
delete [] ipiv;
// condition numbers
double* const work = new double[4*rows];
int* const iwork = new int[rows];
lapack_int* const iwork = new lapack_int[rows];
double norm1 = getNorm1();
LAPACK_dgecon("1", &rows, inv.base(), &rows, &norm1, &rcond1,
LAPACK_dgecon(const_cast<char*>("1"), &rows_arg, inv.base(), &rows_arg, &norm1, &rcond1,
work, iwork, &info);
double norminf = getNormInf();
LAPACK_dgecon("I", &rows, inv.base(), &rows, &norminf, &rcondinf,
LAPACK_dgecon(const_cast<char*>("I"), &rows_arg, inv.base(), &rows_arg, &norminf, &rcondinf,
work, iwork, &info);
delete [] iwork;
delete [] work;

View File

@ -41,11 +41,11 @@ using namespace std;
ZeroPad zero_pad;
void Vector::copy(const double* d, int inc)
void Vector::copy(const double* d, lapack_int inc)
{
int n = length();
int incy = skip();
BLAS_dcopy(&n, d, &inc, base(), &incy);
blas_int n = length();
blas_int incy = skip();
BLAS_dcopy(&n, const_cast<double*>(d), &inc, const_cast<double*>(base()), &incy);
}
Vector::Vector(const Vector& v)
@ -190,10 +190,10 @@ void Vector::add(double r, const Vector& v)
void Vector::add(double r, const ConstVector& v)
{
int n = length();
int incx = v.skip();
int incy = skip();
BLAS_daxpy(&n, &r, v.base(), &incx, base(), &incy);
blas_int n = length();
blas_int incx = v.skip();
blas_int incy = skip();
BLAS_daxpy(&n, &r, const_cast<double*>(v.base()), &incx, base(), &incy);
}
void Vector::add(const double* z, const Vector& v)
@ -203,16 +203,16 @@ void Vector::add(const double* z, const Vector& v)
void Vector::add(const double* z, const ConstVector& v)
{
int n = length()/2;
int incx = v.skip();
int incy = skip();
BLAS_zaxpy(&n, z, v.base(), &incx, base(), &incy);
blas_int n = length()/2;
blas_int incx = v.skip();
blas_int incy = skip();
BLAS_zaxpy(&n, const_cast<double*>(z), const_cast<double*>(v.base()), &incx, base(), &incy);
}
void Vector::mult(double r)
{
int n = length();
int incx = skip();
blas_int n = length();
blas_int incx = skip();
BLAS_dscal(&n, &r, base(), &incx);
}
@ -334,10 +334,10 @@ double ConstVector::dot(const ConstVector& y) const
{
if (length() != y.length())
throw SYLV_MES_EXCEPTION("Vector has different length in ConstVector::dot.");
int n = length();
int incx = skip();
int incy = y.skip();
return BLAS_ddot(&n, base(), &incx, y.base(), &incy);
blas_int n = length();
blas_int incx = skip();
blas_int incy = y.skip();
return BLAS_ddot(&n, const_cast<double*>(base()), &incx, const_cast<double*>(y.base()), &incy);
}
bool ConstVector::isFinite() const

View File

@ -24,6 +24,12 @@
#ifndef VECTOR_H
#define VECTOR_H
#ifdef MATLAB
#include "mex.h"
#endif
#include "../../matlab_versions_compatibility.h"
/* NOTE! Vector and ConstVector have not common super class in order
* to avoid running virtual method invokation mechanism. Some
* members, and methods are thus duplicated */
@ -110,7 +116,7 @@ public:
const Vector& b1, const Vector& b2)
{mult2a(-alpha, -beta1, -beta2, x1, x2, b1, b2);}
private:
void copy(const double* d, int inc);
void copy(const double* d, lapack_int inc);
const Vector& operator=(int); // must not be used (not implemented)
const Vector& operator=(double); // must not be used (not implemented)
};

View File

@ -24,6 +24,12 @@
#ifndef CPPBLAS_H
#define CPPBLAS_H
#ifdef MATLAB
#include "mex.h"
#endif
#include "../../matlab_versions_compatibility.h"
#if defined(MATLAB) && !defined(__linux__) && !defined(OCTAVE)
#define BLAS_dgemm dgemm
#define BLAS_dgemv dgemv
@ -48,6 +54,7 @@
#define BLAS_ddot ddot_
#endif
#if defined NO_BLAS_H
#define BLCHAR const char*
#define CONST_BLINT const int*
#define CONST_BLDOU const double*
@ -78,7 +85,9 @@ extern "C" {
double BLAS_ddot(CONST_BLINT n, CONST_BLDOU x, CONST_BLINT incx, CONST_BLDOU y,
CONST_BLINT incy);
};
#else /* NO_BLAS_H isn't defined */
#include "blas.h"
#endif
#endif /* CPPBLAS_H */

View File

@ -24,6 +24,12 @@
#ifndef CPPLAPACK_H
#define CPPLAPACK_H
#ifdef MATLAB
#include "mex.h"
#endif
#include "../../matlab_versions_compatibility.h"
#if defined(MATLAB) && !defined(__linux__) && !defined(OCTAVE)
#define LAPACK_dgetrs dgetrs
#define LAPACK_dgetrf dgetrf
@ -46,6 +52,7 @@
#define LAPACK_dsyev dsyev_
#endif
#if defined NO_LAPACK_H
#define LACHAR const char*
#define CONST_LAINT const int*
#define LAINT int*
@ -83,7 +90,9 @@ extern "C" {
void LAPACK_dsyev(LACHAR jobz, LACHAR uplo, CONST_LAINT n, LADOU a, CONST_LAINT lda,
LADOU w, LADOU work, CONST_LAINT lwork, LAINT info);
};
#else /* NO_LAPACK_H isn't defined */
#include "lapack.h"
#endif
#endif /* CPPLAPACK_H */

View File

@ -1,10 +1,13 @@
#if !defined(MATLAB_VERSIONS_COMPATIBILITY_H)
#define MATLAB_VERSIONS_COMPATIBILITY_H
#if !defined(LAPACK_USE_MWSIGNEDINDEX) || defined(OCTAVE)
typedef int lapack_int;
typedef int blas_int;
#else
typedef mwSignedIndex lapack_int;
typedef mwSignedIndex blas_int;
#endif
#endif