dynare/mex/sources/estimation/libmat/BlasBindings.hh

170 lines
5.2 KiB
C++

/*
* Copyright (C) 2010-2011 Dynare Team
*
* This file is part of Dynare.
*
* Dynare is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Dynare is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Dynare. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef _BLAS_BINDINGS_HH
#define _BLAS_BINDINGS_HH
#include <dynblas.h>
#include "Vector.hh"
#include "Matrix.hh"
namespace blas
{
/* Level 1 */
//! dot product of two vectors
template<class Vec1, class Vec2>
inline double
dot(const Vec1 &A, Vec2 &B)
{
assert(A.getSize() == B.getSize());
blas_int n = A.getSize();
blas_int lda = A.getStride(), ldb = B.getStride();
return ddot(&n, A.getData(), &lda, B.getData(), &ldb);
}
/* Level 2 */
//! Symmetric rank 1 operation: A = alpha*X*X' + A
template<class Mat, class Vec>
inline void
syr(const char *uplo, double alpha, Vec X, Mat A)
{
assert(X.getSize() == A.getCols() && A.getCols() == A.getRows());
blas_int n = X.getSize();
blas_int incx = X.getStride();
blas_int lda = A.getLd();
dsyr(uplo, &n, &alpha, X.getData(), &incx, A.getData(), &lda);
}
//! General matrix * vector multiplication
// c = alpha*A*b + beta*c, or c := alpha*A'*b + beta*c,
// where alpha and beta are scalars, b and c are vectors and A is an
// m by n matrix.
template<class Mat1, class Vec2, class Vec3>
inline void
gemv(const char *transa, double alpha, const Mat1 &A,
const Vec2 &B, double beta, Vec3 &C)
{
blas_int m = A.getRows(), n = B.getSize(), k = A.getCols(), l = C.getSize();
if (*transa == 'T')
{
m = A.getCols();
k = A.getRows();
}
assert(m == l);
assert(k == n);
blas_int lda = A.getLd(), ldb = B.getStride(), ldc = C.getStride();
dgemv(transa, &m, &n, &alpha, A.getData(), &lda,
B.getData(), &ldb, &beta, C.getData(), &ldc);
}
//! Symmetric matrix * vector multiplication
// c = alpha*A*b + beta*c,
// where alpha and beta are scalars, b and c are vectors and A is a
// m by m symmetric matrix.
template<class Mat1, class Vec2, class Vec3>
inline void
symv(const char *uplo, double alpha, const Mat1 &A,
const Vec2 &B, double beta, Vec3 &C)
{
assert(A.getRows() == A.getCols());
blas_int n = A.getRows();
assert(A.getRows() == B.getSize());
assert(A.getRows() == C.getSize());
blas_int lda = A.getLd(), ldb = B.getStride(), ldc = C.getStride();
dsymv(uplo, &n, &alpha, A.getData(), &lda,
B.getData(), &ldb, &beta, C.getData(), &ldc);
}
/* Level 3 */
//! General matrix multiplication
template<class Mat1, class Mat2, class Mat3>
inline void
gemm(const char *transa, const char *transb,
double alpha, const Mat1 &A, const Mat2 &B,
double beta, Mat3 &C)
{
blas_int m = A.getRows(), n = B.getCols(), k = A.getCols();
if (*transa == 'N')
{
if (*transb == 'N')
{
assert(A.getRows() == C.getRows());
assert(A.getCols() == B.getRows());
assert(B.getCols() == C.getCols());
}
else if (*transb == 'T')
{
assert(A.getRows() == C.getRows());
assert(A.getCols() == B.getCols());
assert(B.getRows() == C.getCols());
n = B.getRows();
}
}
else if (*transa == 'T')
{
m = A.getCols();
k = A.getRows();
if (*transb == 'N')
{
assert(A.getCols() == C.getRows());
assert(A.getRows() == B.getRows());
assert(B.getCols() == C.getCols());
}
else if (*transb == 'T')
{
assert(A.getCols() == C.getRows());
assert(A.getRows() == B.getCols());
assert(B.getRows() == C.getCols());
n = B.getRows();
}
}
blas_int lda = A.getLd(), ldb = B.getLd(), ldc = C.getLd();
dgemm(transa, transb, &m, &n, &k, &alpha, A.getData(), &lda,
B.getData(), &ldb, &beta, C.getData(), &ldc);
}
//! Symmetric matrix A * (poss. rectangular) matrix B multiplication
template<class Mat1, class Mat2, class Mat3>
inline void
symm(const char *side, const char *uplo,
double alpha, const Mat1 &A, const Mat2 &B,
double beta, Mat3 &C)
{
assert(A.getRows() == A.getCols());
assert(B.getRows() == C.getRows());
assert(B.getCols() == C.getCols());
if (*side == 'L' || *side == 'l')
assert(A.getCols() == B.getRows());
else if (*side == 'R' || *side == 'r')
assert(A.getRows() == B.getCols());
blas_int m = B.getRows(), n = B.getCols();
blas_int lda = A.getLd(), ldb = B.getLd(), ldc = C.getLd();
dsymm(side, uplo, &m, &n, &alpha, A.getData(), &lda,
B.getData(), &ldb, &beta, C.getData(), &ldc);
}
} // End of namespace
#endif