Dynare++ sylvester equation solver: various simplifications and improvements

In particular, the test binary now errors out in case of test failure.
time-shift
Sébastien Villemot 2019-01-25 15:27:20 +01:00
parent d15c998804
commit 3ce051d819
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
23 changed files with 510 additions and 686 deletions

View File

@ -5,12 +5,11 @@
#include "BlockDiagonal.hh" #include "BlockDiagonal.hh"
#include <iostream> #include <iostream>
#include <cstring>
#include <utility> #include <utility>
BlockDiagonal::BlockDiagonal(ConstVector d, int d_size) BlockDiagonal::BlockDiagonal(ConstVector d, int d_size)
: QuasiTriangular(std::move(d), d_size), : QuasiTriangular(std::move(d), d_size),
row_len(new int[d_size]), col_len(new int[d_size]) row_len(d_size), col_len(d_size)
{ {
for (int i = 0; i < d_size; i++) for (int i = 0; i < d_size; i++)
{ {
@ -21,7 +20,7 @@ BlockDiagonal::BlockDiagonal(ConstVector d, int d_size)
BlockDiagonal::BlockDiagonal(const QuasiTriangular &t) BlockDiagonal::BlockDiagonal(const QuasiTriangular &t)
: QuasiTriangular(t), : QuasiTriangular(t),
row_len(new int[t.numRows()]), col_len(new int[t.numRows()]) row_len(t.numRows()), col_len(t.numRows())
{ {
for (int i = 0; i < t.numRows(); i++) for (int i = 0; i < t.numRows(); i++)
{ {
@ -32,18 +31,8 @@ BlockDiagonal::BlockDiagonal(const QuasiTriangular &t)
BlockDiagonal::BlockDiagonal(int p, const BlockDiagonal &b) BlockDiagonal::BlockDiagonal(int p, const BlockDiagonal &b)
: QuasiTriangular(p, b), : QuasiTriangular(p, b),
row_len(new int[b.numRows()]), col_len(new int[b.numRows()]) row_len(b.row_len), col_len(b.col_len)
{ {
memcpy(row_len, b.row_len, b.numRows()*sizeof(int));
memcpy(col_len, b.col_len, b.numRows()*sizeof(int));
}
BlockDiagonal::BlockDiagonal(const BlockDiagonal &b)
: QuasiTriangular(b),
row_len(new int[b.numRows()]), col_len(new int[b.numRows()])
{
memcpy(row_len, b.row_len, b.numRows()*sizeof(int));
memcpy(col_len, b.col_len, b.numRows()*sizeof(int));
} }
/* put zeroes to right upper submatrix whose first column is defined /* put zeroes to right upper submatrix whose first column is defined
@ -135,9 +124,7 @@ BlockDiagonal::getNumZeros() const
{ {
int sum = 0; int sum = 0;
for (int i = 0; i < diagonal.getSize(); i++) for (int i = 0; i < diagonal.getSize(); i++)
{ sum += diagonal.getSize() - row_len[i];
sum += diagonal.getSize() - row_len[i];
}
return sum; return sum;
} }

View File

@ -6,28 +6,25 @@
#define BLOCK_DIAGONAL_H #define BLOCK_DIAGONAL_H
#include <memory> #include <memory>
#include <vector>
#include "QuasiTriangular.hh" #include "QuasiTriangular.hh"
class BlockDiagonal : public QuasiTriangular class BlockDiagonal : public QuasiTriangular
{ {
int *const row_len; std::vector<int> row_len, col_len;
int *const col_len;
public: public:
BlockDiagonal(ConstVector d, int d_size); BlockDiagonal(ConstVector d, int d_size);
BlockDiagonal(int p, const BlockDiagonal &b); BlockDiagonal(int p, const BlockDiagonal &b);
BlockDiagonal(const BlockDiagonal &b); BlockDiagonal(const BlockDiagonal &b) = default;
BlockDiagonal(const QuasiTriangular &t); BlockDiagonal(const QuasiTriangular &t);
BlockDiagonal & BlockDiagonal &operator=(const QuasiTriangular &t)
operator=(const QuasiTriangular &t)
{ {
GeneralMatrix::operator=(t); return *this; GeneralMatrix::operator=(t);
} return *this;
BlockDiagonal &operator=(const BlockDiagonal &b);
~BlockDiagonal() override
{
delete [] row_len; delete [] col_len;
} }
BlockDiagonal &operator=(const BlockDiagonal &b) = default;
~BlockDiagonal() override = default;
void setZeroBlockEdge(diag_iter edge); void setZeroBlockEdge(diag_iter edge);
int getNumZeros() const; int getNumZeros() const;
int getNumBlocks() const; int getNumBlocks() const;

View File

@ -10,7 +10,6 @@
#include <iostream> #include <iostream>
#include <iomanip> #include <iomanip>
#include <cstring>
#include <cstdlib> #include <cstdlib>
#include <cmath> #include <cmath>
#include <limits> #include <limits>
@ -28,7 +27,7 @@ GeneralMatrix::GeneralMatrix(const ConstGeneralMatrix &m)
copy(m); copy(m);
} }
GeneralMatrix::GeneralMatrix(const GeneralMatrix &m, const char *dummy) GeneralMatrix::GeneralMatrix(const GeneralMatrix &m, const std::string &dummy)
: data(m.rows*m.cols), rows(m.cols), cols(m.rows), ld(m.cols) : data(m.rows*m.cols), rows(m.cols), cols(m.rows), ld(m.cols)
{ {
for (int i = 0; i < m.rows; i++) for (int i = 0; i < m.rows; i++)
@ -36,7 +35,7 @@ GeneralMatrix::GeneralMatrix(const GeneralMatrix &m, const char *dummy)
get(j, i) = m.get(i, j); get(j, i) = m.get(i, j);
} }
GeneralMatrix::GeneralMatrix(const ConstGeneralMatrix &m, const char *dummy) GeneralMatrix::GeneralMatrix(const ConstGeneralMatrix &m, const std::string &dummy)
: data(m.rows*m.cols), rows(m.cols), cols(m.rows), ld(m.cols) : data(m.rows*m.cols), rows(m.cols), cols(m.rows), ld(m.cols)
{ {
for (int i = 0; i < m.rows; i++) for (int i = 0; i < m.rows; i++)
@ -61,20 +60,20 @@ GeneralMatrix::GeneralMatrix(const ConstGeneralMatrix &a, const ConstGeneralMatr
gemm("N", a, "N", b, 1.0, 0.0); gemm("N", a, "N", b, 1.0, 0.0);
} }
GeneralMatrix::GeneralMatrix(const ConstGeneralMatrix &a, const ConstGeneralMatrix &b, const char *dum) GeneralMatrix::GeneralMatrix(const ConstGeneralMatrix &a, const ConstGeneralMatrix &b, const std::string &dum)
: data(a.rows*b.rows), rows(a.rows), cols(b.rows), ld(a.rows) : data(a.rows*b.rows), rows(a.rows), cols(b.rows), ld(a.rows)
{ {
gemm("N", a, "T", b, 1.0, 0.0); gemm("N", a, "T", b, 1.0, 0.0);
} }
GeneralMatrix::GeneralMatrix(const ConstGeneralMatrix &a, const char *dum, const ConstGeneralMatrix &b) GeneralMatrix::GeneralMatrix(const ConstGeneralMatrix &a, const std::string &dum, const ConstGeneralMatrix &b)
: data(a.cols*b.cols), rows(a.cols), cols(b.cols), ld(a.cols) : data(a.cols*b.cols), rows(a.cols), cols(b.cols), ld(a.cols)
{ {
gemm("T", a, "N", b, 1.0, 0.0); gemm("T", a, "N", b, 1.0, 0.0);
} }
GeneralMatrix::GeneralMatrix(const ConstGeneralMatrix &a, const char *dum1, GeneralMatrix::GeneralMatrix(const ConstGeneralMatrix &a, const std::string &dum1,
const ConstGeneralMatrix &b, const char *dum2) const ConstGeneralMatrix &b, const std::string &dum2)
: data(a.cols*b.rows), rows(a.cols), cols(b.rows), ld(a.cols) : data(a.cols*b.rows), rows(a.cols), cols(b.rows), ld(a.cols)
{ {
gemm("T", a, "T", b, 1.0, 0.0); gemm("T", a, "T", b, 1.0, 0.0);
@ -140,21 +139,21 @@ GeneralMatrix::multAndAdd(const ConstGeneralMatrix &a, const ConstGeneralMatrix
void void
GeneralMatrix::multAndAdd(const ConstGeneralMatrix &a, const ConstGeneralMatrix &b, GeneralMatrix::multAndAdd(const ConstGeneralMatrix &a, const ConstGeneralMatrix &b,
const char *dum, double mult) const std::string &dum, double mult)
{ {
gemm("N", a, "T", b, mult, 1.0); gemm("N", a, "T", b, mult, 1.0);
} }
void void
GeneralMatrix::multAndAdd(const ConstGeneralMatrix &a, const char *dum, GeneralMatrix::multAndAdd(const ConstGeneralMatrix &a, const std::string &dum,
const ConstGeneralMatrix &b, double mult) const ConstGeneralMatrix &b, double mult)
{ {
gemm("T", a, "N", b, mult, 1.0); gemm("T", a, "N", b, mult, 1.0);
} }
void void
GeneralMatrix::multAndAdd(const ConstGeneralMatrix &a, const char *dum1, GeneralMatrix::multAndAdd(const ConstGeneralMatrix &a, const std::string &dum1,
const ConstGeneralMatrix &b, const char *dum2, double mult) const ConstGeneralMatrix &b, const std::string &dum2, double mult)
{ {
gemm("T", a, "T", b, mult, 1.0); gemm("T", a, "T", b, mult, 1.0);
} }
@ -269,7 +268,7 @@ GeneralMatrix::add(double a, const ConstGeneralMatrix &m)
} }
void void
GeneralMatrix::add(double a, const ConstGeneralMatrix &m, const char *dum) GeneralMatrix::add(double a, const ConstGeneralMatrix &m, const std::string &dum)
{ {
if (m.numRows() != cols || m.numCols() != rows) if (m.numRows() != cols || m.numCols() != rows)
throw SYLV_MES_EXCEPTION("Matrix has different size in GeneralMatrix::add."); throw SYLV_MES_EXCEPTION("Matrix has different size in GeneralMatrix::add.");
@ -288,20 +287,20 @@ GeneralMatrix::copy(const ConstGeneralMatrix &m, int ioff, int joff)
} }
void void
GeneralMatrix::gemm(const char *transa, const ConstGeneralMatrix &a, GeneralMatrix::gemm(const std::string &transa, const ConstGeneralMatrix &a,
const char *transb, const ConstGeneralMatrix &b, const std::string &transb, const ConstGeneralMatrix &b,
double alpha, double beta) double alpha, double beta)
{ {
int opa_rows = a.numRows(); int opa_rows = a.numRows();
int opa_cols = a.numCols(); int opa_cols = a.numCols();
if (!strcmp(transa, "T")) if (transa == "T")
{ {
opa_rows = a.numCols(); opa_rows = a.numCols();
opa_cols = a.numRows(); opa_cols = a.numRows();
} }
int opb_rows = b.numRows(); int opb_rows = b.numRows();
int opb_cols = b.numCols(); int opb_cols = b.numCols();
if (!strcmp(transb, "T")) if (transb == "T")
{ {
opb_rows = b.numCols(); opb_rows = b.numCols();
opb_cols = b.numRows(); opb_cols = b.numRows();
@ -322,7 +321,7 @@ GeneralMatrix::gemm(const char *transa, const ConstGeneralMatrix &a,
blas_int ldc = ld; blas_int ldc = ld;
if (lda > 0 && ldb > 0 && ldc > 0) if (lda > 0 && ldb > 0 && ldc > 0)
{ {
dgemm(transa, transb, &m, &n, &k, &alpha, a.data.base(), &lda, dgemm(transa.c_str(), transb.c_str(), &m, &n, &k, &alpha, a.data.base(), &lda,
b.data.base(), &ldb, &beta, data.base(), &ldc); b.data.base(), &ldb, &beta, data.base(), &ldc);
} }
else if (numRows()*numCols() > 0) else if (numRows()*numCols() > 0)
@ -335,7 +334,7 @@ GeneralMatrix::gemm(const char *transa, const ConstGeneralMatrix &a,
} }
void void
GeneralMatrix::gemm_partial_left(const char *trans, const ConstGeneralMatrix &m, GeneralMatrix::gemm_partial_left(const std::string &trans, const ConstGeneralMatrix &m,
double alpha, double beta) double alpha, double beta)
{ {
int icol; int icol;
@ -354,7 +353,7 @@ GeneralMatrix::gemm_partial_left(const char *trans, const ConstGeneralMatrix &m,
} }
void void
GeneralMatrix::gemm_partial_right(const char *trans, const ConstGeneralMatrix &m, GeneralMatrix::gemm_partial_right(const std::string &trans, const ConstGeneralMatrix &m,
double alpha, double beta) double alpha, double beta)
{ {
int irow; int irow;
@ -469,7 +468,7 @@ ConstGeneralMatrix::multVecTrans(double a, Vector &x, double b,
/* m = inv(this)*m */ /* m = inv(this)*m */
void void
ConstGeneralMatrix::multInvLeft(const char *trans, int mrows, int mcols, int mld, double *d) const ConstGeneralMatrix::multInvLeft(const std::string &trans, int mrows, int mcols, int mld, double *d) const
{ {
if (rows != cols) if (rows != cols)
throw SYLV_MES_EXCEPTION("The matrix is not square for inversion."); throw SYLV_MES_EXCEPTION("The matrix is not square for inversion.");
@ -483,7 +482,7 @@ ConstGeneralMatrix::multInvLeft(const char *trans, int mrows, int mcols, int mld
lapack_int info; lapack_int info;
lapack_int rows2 = rows, mcols2 = mcols, mld2 = mld, lda = inv.ld; lapack_int rows2 = rows, mcols2 = mcols, mld2 = mld, lda = inv.ld;
dgetrf(&rows2, &rows2, inv.getData().base(), &lda, ipiv.data(), &info); dgetrf(&rows2, &rows2, inv.getData().base(), &lda, ipiv.data(), &info);
dgetrs(trans, &rows2, &mcols2, inv.base(), &lda, ipiv.data(), d, dgetrs(trans.c_str(), &rows2, &mcols2, inv.base(), &lda, ipiv.data(), d,
&mld2, &info); &mld2, &info);
} }
} }

View File

@ -11,6 +11,7 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <string>
class GeneralMatrix; class GeneralMatrix;
@ -120,7 +121,7 @@ public:
virtual void print() const; virtual void print() const;
protected: protected:
void multInvLeft(const char *trans, int mrows, int mcols, int mld, double *d) const; void multInvLeft(const std::string &trans, int mrows, int mcols, int mld, double *d) const;
}; };
class GeneralMatrix class GeneralMatrix
@ -152,19 +153,19 @@ public:
explicit GeneralMatrix(const ConstGeneralMatrix &m); explicit GeneralMatrix(const ConstGeneralMatrix &m);
GeneralMatrix(GeneralMatrix &&m) = default; GeneralMatrix(GeneralMatrix &&m) = default;
GeneralMatrix(const GeneralMatrix &m, const char *dummy); // transpose GeneralMatrix(const GeneralMatrix &m, const std::string &dummy); // transpose
GeneralMatrix(const ConstGeneralMatrix &m, const char *dummy); // transpose GeneralMatrix(const ConstGeneralMatrix &m, const std::string &dummy); // transpose
GeneralMatrix(const GeneralMatrix &m, int i, int j, int nrows, int ncols); GeneralMatrix(const GeneralMatrix &m, int i, int j, int nrows, int ncols);
GeneralMatrix(GeneralMatrix &m, int i, int j, int nrows, int ncols); GeneralMatrix(GeneralMatrix &m, int i, int j, int nrows, int ncols);
/* this = a*b */ /* this = a*b */
GeneralMatrix(const ConstGeneralMatrix &a, const ConstGeneralMatrix &b); GeneralMatrix(const ConstGeneralMatrix &a, const ConstGeneralMatrix &b);
/* this = a*b' */ /* this = a*b' */
GeneralMatrix(const ConstGeneralMatrix &a, const ConstGeneralMatrix &b, const char *dum); GeneralMatrix(const ConstGeneralMatrix &a, const ConstGeneralMatrix &b, const std::string &dum);
/* this = a'*b */ /* this = a'*b */
GeneralMatrix(const ConstGeneralMatrix &a, const char *dum, const ConstGeneralMatrix &b); GeneralMatrix(const ConstGeneralMatrix &a, const std::string &dum, const ConstGeneralMatrix &b);
/* this = a'*b */ /* this = a'*b */
GeneralMatrix(const ConstGeneralMatrix &a, const char *dum1, GeneralMatrix(const ConstGeneralMatrix &a, const std::string &dum1,
const ConstGeneralMatrix &b, const char *dum2); const ConstGeneralMatrix &b, const std::string &dum2);
virtual ~GeneralMatrix() = default; virtual ~GeneralMatrix() = default;
GeneralMatrix &operator=(const GeneralMatrix &m) = default; GeneralMatrix &operator=(const GeneralMatrix &m) = default;
@ -260,30 +261,30 @@ public:
/* this = this + scalar*a*b' */ /* this = this + scalar*a*b' */
void multAndAdd(const ConstGeneralMatrix &a, const ConstGeneralMatrix &b, void multAndAdd(const ConstGeneralMatrix &a, const ConstGeneralMatrix &b,
const char *dum, double mult = 1.0); const std::string &dum, double mult = 1.0);
void void
multAndAdd(const GeneralMatrix &a, const GeneralMatrix &b, multAndAdd(const GeneralMatrix &a, const GeneralMatrix &b,
const char *dum, double mult = 1.0) const std::string &dum, double mult = 1.0)
{ {
multAndAdd(ConstGeneralMatrix(a), ConstGeneralMatrix(b), dum, mult); multAndAdd(ConstGeneralMatrix(a), ConstGeneralMatrix(b), dum, mult);
} }
/* this = this + scalar*a'*b */ /* this = this + scalar*a'*b */
void multAndAdd(const ConstGeneralMatrix &a, const char *dum, const ConstGeneralMatrix &b, void multAndAdd(const ConstGeneralMatrix &a, const std::string &dum, const ConstGeneralMatrix &b,
double mult = 1.0); double mult = 1.0);
void void
multAndAdd(const GeneralMatrix &a, const char *dum, const GeneralMatrix &b, multAndAdd(const GeneralMatrix &a, const std::string &dum, const GeneralMatrix &b,
double mult = 1.0) double mult = 1.0)
{ {
multAndAdd(ConstGeneralMatrix(a), dum, ConstGeneralMatrix(b), mult); multAndAdd(ConstGeneralMatrix(a), dum, ConstGeneralMatrix(b), mult);
} }
/* this = this + scalar*a'*b' */ /* this = this + scalar*a'*b' */
void multAndAdd(const ConstGeneralMatrix &a, const char *dum1, void multAndAdd(const ConstGeneralMatrix &a, const std::string &dum1,
const ConstGeneralMatrix &b, const char *dum2, double mult = 1.0); const ConstGeneralMatrix &b, const std::string &dum2, double mult = 1.0);
void void
multAndAdd(const GeneralMatrix &a, const char *dum1, multAndAdd(const GeneralMatrix &a, const std::string &dum1,
const GeneralMatrix &b, const char *dum2, double mult = 1.0) const GeneralMatrix &b, const std::string &dum2, double mult = 1.0)
{ {
multAndAdd(ConstGeneralMatrix(a), dum1, ConstGeneralMatrix(b), dum2, mult); multAndAdd(ConstGeneralMatrix(a), dum1, ConstGeneralMatrix(b), dum2, mult);
} }
@ -394,9 +395,9 @@ public:
} }
/* this = this + scalar*m' */ /* this = this + scalar*m' */
void add(double a, const ConstGeneralMatrix &m, const char *dum); void add(double a, const ConstGeneralMatrix &m, const std::string &dum);
void void
add(double a, const GeneralMatrix &m, const char *dum) add(double a, const GeneralMatrix &m, const std::string &dum)
{ {
add(a, ConstGeneralMatrix(m), dum); add(a, ConstGeneralMatrix(m), dum);
} }
@ -426,12 +427,12 @@ private:
copy(ConstGeneralMatrix(m), ioff, joff); copy(ConstGeneralMatrix(m), ioff, joff);
} }
void gemm(const char *transa, const ConstGeneralMatrix &a, void gemm(const std::string &transa, const ConstGeneralMatrix &a,
const char *transb, const ConstGeneralMatrix &b, const std::string &transb, const ConstGeneralMatrix &b,
double alpha, double beta); double alpha, double beta);
void void
gemm(const char *transa, const GeneralMatrix &a, gemm(const std::string &transa, const GeneralMatrix &a,
const char *transb, const GeneralMatrix &b, const std::string &transb, const GeneralMatrix &b,
double alpha, double beta) double alpha, double beta)
{ {
gemm(transa, ConstGeneralMatrix(a), transb, ConstGeneralMatrix(b), gemm(transa, ConstGeneralMatrix(a), transb, ConstGeneralMatrix(b),
@ -439,20 +440,20 @@ private:
} }
/* this = this * op(m) (without whole copy of this) */ /* this = this * op(m) (without whole copy of this) */
void gemm_partial_right(const char *trans, const ConstGeneralMatrix &m, void gemm_partial_right(const std::string &trans, const ConstGeneralMatrix &m,
double alpha, double beta); double alpha, double beta);
void void
gemm_partial_right(const char *trans, const GeneralMatrix &m, gemm_partial_right(const std::string &trans, const GeneralMatrix &m,
double alpha, double beta) double alpha, double beta)
{ {
gemm_partial_right(trans, ConstGeneralMatrix(m), alpha, beta); gemm_partial_right(trans, ConstGeneralMatrix(m), alpha, beta);
} }
/* this = op(m) *this (without whole copy of this) */ /* this = op(m) *this (without whole copy of this) */
void gemm_partial_left(const char *trans, const ConstGeneralMatrix &m, void gemm_partial_left(const std::string &trans, const ConstGeneralMatrix &m,
double alpha, double beta); double alpha, double beta);
void void
gemm_partial_left(const char *trans, const GeneralMatrix &m, gemm_partial_left(const std::string &trans, const GeneralMatrix &m,
double alpha, double beta) double alpha, double beta)
{ {
gemm_partial_left(trans, ConstGeneralMatrix(m), alpha, beta); gemm_partial_left(trans, ConstGeneralMatrix(m), alpha, beta);

View File

@ -71,7 +71,7 @@ GeneralSylvester::init()
cdecomp = std::make_unique<SimilarityDecomp>(c.getData(), c.numRows(), *(pars.bs_norm)); cdecomp = std::make_unique<SimilarityDecomp>(c.getData(), c.numRows(), *(pars.bs_norm));
cdecomp->check(pars, c); cdecomp->check(pars, c);
cdecomp->infoToPars(pars); cdecomp->infoToPars(pars);
if (*(pars.method) == SylvParams::recurse) if (*(pars.method) == SylvParams::solve_method::recurse)
sylv = std::make_unique<TriangularSylvester>(*bdecomp, *cdecomp); sylv = std::make_unique<TriangularSylvester>(*bdecomp, *cdecomp);
else else
sylv = std::make_unique<IterativeSylvester>(*bdecomp, *cdecomp); sylv = std::make_unique<IterativeSylvester>(*bdecomp, *cdecomp);

View File

@ -27,9 +27,7 @@ KronUtils::multAtLevel(int level, const QuasiTriangular &t,
t.multVec(x, b); t.multVec(x, b);
} }
else // 0 < level == depth else // 0 < level == depth
{ t.multKron(x);
t.multKron(x);
}
} }
void void
@ -55,9 +53,7 @@ KronUtils::multAtLevelTrans(int level, const QuasiTriangular &t,
t.multVecTrans(x, b); t.multVecTrans(x, b);
} }
else // 0 < level == depth else // 0 < level == depth
{ t.multKronTrans(x);
t.multKronTrans(x);
}
} }
void void

View File

@ -37,18 +37,17 @@ KronVector::KronVector(KronVector &v, int i)
} }
KronVector::KronVector(const ConstKronVector &v) KronVector::KronVector(const ConstKronVector &v)
: Vector(v.length()), m(v.getM()), n(v.getN()), depth(v.getDepth()) : Vector(v), m(v.m), n(v.n), depth(v.depth)
{ {
Vector::operator=(v);
} }
KronVector & KronVector &
KronVector::operator=(const ConstKronVector &v) KronVector::operator=(const ConstKronVector &v)
{ {
Vector::operator=(v); Vector::operator=(v);
m = v.getM(); m = v.m;
n = v.getN(); n = v.n;
depth = v.getDepth(); depth = v.depth;
return *this; return *this;
} }
@ -62,13 +61,7 @@ KronVector::operator=(const Vector &v)
} }
ConstKronVector::ConstKronVector(const KronVector &v) ConstKronVector::ConstKronVector(const KronVector &v)
: ConstVector(v), m(v.getM()), n(v.getN()), depth(v.getDepth()) : ConstVector(v), m(v.m), n(v.n), depth(v.depth)
{
}
ConstKronVector::ConstKronVector(const ConstKronVector &v)
: ConstVector(power(v.getM(), v.getDepth())*v.getN()), m(v.getM()), n(v.getN()),
depth(v.getDepth())
{ {
} }

View File

@ -11,18 +11,22 @@ class ConstKronVector;
class KronVector : public Vector class KronVector : public Vector
{ {
friend class ConstKronVector;
protected: protected:
int m{0}; int m{0};
int n{0}; int n{0};
int depth{0}; int depth{0};
public: public:
KronVector() = default; KronVector() = default;
KronVector(const KronVector &v) = default;
KronVector(KronVector &&v) = default;
KronVector(int mm, int nn, int dp); // new instance KronVector(int mm, int nn, int dp); // new instance
KronVector(Vector &v, int mm, int nn, int dp); // conversion KronVector(Vector &v, int mm, int nn, int dp); // conversion
KronVector(KronVector &, int i); // picks i-th subvector KronVector(KronVector &, int i); // picks i-th subvector
// We don't want implict conversion from ConstKronVector, since it's expensive // We don't want implict conversion from ConstKronVector, since it's expensive
explicit KronVector(const ConstKronVector &v); // new instance and copy explicit KronVector(const ConstKronVector &v); // new instance and copy
KronVector &operator=(const KronVector &v) = default; KronVector &operator=(const KronVector &v) = default;
KronVector &operator=(KronVector &&v) = default;
KronVector &operator=(const ConstKronVector &v); KronVector &operator=(const ConstKronVector &v);
KronVector &operator=(const Vector &v); KronVector &operator=(const Vector &v);
int int
@ -44,6 +48,7 @@ public:
class ConstKronVector : public ConstVector class ConstKronVector : public ConstVector
{ {
friend class KronVector;
protected: protected:
int m; int m;
int n; int n;
@ -51,11 +56,14 @@ protected:
public: public:
// Implicit conversion from KronVector is ok, since it's cheap // Implicit conversion from KronVector is ok, since it's cheap
ConstKronVector(const KronVector &v); ConstKronVector(const KronVector &v);
ConstKronVector(const ConstKronVector &v); ConstKronVector(const ConstKronVector &v) = default;
ConstKronVector(ConstKronVector &&v) = default;
ConstKronVector(const Vector &v, int mm, int nn, int dp); ConstKronVector(const Vector &v, int mm, int nn, int dp);
ConstKronVector(ConstVector v, int mm, int nn, int dp); ConstKronVector(ConstVector v, int mm, int nn, int dp);
ConstKronVector(const KronVector &v, int i); ConstKronVector(const KronVector &v, int i);
ConstKronVector(const ConstKronVector &v, int i); ConstKronVector(const ConstKronVector &v, int i);
ConstKronVector &operator=(const ConstKronVector &v) = delete;
ConstKronVector &operator=(ConstKronVector &&v) = delete;
int int
getM() const getM() const
{ {

View File

@ -12,8 +12,6 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
using namespace std;
double double
DiagonalBlock::getDeterminant() const DiagonalBlock::getDeterminant() const
{ {
@ -30,9 +28,9 @@ double
DiagonalBlock::getSize() const DiagonalBlock::getSize() const
{ {
if (real) if (real)
return abs(*alpha); return std::abs(*alpha);
else else
return sqrt(getDeterminant()); return std::sqrt(getDeterminant());
} }
// this function makes Diagonal inconsistent, it should only be used // this function makes Diagonal inconsistent, it should only be used
@ -119,28 +117,18 @@ Diagonal::Diagonal(double *data, const Diagonal &d)
} }
} }
void
Diagonal::copy(const Diagonal &d)
{
num_all = d.num_all;
num_real = d.num_real;
blocks = d.blocks;
}
int int
Diagonal::getNumComplex(const double *data, int d_size) Diagonal::getNumComplex(const double *data, int d_size)
{ {
int num_complex = 0; int num_complex = 0;
int in = 1; int in = 1;
for (int i = 0; i < d_size-1; i++, in = in + d_size + 1) for (int i = 0; i < d_size-1; i++, in = in + d_size + 1)
{ if (!isZero(data[in]))
if (!isZero(data[in])) {
{ num_complex++;
num_complex++; if (in < d_size - 2 && !isZero(data[in + d_size +1]))
if (in < d_size - 2 && !isZero(data[in + d_size +1])) throw SYLV_MES_EXCEPTION("Matrix is not quasi-triangular");
throw SYLV_MES_EXCEPTION("Matrix is not quasi-triangular"); }
}
}
return num_complex; return num_complex;
} }
@ -174,7 +162,7 @@ Diagonal::getEigenValues(Vector &eig) const
int d_size = getSize(); int d_size = getSize();
if (eig.length() != 2*d_size) if (eig.length() != 2*d_size)
{ {
ostringstream mes; std::ostringstream mes;
mes << "Wrong length of vector for eigenvalues len=" << eig.length() mes << "Wrong length of vector for eigenvalues len=" << eig.length()
<< ", should be=" << 2*d_size << '.' << std::endl; << ", should be=" << 2*d_size << '.' << std::endl;
throw SYLV_MES_EXCEPTION(mes.str()); throw SYLV_MES_EXCEPTION(mes.str());
@ -187,7 +175,7 @@ Diagonal::getEigenValues(Vector &eig) const
eig[2*ind+1] = 0.0; eig[2*ind+1] = 0.0;
else else
{ {
double beta = sqrt(b.getSBeta()); double beta = std::sqrt(b.getSBeta());
eig[2*ind+1] = beta; eig[2*ind+1] = beta;
eig[2*ind+2] = eig[2*ind]; eig[2*ind+2] = eig[2*ind];
eig[2*ind+3] = -beta; eig[2*ind+3] = -beta;
@ -204,28 +192,28 @@ Diagonal::swapLogically(diag_iter it)
diag_iter itp = it; diag_iter itp = it;
++itp; ++itp;
if ((*it).isReal() && !(*itp).isReal()) if (it->isReal() && !itp->isReal())
{ {
// first is real, second is complex // first is real, second is complex
double *d1 = (*it).alpha.a1; double *d1 = it->alpha.a1;
double *d2 = (*itp).alpha.a1; double *d2 = itp->alpha.a1;
double *d3 = (*itp).alpha.a2; double *d3 = itp->alpha.a2;
// swap // swap
DiagonalBlock new_it((*it).jbar, d1, d2); DiagonalBlock new_it(it->jbar, d1, d2);
*it = new_it; *it = new_it;
DiagonalBlock new_itp((*itp).jbar+1, d3); DiagonalBlock new_itp(itp->jbar+1, d3);
*itp = new_itp; *itp = new_itp;
} }
else if (!(*it).isReal() && (*itp).isReal()) else if (!it->isReal() && itp->isReal())
{ {
// first is complex, second is real // first is complex, second is real
double *d1 = (*it).alpha.a1; double *d1 = it->alpha.a1;
double *d2 = (*it).alpha.a2; double *d2 = it->alpha.a2;
double *d3 = (*itp).alpha.a1; double *d3 = itp->alpha.a1;
// swap // swap
DiagonalBlock new_it((*it).jbar, d1); DiagonalBlock new_it(it->jbar, d1);
*it = new_it; *it = new_it;
DiagonalBlock new_itp((*itp).jbar-1, d2, d3); DiagonalBlock new_itp(itp->jbar-1, d2, d3);
*itp = new_itp; *itp = new_itp;
} }
} }
@ -233,17 +221,17 @@ Diagonal::swapLogically(diag_iter it)
void void
Diagonal::checkConsistency(diag_iter it) Diagonal::checkConsistency(diag_iter it)
{ {
if (!(*it).isReal() && isZero((*it).getBeta2())) if (!it->isReal() && isZero(it->getBeta2()))
{ {
(*it).getBeta2() = 0.0; // put exact zero it->getBeta2() = 0.0; // put exact zero
int jbar = (*it).getIndex(); int jbar = it->getIndex();
double *d2 = (*it).alpha.a2; double *d2 = it->alpha.a2;
(*it).alpha.a2 = (*it).alpha.a1; it->alpha.a2 = it->alpha.a1;
(*it).real = true; it->real = true;
(*it).beta1 = nullptr; it->beta1 = nullptr;
(*it).beta2 = nullptr; it->beta2 = nullptr;
DiagonalBlock b(jbar+1, d2); DiagonalBlock b(jbar+1, d2);
blocks.insert((++it).iter(), b); blocks.insert(++it, b);
num_real += 2; num_real += 2;
num_all++; num_all++;
} }
@ -257,7 +245,7 @@ Diagonal::getAverageSize(diag_iter start, diag_iter end)
for (diag_iter run = start; run != end; ++run) for (diag_iter run = start; run != end; ++run)
{ {
num++; num++;
res += (*run).getSize(); res += run->getSize();
} }
if (num > 0) if (num > 0)
res = res/num; res = res/num;
@ -271,7 +259,7 @@ Diagonal::findClosestBlock(diag_iter start, diag_iter end, double a)
double minim = 1.0e100; double minim = 1.0e100;
for (diag_iter run = start; run != end; ++run) for (diag_iter run = start; run != end; ++run)
{ {
double dist = abs(a - (*run).getSize()); double dist = std::abs(a - run->getSize());
if (dist < minim) if (dist < minim)
{ {
minim = dist; minim = dist;
@ -288,7 +276,7 @@ Diagonal::findNextLargerBlock(diag_iter start, diag_iter end, double a)
double minim = 1.0e100; double minim = 1.0e100;
for (diag_iter run = start; run != end; ++run) for (diag_iter run = start; run != end; ++run)
{ {
double dist = (*run).getSize() - a; double dist = run->getSize() - a;
if ((0 <= dist) && (dist < minim)) if ((0 <= dist) && (dist < minim))
{ {
minim = dist; minim = dist;
@ -318,7 +306,7 @@ Diagonal::print() const
bool bool
Diagonal::isZero(double p) Diagonal::isZero(double p)
{ {
return (abs(p) < EPS); return (std::abs(p) < EPS);
} }
QuasiTriangular::const_col_iter QuasiTriangular::const_col_iter
@ -453,21 +441,15 @@ QuasiTriangular::QuasiTriangular(const SchurDecompZero &decomp)
zv.zeros(); zv.zeros();
// fill right upper part with decomp.getRU() // fill right upper part with decomp.getRU()
for (int i = 0; i < decomp.getRU().numRows(); i++) for (int i = 0; i < decomp.getRU().numRows(); i++)
{ for (int j = 0; j < decomp.getRU().numCols(); j++)
for (int j = 0; j < decomp.getRU().numCols(); j++) getData()[(j+decomp.getZeroCols())*decomp.getDim()+i] = decomp.getRU().get(i, j);
{
getData()[(j+decomp.getZeroCols())*decomp.getDim()+i] = decomp.getRU().get(i, j);
}
}
// fill right lower part with decomp.getT() // fill right lower part with decomp.getT()
for (int i = 0; i < decomp.getT().numRows(); i++) for (int i = 0; i < decomp.getT().numRows(); i++)
{ for (int j = 0; j < decomp.getT().numCols(); j++)
for (int j = 0; j < decomp.getT().numCols(); j++) getData()[(j+decomp.getZeroCols())*decomp.getDim()+decomp.getZeroCols()+i]
{ = decomp.getT().get(i, j);
getData()[(j+decomp.getZeroCols())*decomp.getDim()+decomp.getZeroCols()+i]
= decomp.getT().get(i, j);
}
}
// construct diagonal // construct diagonal
diagonal = Diagonal{getData().base(), decomp.getDim()}; diagonal = Diagonal{getData().base(), decomp.getDim()};
} }
@ -487,26 +469,22 @@ QuasiTriangular::setMatrixViaIter(double r, const QuasiTriangular &t)
const_diag_iter dir = t.diag_begin(); const_diag_iter dir = t.diag_begin();
for (; dil != diag_end(); ++dil, ++dir) for (; dil != diag_end(); ++dil, ++dir)
{ {
(*dil).getAlpha() = rr*(*(*dir).getAlpha()); dil->getAlpha() = rr*(*dir->getAlpha());
if (!(*dil).isReal()) if (!dil->isReal())
{ {
(*dil).getBeta1() = rr*(*dir).getBeta1(); dil->getBeta1() = rr*dir->getBeta1();
(*dil).getBeta2() = rr*(*dir).getBeta2(); dil->getBeta2() = rr*dir->getBeta2();
} }
col_iter cil = col_begin(*dil); col_iter cil = col_begin(*dil);
const_col_iter cir = t.col_begin(*dir); const_col_iter cir = t.col_begin(*dir);
for (; cil != col_end(*dil); ++cil, ++cir) for (; cil != col_end(*dil); ++cil, ++cir)
{ if (dil->isReal())
if ((*dil).isReal()) *cil = rr*(*cir);
{ else
*cil = rr*(*cir); {
} cil.a() = rr*cir.a();
else cil.b() = rr*cir.b();
{ }
cil.a() = rr*cir.a();
cil.b() = rr*cir.b();
}
}
} }
} }
@ -524,26 +502,22 @@ QuasiTriangular::addMatrixViaIter(double r, const QuasiTriangular &t)
const_diag_iter dir = t.diag_begin(); const_diag_iter dir = t.diag_begin();
for (; dil != diag_end(); ++dil, ++dir) for (; dil != diag_end(); ++dil, ++dir)
{ {
(*dil).getAlpha() = (*(*dil).getAlpha()) + rr*(*(*dir).getAlpha()); dil->getAlpha() = (*dil->getAlpha()) + rr*(*dir->getAlpha());
if (!(*dil).isReal()) if (!dil->isReal())
{ {
(*dil).getBeta1() += rr*(*dir).getBeta1(); dil->getBeta1() += rr*dir->getBeta1();
(*dil).getBeta2() += rr*(*dir).getBeta2(); dil->getBeta2() += rr*dir->getBeta2();
} }
col_iter cil = col_begin(*dil); col_iter cil = col_begin(*dil);
const_col_iter cir = t.col_begin(*dir); const_col_iter cir = t.col_begin(*dir);
for (; cil != col_end(*dil); ++cil, ++cir) for (; cil != col_end(*dil); ++cil, ++cir)
{ if (dil->isReal())
if ((*dil).isReal()) *cil += rr*(*cir);
{ else
*cil += rr*(*cir); {
} cil.a() += rr*cir.a();
else cil.b() += rr*cir.b();
{ }
cil.a() += rr*cir.a();
cil.b() += rr*cir.b();
}
}
} }
} }
@ -551,9 +525,7 @@ void
QuasiTriangular::addUnit() QuasiTriangular::addUnit()
{ {
for (diag_iter di = diag_begin(); di != diag_end(); ++di) for (diag_iter di = diag_begin(); di != diag_end(); ++di)
{ di->getAlpha() = *(di->getAlpha()) + 1.0;
(*di).getAlpha() = *((*di).getAlpha()) + 1.0;
}
} }
void void
@ -577,15 +549,13 @@ QuasiTriangular::solvePre(Vector &x, double &eig_min)
for (diag_iter di = diag_begin(); di != diag_end(); ++di) for (diag_iter di = diag_begin(); di != diag_end(); ++di)
{ {
double eig_size; double eig_size;
if (!(*di).isReal()) if (!di->isReal())
{ {
eig_size = (*di).getDeterminant(); eig_size = di->getDeterminant();
eliminateLeft((*di).getIndex()+1, (*di).getIndex(), x); eliminateLeft(di->getIndex()+1, di->getIndex(), x);
} }
else else
{ eig_size = *di->getAlpha()*(*di->getAlpha());
eig_size = *(*di).getAlpha()*(*(*di).getAlpha());
}
if (eig_size < eig_min) if (eig_size < eig_min)
eig_min = eig_size; eig_min = eig_size;
} }
@ -603,15 +573,13 @@ QuasiTriangular::solvePreTrans(Vector &x, double &eig_min)
for (diag_iter di = diag_begin(); di != diag_end(); ++di) for (diag_iter di = diag_begin(); di != diag_end(); ++di)
{ {
double eig_size; double eig_size;
if (!(*di).isReal()) if (!di->isReal())
{ {
eig_size = (*di).getDeterminant(); eig_size = di->getDeterminant();
eliminateRight((*di).getIndex()+1, (*di).getIndex(), x); eliminateRight(di->getIndex()+1, di->getIndex(), x);
} }
else else
{ eig_size = *di->getAlpha()*(*di->getAlpha());
eig_size = *(*di).getAlpha()*(*(*di).getAlpha());
}
if (eig_size < eig_min) if (eig_size < eig_min)
eig_min = eig_size; eig_min = eig_size;
} }
@ -632,13 +600,11 @@ QuasiTriangular::multVec(Vector &x, const ConstVector &b) const
blas_int incx = x.skip(); blas_int incx = x.skip();
dtrmv("U", "N", "N", &nn, getData().base(), &lda, x.base(), &incx); dtrmv("U", "N", "N", &nn, getData().base(), &lda, x.base(), &incx);
for (const_diag_iter di = diag_begin(); di != diag_end(); ++di) for (const_diag_iter di = diag_begin(); di != diag_end(); ++di)
{ if (!di->isReal())
if (!(*di).isReal()) {
{ int jbar = di->getIndex();
int jbar = (*di).getIndex(); x[jbar+1] += di->getBeta2()*(b[jbar]);
x[jbar+1] += (*di).getBeta2()*(b[jbar]); }
}
}
} }
void void
@ -650,13 +616,11 @@ QuasiTriangular::multVecTrans(Vector &x, const ConstVector &b) const
blas_int incx = x.skip(); blas_int incx = x.skip();
dtrmv("U", "T", "N", &nn, getData().base(), &lda, x.base(), &incx); dtrmv("U", "T", "N", &nn, getData().base(), &lda, x.base(), &incx);
for (const_diag_iter di = diag_begin(); di != diag_end(); ++di) for (const_diag_iter di = diag_begin(); di != diag_end(); ++di)
{ if (!di->isReal())
if (!(*di).isReal()) {
{ int jbar = di->getIndex();
int jbar = (*di).getIndex(); x[jbar] += di->getBeta2()*b[jbar+1];
x[jbar] += (*di).getBeta2()*b[jbar+1]; }
}
}
} }
void void

View File

@ -12,8 +12,6 @@
#include <list> #include <list>
#include <memory> #include <memory>
using namespace std;
class DiagonalBlock; class DiagonalBlock;
class Diagonal; class Diagonal;
class DiagPair class DiagPair
@ -23,19 +21,17 @@ private:
double *a2; double *a2;
public: public:
DiagPair() = default; DiagPair() = default;
DiagPair(double *aa1, double *aa2) DiagPair(double *aa1, double *aa2) : a1{aa1}, a2{aa2}
{ {
a1 = aa1; a2 = aa2;
}
DiagPair(const DiagPair &p)
{
a1 = p.a1; a2 = p.a2;
} }
DiagPair(const DiagPair &p) = default;
DiagPair &operator=(const DiagPair &p) = default; DiagPair &operator=(const DiagPair &p) = default;
DiagPair & DiagPair &
operator=(double v) operator=(double v)
{ {
*a1 = v; *a2 = v; return *this; *a1 = v;
*a2 = v;
return *this;
} }
const double & const double &
operator*() const operator*() const
@ -48,6 +44,9 @@ public:
friend class DiagonalBlock; friend class DiagonalBlock;
}; };
// Stores a diagonal block: either a scalar, or a 2x2 block
/* alpha points to the diagonal element(s); beta1 and beta2 point to the
off-diagonal elements of the 2x2 block */
class DiagonalBlock class DiagonalBlock
{ {
private: private:
@ -57,54 +56,25 @@ private:
double *beta1; double *beta1;
double *beta2; double *beta2;
void
copy(const DiagonalBlock &b)
{
jbar = b.jbar;
real = b.real;
alpha = b.alpha;
beta1 = b.beta1;
beta2 = b.beta2;
}
public: public:
DiagonalBlock() = default; DiagonalBlock() = default;
DiagonalBlock(int jb, bool r, double *a1, double *a2, DiagonalBlock(int jb, bool r, double *a1, double *a2,
double *b1, double *b2) double *b1, double *b2)
: alpha(a1, a2) : jbar{jb}, real{r}, alpha{a1, a2}, beta1{b1}, beta2{b2}
{ {
jbar = jb;
real = r;
beta1 = b1;
beta2 = b2;
} }
// construct complex block // construct complex block
DiagonalBlock(int jb, double *a1, double *a2) DiagonalBlock(int jb, double *a1, double *a2)
: alpha(a1, a2) : jbar{jb}, real{false}, alpha{a1, a2}, beta1{a2-1}, beta2{a1+1}
{ {
jbar = jb;
real = false;
beta1 = a2 - 1;
beta2 = a1 + 1;
} }
// construct real block // construct real block
DiagonalBlock(int jb, double *a1) DiagonalBlock(int jb, double *a1)
: alpha(a1, a1) : jbar{jb}, real{true}, alpha{a1, a1}, beta1{nullptr}, beta2{nullptr}
{ {
jbar = jb;
real = true;
beta1 = nullptr;
beta2 = nullptr;
}
DiagonalBlock(const DiagonalBlock &b)
{
copy(b);
}
DiagonalBlock &
operator=(const DiagonalBlock &b)
{
copy(b); return *this;
} }
DiagonalBlock(const DiagonalBlock &b) = default;
DiagonalBlock &operator=(const DiagonalBlock &b) = default;
int int
getIndex() const getIndex() const
{ {
@ -144,76 +114,21 @@ public:
friend class Diagonal; friend class Diagonal;
}; };
template <class _Tdiag, class _Tblock, class _Titer>
struct _diag_iter
{
using _Self = _diag_iter<_Tdiag, _Tblock, _Titer>;
_Tdiag diag;
_Titer it;
public:
_diag_iter(_Tdiag d, _Titer iter) : diag(d), it(iter)
{
}
_Tblock
operator*() const
{
return *it;
}
_Self &
operator++()
{
++it; return *this;
}
_Self &
operator--()
{
--it; return *this;
}
bool
operator==(const _Self &x) const
{
return x.it == it;
}
bool
operator!=(const _Self &x) const
{
return x.it != it;
}
_Self &
operator=(const _Self &x)
{
it = x.it; return *this;
}
_Titer
iter() const
{
return it;
}
};
class Diagonal class Diagonal
{ {
public: public:
using const_diag_iter = _diag_iter<const Diagonal &, const DiagonalBlock &, list<DiagonalBlock>::const_iterator>; using const_diag_iter = std::list<DiagonalBlock>::const_iterator;
using diag_iter = _diag_iter<Diagonal &, DiagonalBlock &, list<DiagonalBlock>::iterator>; using diag_iter = std::list<DiagonalBlock>::iterator;
private: private:
int num_all{0}; int num_all{0};
list<DiagonalBlock> blocks; std::list<DiagonalBlock> blocks;
int num_real{0}; int num_real{0};
void copy(const Diagonal &);
public: public:
Diagonal() = default; Diagonal() = default;
Diagonal(double *data, int d_size); Diagonal(double *data, int d_size);
Diagonal(double *data, const Diagonal &d); Diagonal(double *data, const Diagonal &d);
Diagonal(const Diagonal &d) Diagonal(const Diagonal &d) = default;
{ Diagonal &operator=(const Diagonal &d) = default;
copy(d);
}
Diagonal &
operator=(const Diagonal &d)
{
copy(d); return *this;
}
virtual ~Diagonal() = default; virtual ~Diagonal() = default;
int int
@ -247,22 +162,22 @@ public:
diag_iter diag_iter
begin() begin()
{ {
return diag_iter(*this, blocks.begin()); return blocks.begin();
} }
const_diag_iter const_diag_iter
begin() const begin() const
{ {
return const_diag_iter(*this, blocks.begin()); return blocks.begin();
} }
diag_iter diag_iter
end() end()
{ {
return diag_iter(*this, blocks.end()); return blocks.end();
} }
const_diag_iter const_diag_iter
end() const end() const
{ {
return const_diag_iter(*this, blocks.end()); return blocks.end();
} }
/* redefine pointers as data start at p */ /* redefine pointers as data start at p */
@ -283,14 +198,11 @@ struct _matrix_iter
public: public:
_matrix_iter(_TPtr base, int ds, bool r) _matrix_iter(_TPtr base, int ds, bool r)
{ {
ptr = base; d_size = ds; real = r; ptr = base;
d_size = ds;
real = r;
} }
virtual ~_matrix_iter() = default; virtual ~_matrix_iter() = default;
_Self &
operator=(const _Self &it)
{
ptr = it.ptr; d_size = it.d_size; real = it.real; return *this;
}
bool bool
operator==(const _Self &it) const operator==(const _Self &it) const
{ {
@ -328,19 +240,17 @@ public:
_Self & _Self &
operator++() override operator++() override
{ {
_Tparent::ptr++; row++; return *this; _Tparent::ptr++;
row++;
return *this;
} }
_TRef _TRef
b() const b() const
{ {
if (_Tparent::real) if (_Tparent::real)
{ return *(_Tparent::ptr);
return *(_Tparent::ptr);
}
else else
{ return *(_Tparent::ptr+_Tparent::d_size);
return *(_Tparent::ptr+_Tparent::d_size);
}
} }
int int
getRow() const getRow() const
@ -363,19 +273,17 @@ public:
_Self & _Self &
operator++() override operator++() override
{ {
_Tparent::ptr += _Tparent::d_size; col++; return *this; _Tparent::ptr += _Tparent::d_size;
col++;
return *this;
} }
virtual _TRef virtual _TRef
b() const b() const
{ {
if (_Tparent::real) if (_Tparent::real)
{ return *(_Tparent::ptr);
return *(_Tparent::ptr);
}
else else
{ return *(_Tparent::ptr+1);
return *(_Tparent::ptr+1);
}
} }
int int
getCol() const getCol() const
@ -387,6 +295,12 @@ public:
class SchurDecomp; class SchurDecomp;
class SchurDecompZero; class SchurDecompZero;
/* Represents an upper quasi-triangular matrix.
All the elements are stored in the SqSylvMatrix super-class.
Additionally, a list of the diagonal blocks (1x1 or 2x2), is stored in the
"diagonal" member, in order to optimize some operations (where the matrix is
seen as an upper-triangular matrix, plus sub-diagonal elements of the 2x2
diagonal blocks) */
class QuasiTriangular : public SqSylvMatrix class QuasiTriangular : public SqSylvMatrix
{ {
public: public:

View File

@ -19,9 +19,7 @@ SchurDecompEig::bubbleEigen(diag_iter from, diag_iter to)
{ {
diag_iter runm = run; diag_iter runm = run;
if (!tryToSwap(run, runm) && runm == to) if (!tryToSwap(run, runm) && runm == to)
{ ++to;
++to;
}
else else
{ {
// bubble all eigenvalues from runm(incl.) to run(excl.), // bubble all eigenvalues from runm(incl.) to run(excl.),

View File

@ -120,9 +120,7 @@ SimilarityDecomp::diagonalize(double norm)
++end; ++end;
} }
else else
{ bringGuiltyBlock(start, end); // moves with end
bringGuiltyBlock(start, end); // moves with end
}
} }
} }

View File

@ -8,7 +8,6 @@
#include <dynblas.h> #include <dynblas.h>
#include <dynlapack.h> #include <dynlapack.h>
#include <cstring>
#include <cmath> #include <cmath>
#include <vector> #include <vector>
@ -102,9 +101,7 @@ SylvMatrix::eliminateLeft(int row, int col, Vector &x)
get(row, col) = 0.0; get(row, col) = 0.0;
double mult = e/d; double mult = e/d;
for (int i = col + 1; i < numCols(); i++) for (int i = col + 1; i < numCols(); i++)
{ get(row, i) = get(row, i) - mult*get(col, i);
get(row, i) = get(row, i) - mult*get(col, i);
}
x[row] = x[row] - mult*x[col]; x[row] = x[row] - mult*x[col];
} }
else if (std::abs(e) > std::abs(d)) else if (std::abs(e) > std::abs(d))
@ -137,9 +134,7 @@ SylvMatrix::eliminateRight(int row, int col, Vector &x)
get(row, col) = 0.0; get(row, col) = 0.0;
double mult = e/d; double mult = e/d;
for (int i = 0; i < row; i++) for (int i = 0; i < row; i++)
{ get(i, col) = get(i, col) - mult*get(i, row);
get(i, col) = get(i, col) - mult*get(i, row);
}
x[col] = x[col] - mult*x[row]; x[col] = x[col] - mult*x[row];
} }
else if (std::abs(e) > std::abs(d)) else if (std::abs(e) > std::abs(d))

View File

@ -39,6 +39,10 @@ public:
: GeneralMatrix(a, b) : GeneralMatrix(a, b)
{ {
} }
SylvMatrix(const SylvMatrix &m) = default;
SylvMatrix(SylvMatrix &&m) = default;
SylvMatrix &operator=(const SylvMatrix &m) = default;
SylvMatrix &operator=(SylvMatrix &&m) = default;
/* this = |I 0|* this /* this = |I 0|* this
|0 m| */ |0 m| */
@ -70,6 +74,7 @@ public:
{ {
} }
SqSylvMatrix(const SqSylvMatrix &m) = default; SqSylvMatrix(const SqSylvMatrix &m) = default;
SqSylvMatrix(SqSylvMatrix &&m) = default;
SqSylvMatrix(const GeneralMatrix &m, int i, int j, int nrows) SqSylvMatrix(const GeneralMatrix &m, int i, int j, int nrows)
: SylvMatrix(m, i, j, nrows, nrows) : SylvMatrix(m, i, j, nrows, nrows)
{ {
@ -79,12 +84,8 @@ public:
{ {
} }
SqSylvMatrix(const GeneralMatrix &a, const GeneralMatrix &b); SqSylvMatrix(const GeneralMatrix &a, const GeneralMatrix &b);
SqSylvMatrix & SqSylvMatrix &operator=(const SqSylvMatrix &m) = default;
operator=(const SqSylvMatrix &m) SqSylvMatrix &operator=(SqSylvMatrix &&m) = default;
{
GeneralMatrix::operator=(m);
return *this;
}
/* x = (this \otimes this..\otimes this)*d */ /* x = (this \otimes this..\otimes this)*d */
void multVecKron(KronVector &x, const ConstKronVector &d) const; void multVecKron(KronVector &x, const ConstKronVector &d) const;
/* x = (this' \otimes this'..\otimes this')*d */ /* x = (this' \otimes this'..\otimes this')*d */

View File

@ -29,7 +29,7 @@ SylvParams::print(std::ostream &fdesc, const std::string &prefix) const
f_largest.print(fdesc, prefix, "largest block in F "); f_largest.print(fdesc, prefix, "largest block in F ");
f_zeros.print(fdesc, prefix, "num zeros in F "); f_zeros.print(fdesc, prefix, "num zeros in F ");
f_offdiag.print(fdesc, prefix, "num offdiag in F "); f_offdiag.print(fdesc, prefix, "num offdiag in F ");
if (*method == iter) if (*method == solve_method::iter)
{ {
converged.print(fdesc, prefix, "converged "); converged.print(fdesc, prefix, "converged ");
convergence_tol.print(fdesc, prefix, "convergence tol. "); convergence_tol.print(fdesc, prefix, "convergence tol. ");
@ -52,57 +52,57 @@ void
SylvParams::setArrayNames(int &num, const char **names) const SylvParams::setArrayNames(int &num, const char **names) const
{ {
num = 0; num = 0;
if (method.getStatus() != undef) if (method.getStatus() != status::undef)
names[num++] = "method"; names[num++] = "method";
if (convergence_tol.getStatus() != undef) if (convergence_tol.getStatus() != status::undef)
names[num++] = "convergence_tol"; names[num++] = "convergence_tol";
if (max_num_iter.getStatus() != undef) if (max_num_iter.getStatus() != status::undef)
names[num++] = "max_num_iter"; names[num++] = "max_num_iter";
if (bs_norm.getStatus() != undef) if (bs_norm.getStatus() != status::undef)
names[num++] = "bs_norm"; names[num++] = "bs_norm";
if (converged.getStatus() != undef) if (converged.getStatus() != status::undef)
names[num++] = "converged"; names[num++] = "converged";
if (iter_last_norm.getStatus() != undef) if (iter_last_norm.getStatus() != status::undef)
names[num++] = "iter_last_norm"; names[num++] = "iter_last_norm";
if (num_iter.getStatus() != undef) if (num_iter.getStatus() != status::undef)
names[num++] = "num_iter"; names[num++] = "num_iter";
if (f_err1.getStatus() != undef) if (f_err1.getStatus() != status::undef)
names[num++] = "f_err1"; names[num++] = "f_err1";
if (f_errI.getStatus() != undef) if (f_errI.getStatus() != status::undef)
names[num++] = "f_errI"; names[num++] = "f_errI";
if (viv_err1.getStatus() != undef) if (viv_err1.getStatus() != status::undef)
names[num++] = "viv_err1"; names[num++] = "viv_err1";
if (viv_errI.getStatus() != undef) if (viv_errI.getStatus() != status::undef)
names[num++] = "viv_errI"; names[num++] = "viv_errI";
if (ivv_err1.getStatus() != undef) if (ivv_err1.getStatus() != status::undef)
names[num++] = "ivv_err1"; names[num++] = "ivv_err1";
if (ivv_errI.getStatus() != undef) if (ivv_errI.getStatus() != status::undef)
names[num++] = "ivv_errI"; names[num++] = "ivv_errI";
if (f_blocks.getStatus() != undef) if (f_blocks.getStatus() != status::undef)
names[num++] = "f_blocks"; names[num++] = "f_blocks";
if (f_largest.getStatus() != undef) if (f_largest.getStatus() != status::undef)
names[num++] = "f_largest"; names[num++] = "f_largest";
if (f_zeros.getStatus() != undef) if (f_zeros.getStatus() != status::undef)
names[num++] = "f_zeros"; names[num++] = "f_zeros";
if (f_offdiag.getStatus() != undef) if (f_offdiag.getStatus() != status::undef)
names[num++] = "f_offdiag"; names[num++] = "f_offdiag";
if (rcondA1.getStatus() != undef) if (rcondA1.getStatus() != status::undef)
names[num++] = "rcondA1"; names[num++] = "rcondA1";
if (rcondAI.getStatus() != undef) if (rcondAI.getStatus() != status::undef)
names[num++] = "rcondAI"; names[num++] = "rcondAI";
if (eig_min.getStatus() != undef) if (eig_min.getStatus() != status::undef)
names[num++] = "eig_min"; names[num++] = "eig_min";
if (mat_err1.getStatus() != undef) if (mat_err1.getStatus() != status::undef)
names[num++] = "mat_err1"; names[num++] = "mat_err1";
if (mat_errI.getStatus() != undef) if (mat_errI.getStatus() != status::undef)
names[num++] = "mat_errI"; names[num++] = "mat_errI";
if (mat_errF.getStatus() != undef) if (mat_errF.getStatus() != status::undef)
names[num++] = "mat_errF"; names[num++] = "mat_errF";
if (vec_err1.getStatus() != undef) if (vec_err1.getStatus() != status::undef)
names[num++] = "vec_err1"; names[num++] = "vec_err1";
if (vec_errI.getStatus() != undef) if (vec_errI.getStatus() != status::undef)
names[num++] = "vec_errI"; names[num++] = "vec_errI";
if (cpu_time.getStatus() != undef) if (cpu_time.getStatus() != status::undef)
names[num++] = "cpu_time"; names[num++] = "cpu_time";
} }
@ -133,7 +133,7 @@ SylvParams::BoolParamItem::createMatlabArray() const
mxArray * mxArray *
SylvParams::MethodParamItem::createMatlabArray() const SylvParams::MethodParamItem::createMatlabArray() const
{ {
if (value == iter) if (value == solve_method::iter)
return mxCreateString("iterative"); return mxCreateString("iterative");
else else
return mxCreateString("recursive"); return mxCreateString("recursive");
@ -149,57 +149,57 @@ SylvParams::createStructArray() const
mxArray *const res = mxCreateStructArray(2, dims, num, names); mxArray *const res = mxCreateStructArray(2, dims, num, names);
int i = 0; int i = 0;
if (method.getStatus() != undef) if (method.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, method.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, method.createMatlabArray());
if (convergence_tol.getStatus() != undef) if (convergence_tol.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, convergence_tol.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, convergence_tol.createMatlabArray());
if (max_num_iter.getStatus() != undef) if (max_num_iter.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, max_num_iter.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, max_num_iter.createMatlabArray());
if (bs_norm.getStatus() != undef) if (bs_norm.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, bs_norm.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, bs_norm.createMatlabArray());
if (converged.getStatus() != undef) if (converged.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, converged.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, converged.createMatlabArray());
if (iter_last_norm.getStatus() != undef) if (iter_last_norm.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, iter_last_norm.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, iter_last_norm.createMatlabArray());
if (num_iter.getStatus() != undef) if (num_iter.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, num_iter.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, num_iter.createMatlabArray());
if (f_err1.getStatus() != undef) if (f_err1.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, f_err1.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, f_err1.createMatlabArray());
if (f_errI.getStatus() != undef) if (f_errI.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, f_errI.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, f_errI.createMatlabArray());
if (viv_err1.getStatus() != undef) if (viv_err1.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, viv_err1.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, viv_err1.createMatlabArray());
if (viv_errI.getStatus() != undef) if (viv_errI.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, viv_errI.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, viv_errI.createMatlabArray());
if (ivv_err1.getStatus() != undef) if (ivv_err1.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, ivv_err1.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, ivv_err1.createMatlabArray());
if (ivv_errI.getStatus() != undef) if (ivv_errI.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, ivv_errI.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, ivv_errI.createMatlabArray());
if (f_blocks.getStatus() != undef) if (f_blocks.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, f_blocks.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, f_blocks.createMatlabArray());
if (f_largest.getStatus() != undef) if (f_largest.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, f_largest.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, f_largest.createMatlabArray());
if (f_zeros.getStatus() != undef) if (f_zeros.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, f_zeros.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, f_zeros.createMatlabArray());
if (f_offdiag.getStatus() != undef) if (f_offdiag.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, f_offdiag.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, f_offdiag.createMatlabArray());
if (rcondA1.getStatus() != undef) if (rcondA1.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, rcondA1.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, rcondA1.createMatlabArray());
if (rcondAI.getStatus() != undef) if (rcondAI.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, rcondAI.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, rcondAI.createMatlabArray());
if (eig_min.getStatus() != undef) if (eig_min.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, eig_min.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, eig_min.createMatlabArray());
if (mat_err1.getStatus() != undef) if (mat_err1.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, mat_err1.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, mat_err1.createMatlabArray());
if (mat_errI.getStatus() != undef) if (mat_errI.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, mat_errI.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, mat_errI.createMatlabArray());
if (mat_errF.getStatus() != undef) if (mat_errF.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, mat_errF.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, mat_errF.createMatlabArray());
if (vec_err1.getStatus() != undef) if (vec_err1.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, vec_err1.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, vec_err1.createMatlabArray());
if (vec_errI.getStatus() != undef) if (vec_errI.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, vec_errI.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, vec_errI.createMatlabArray());
if (cpu_time.getStatus() != undef) if (cpu_time.getStatus() != status::undef)
mxSetFieldByNumber(res, 0, i++, cpu_time.createMatlabArray()); mxSetFieldByNumber(res, 0, i++, cpu_time.createMatlabArray());
return res; return res;

View File

@ -12,7 +12,7 @@
# include <dynmex.h> # include <dynmex.h>
#endif #endif
typedef enum {def, changed, undef} status; enum class status {def, changed, undef};
template <class _Type> template <class _Type>
struct ParamItem struct ParamItem
@ -24,25 +24,21 @@ protected:
public: public:
ParamItem() ParamItem()
{ {
s = undef; s = status::undef;
} }
ParamItem(_Type val) ParamItem(_Type val)
{ {
value = val; s = def; value = val;
} s = status::def;
ParamItem(const _Self &item)
{
value = item.value; s = item.s;
}
_Self &
operator=(const _Self &item)
{
value = item.value; s = item.s; return *this;
} }
ParamItem(const _Self &item) = default;
_Self &operator=(const _Self &item) = default;
_Self & _Self &
operator=(const _Type &val) operator=(const _Type &val)
{ {
value = val; s = changed; return *this; value = val;
s = status::changed;
return *this;
} }
_Type _Type
operator*() const operator*() const
@ -57,10 +53,10 @@ public:
void void
print(std::ostream &out, const std::string &prefix, const std::string &str) const print(std::ostream &out, const std::string &prefix, const std::string &str) const
{ {
if (s == undef) if (s == status::undef)
return; return;
out << prefix << str << "= " << value; out << prefix << str << "= " << value;
if (s == def) if (s == status::def)
out << " <default>"; out << " <default>";
out << std::endl; out << std::endl;
} }
@ -69,7 +65,7 @@ public:
class SylvParams class SylvParams
{ {
public: public:
using solve_method = enum {iter, recurse}; enum class solve_method {iter, recurse};
protected: protected:
class DoubleParamItem : public ParamItem<double> class DoubleParamItem : public ParamItem<double>
@ -184,7 +180,7 @@ public:
DoubleParamItem cpu_time; // time of the job in CPU seconds DoubleParamItem cpu_time; // time of the job in CPU seconds
SylvParams(bool wc = false) SylvParams(bool wc = false)
: method(recurse), convergence_tol(1.e-30), max_num_iter(15), : method(solve_method::recurse), convergence_tol(1.e-30), max_num_iter(15),
bs_norm(1.3), want_check(wc) bs_norm(1.3), want_check(wc)
{ {
} }
@ -206,10 +202,10 @@ operator<<(std::ostream &out, SylvParams::solve_method m)
{ {
switch (m) switch (m)
{ {
case SylvParams::iter: case SylvParams::solve_method::iter:
out << "iterative"; out << "iterative";
break; break;
case SylvParams::recurse: case SylvParams::solve_method::recurse:
out << "recurse (a.k.a. triangular)"; out << "recurse (a.k.a. triangular)";
break; break;
} }

View File

@ -19,9 +19,6 @@ SymSchurDecomp::SymSchurDecomp(const ConstGeneralMatrix &mata)
throw SYLV_MES_EXCEPTION("Matrix is not square in SymSchurDecomp constructor"); throw SYLV_MES_EXCEPTION("Matrix is not square in SymSchurDecomp constructor");
// prepare for dsyevr // prepare for dsyevr
const char *jobz = "V";
const char *range = "A";
const char *uplo = "U";
lapack_int n = mata.numRows(); lapack_int n = mata.numRows();
GeneralMatrix tmpa(mata); GeneralMatrix tmpa(mata);
double *a = tmpa.base(); double *a = tmpa.base();
@ -45,7 +42,7 @@ SymSchurDecomp::SymSchurDecomp(const ConstGeneralMatrix &mata)
lapack_int info; lapack_int info;
// query for lwork and liwork // query for lwork and liwork
dsyevr(jobz, range, uplo, &n, a, &lda, vl, vu, il, iu, &abstol, dsyevr("V", "A", "U", &n, a, &lda, vl, vu, il, iu, &abstol,
&m, w, z, &ldz, isuppz.data(), &tmpwork, &lwork, &tmpiwork, &liwork, &info); &m, w, z, &ldz, isuppz.data(), &tmpwork, &lwork, &tmpiwork, &liwork, &info);
lwork = (int) tmpwork; lwork = (int) tmpwork;
liwork = tmpiwork; liwork = tmpiwork;
@ -54,7 +51,7 @@ SymSchurDecomp::SymSchurDecomp(const ConstGeneralMatrix &mata)
std::vector<lapack_int> iwork(liwork); std::vector<lapack_int> iwork(liwork);
// do the calculation // do the calculation
dsyevr(jobz, range, uplo, &n, a, &lda, vl, vu, il, iu, &abstol, dsyevr("V", "A", "U", &n, a, &lda, vl, vu, il, iu, &abstol,
&m, w, z, &ldz, isuppz.data(), work.data(), &lwork, iwork.data(), &liwork, &info); &m, w, z, &ldz, isuppz.data(), work.data(), &lwork, iwork.data(), &liwork, &info);
if (info < 0) if (info < 0)

View File

@ -48,7 +48,7 @@ TriangularSylvester::solve(SylvParams &pars, KronVector &d) const
{ {
double eig_min = 1e30; double eig_min = 1e30;
solvi(1., d, eig_min); solvi(1., d, eig_min);
pars.eig_min = sqrt(eig_min); pars.eig_min = std::sqrt(eig_min);
} }
void void
@ -122,7 +122,7 @@ TriangularSylvester::solviRealAndEliminate(double r, const_diag_iter di,
double f = *((*di).getAlpha()); double f = *((*di).getAlpha());
KronVector dj(d, jbar); KronVector dj(d, jbar);
// solve system // solve system
if (abs(r*f) > diag_zero) if (std::abs(r*f) > diag_zero)
solvi(r*f, dj, eig_min); solvi(r*f, dj, eig_min);
// calculate y // calculate y
KronVector y((const KronVector &)dj); KronVector y((const KronVector &)dj);
@ -277,8 +277,8 @@ TriangularSylvester::solviipComplex(double alpha, double betas, double gamma,
KronVector d2tmp(d2); KronVector d2tmp(d2);
quaEval(alpha, betas, gamma, delta1, delta2, quaEval(alpha, betas, gamma, delta1, delta2,
d1, d2, d1tmp, d2tmp); d1, d2, d1tmp, d2tmp);
double delta = sqrt(delta1*delta2); double delta = std::sqrt(delta1*delta2);
double beta = sqrt(betas); double beta = std::sqrt(betas);
double a1 = alpha*gamma - beta*delta; double a1 = alpha*gamma - beta*delta;
double b1 = alpha*delta + gamma*beta; double b1 = alpha*delta + gamma*beta;
double a2 = alpha*gamma + beta*delta; double a2 = alpha*gamma + beta*delta;

View File

@ -14,8 +14,6 @@
#include <iostream> #include <iostream>
#include <iomanip> #include <iomanip>
using namespace std;
ZeroPad zero_pad; ZeroPad zero_pad;
Vector::Vector(const Vector &v) Vector::Vector(const Vector &v)
@ -25,9 +23,9 @@ Vector::Vector(const Vector &v)
} }
Vector::Vector(const ConstVector &v) Vector::Vector(const ConstVector &v)
: len(v.length()), data{new double[len], [](double *arr) { delete[] arr; }} : len(v.len), data{new double[len], [](double *arr) { delete[] arr; }}
{ {
copy(v.base(), v.skip()); copy(v.base(), v.s);
} }
Vector & Vector &
@ -38,17 +36,13 @@ Vector::operator=(const Vector &v)
if (v.len != len) if (v.len != len)
throw SYLV_MES_EXCEPTION("Attempt to assign vectors with different lengths."); throw SYLV_MES_EXCEPTION("Attempt to assign vectors with different lengths.");
/*
if (s == v.s if (s == v.s
&& (data <= v.data && v.data < data+len*s && (base() <= v.base() && v.base() < base()+len*s
|| v.data <= data && data < v.data+v.len*v.s) || v.base() <= base() && base() < v.base()+v.len*v.s)
&& (data-v.data) % s == 0) && (base()-v.base()) % s == 0)
{ throw SYLV_MES_EXCEPTION("Attempt to assign overlapping vectors.");
std::cout << "this destroy=" << destroy << ", v destroy=" << v.destroy
<< ", data-v.data=" << (unsigned long) (data-v.data)
<< ", len=" << len << std::endl;
throw SYLV_MES_EXCEPTION("Attempt to assign overlapping vectors.");
} */
copy(v.base(), v.s); copy(v.base(), v.s);
return *this; return *this;
} }
@ -65,16 +59,15 @@ Vector::operator=(Vector &&v)
Vector & Vector &
Vector::operator=(const ConstVector &v) Vector::operator=(const ConstVector &v)
{ {
if (v.length() != len) if (v.len != len)
throw SYLV_MES_EXCEPTION("Attempt to assign vectors with different lengths."); throw SYLV_MES_EXCEPTION("Attempt to assign vectors with different lengths.");
/* if (s == v.s
if (v.skip() == 1 && skip() == 1 && ( && (base() <= v.base() && v.base() < base()+len*s
(base() < v.base() + v.length() && base() >= v.base()) || v.base() <= base() && base() < v.base()+v.len*v.s)
|| (base() + length() < v.base() + v.length() && (base()-v.base()) % s == 0)
&& base() + length() > v.base())))
throw SYLV_MES_EXCEPTION("Attempt to assign overlapping vectors."); throw SYLV_MES_EXCEPTION("Attempt to assign overlapping vectors.");
*/
copy(v.base(), v.skip()); copy(v.base(), v.s);
return *this; return *this;
} }
@ -210,7 +203,7 @@ void
Vector::add(double r, const ConstVector &v) Vector::add(double r, const ConstVector &v)
{ {
blas_int n = len; blas_int n = len;
blas_int incx = v.skip(); blas_int incx = v.s;
blas_int incy = s; blas_int incy = s;
daxpy(&n, &r, v.base(), &incx, base(), &incy); daxpy(&n, &r, v.base(), &incx, base(), &incy);
} }
@ -225,7 +218,7 @@ void
Vector::addComplex(const std::complex<double> &z, const ConstVector &v) Vector::addComplex(const std::complex<double> &z, const ConstVector &v)
{ {
blas_int n = len/2; blas_int n = len/2;
blas_int incx = v.skip(); blas_int incx = v.s;
blas_int incy = s; blas_int incy = s;
zaxpy(&n, reinterpret_cast<const double(&)[2]>(z), v.base(), &incx, base(), &incy); zaxpy(&n, reinterpret_cast<const double(&)[2]>(z), v.base(), &incx, base(), &incy);
} }
@ -374,8 +367,8 @@ ConstVector::getMax() const
{ {
double r = 0; double r = 0;
for (int i = 0; i < len; i++) for (int i = 0; i < len; i++)
if (abs(operator[](i)) > r) if (std::abs(operator[](i)) > r)
r = abs(operator[](i)); r = std::abs(operator[](i));
return r; return r;
} }
@ -384,7 +377,7 @@ ConstVector::getNorm1() const
{ {
double norm = 0.0; double norm = 0.0;
for (int i = 0; i < len; i++) for (int i = 0; i < len; i++)
norm += abs(operator[](i)); norm += std::abs(operator[](i));
return norm; return norm;
} }
@ -403,7 +396,7 @@ bool
ConstVector::isFinite() const ConstVector::isFinite() const
{ {
int i = 0; int i = 0;
while (i < len && isfinite(operator[](i))) while (i < len && std::isfinite(operator[](i)))
i++; i++;
return i == len; return i == len;
} }

View File

@ -140,6 +140,7 @@ class ConstGeneralMatrix;
class ConstVector class ConstVector
{ {
friend class Vector;
protected: protected:
int len; int len;
int off{0}; // offset to double* pointer int off{0}; // offset to double* pointer

View File

@ -4,70 +4,52 @@
#include "MMMatrix.hh" #include "MMMatrix.hh"
#include <cstdio> #include <fstream>
#include <cstring> #include <iomanip>
MMMatrixIn::MMMatrixIn(const char *fname) MMMatrixIn::MMMatrixIn(const std::string &fname)
{ {
FILE *fd; std::ifstream fd{fname};
if (nullptr == (fd = fopen(fname, "r"))) if (fd.fail())
throw MMException(string("Cannot open file ")+fname+" for reading\n"); throw MMException("Cannot open file "+fname+" for reading\n");
char buffer[1000];
// jump over initial comments // jump over initial comments
while (fgets(buffer, 1000, fd) && strncmp(buffer, "%%", 2)) while (fd.peek() == '%')
{ fd.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
}
// read in number of rows and cols // read in number of rows and cols
if (!fgets(buffer, 1000, fd)) fd >> rows >> cols;
throw MMException(string("Cannot read rows and cols while reading ")+fname+"\n"); if (fd.fail())
if (2 != sscanf(buffer, "%d %d", &rows, &cols))
throw MMException("Couldn't parse rows and cols\n"); throw MMException("Couldn't parse rows and cols\n");
// read in data // read in data
data = std::shared_ptr<double>(static_cast<double *>(operator new[](rows*cols*sizeof(double))), [](double *arr) { operator delete[](static_cast<void *>(arr)); }); data = std::shared_ptr<double>(static_cast<double *>(operator new[](rows*cols*sizeof(double))), [](double *arr) { operator delete[](static_cast<void *>(arr)); });
int len = rows*cols; int len = rows*cols;
int i = 0; int i = 0;
while (fgets(buffer, 1000, fd) && i < len) while (!fd.eof() && i < len)
{ {
if (1 != sscanf(buffer, "%lf", const_cast<double *>(data.get())+i)) fd >> data.get()[i];
throw MMException(string("Couldn't parse float number ")+buffer+"\n"); if (fd.fail())
throw MMException("Couldn't parse float number\n");
i++; i++;
} }
if (i < len) if (i < len)
{ throw MMException("Couldn't read all " + std::to_string(len) + " elements, read "
char mes[1000]; + std::to_string(i) + " so far\n");
sprintf(mes, "Couldn't read all %d lines, read %d so far\n", len, i); fd.close();
throw MMException(mes);
}
fclose(fd);
} }
void void
MMMatrixOut::write(const char *fname, int rows, int cols, const double *data) MMMatrixOut::write(const std::string &fname, const GeneralMatrix &m)
{ {
FILE *fd; std::ofstream fd{fname, std::ios::out | std::ios::trunc};
if (nullptr == (fd = fopen(fname, "w"))) if (fd.fail())
throw MMException(string("Cannot open file ")+fname+" for writing\n"); throw MMException("Cannot open file "+fname+" for writing\n");
if (0 > fprintf(fd, "%%%%MatrixMarket matrix array real general\n")) fd << "%%%%MatrixMarket matrix array real general" << std::endl
throw MMException(string("Output error when writing file ")+fname); << m.numRows() << ' ' << m.numCols() << std::endl
if (0 > fprintf(fd, "%d %d\n", rows, cols)) << std::setprecision(35);
throw MMException(string("Output error when writing file ")+fname); for (int i = 0; i < m.numCols(); i++)
int running = 0; for (int j = 0; j < m.numRows(); j++)
for (int i = 0; i < cols; i++) fd << std::setw(40) << m.get(i, j);
{ fd.close();
for (int j = 0; j < rows; j++)
{
if (0 > fprintf(fd, "%40.35g\n", data[running]))
throw MMException(string("Output error when writing file ")+fname);
running++;
}
}
fclose(fd);
}
void
MMMatrixOut::write(const char *fname, const GeneralMatrix &m)
{
write(fname, m.numRows(), m.numCols(), m.base());
} }

View File

@ -12,22 +12,17 @@
#include <utility> #include <utility>
#include <memory> #include <memory>
using namespace std;
class MMException : public MallocAllocator class MMException : public MallocAllocator
{ {
string message; std::string message;
public: public:
MMException(string mes) : message(std::move(mes)) MMException(std::string mes) : message(std::move(mes))
{ {
} }
MMException(const char *mes) : message(mes) std::string
{
}
const char *
getMessage() const getMessage() const
{ {
return message.data(); return message;
} }
}; };
@ -37,7 +32,7 @@ class MMMatrixIn : public MallocAllocator
int rows; int rows;
int cols; int cols;
public: public:
MMMatrixIn(const char *fname); MMMatrixIn(const std::string &fname);
~MMMatrixIn() = default; ~MMMatrixIn() = default;
Vector Vector
getData() const getData() const
@ -64,8 +59,7 @@ public:
class MMMatrixOut : public MallocAllocator class MMMatrixOut : public MallocAllocator
{ {
public: public:
static void write(const char *fname, int rows, int cols, const double *data); static void write(const std::string &fname, const GeneralMatrix &m);
static void write(const char *fname, const GeneralMatrix &m);
}; };
#endif /* MM_MATRIX_H */ #endif /* MM_MATRIX_H */

View File

@ -18,76 +18,65 @@
#include "MMMatrix.hh" #include "MMMatrix.hh"
#include <cstdio>
#include <cstring>
#include <ctime> #include <ctime>
#include <cmath> #include <cmath>
#include <string>
#include <utility>
#include <iostream>
#include <iomanip>
#include <memory>
class TestRunnable : public MallocAllocator class TestRunnable : public MallocAllocator
{ {
char name[100];
static double eps_norm;
public: public:
TestRunnable(const char *n) const std::string name;
static constexpr double eps_norm = 1.0e-10;
TestRunnable(std::string n) : name(std::move(n))
{ {
strncpy(name, n, 100);
} }
virtual ~TestRunnable() = default; virtual ~TestRunnable() = default;
bool test() const; bool test() const;
virtual bool run() const = 0; virtual bool run() const = 0;
const char *
getName() const
{
return name;
}
protected: protected:
// declaration of auxiliary static methods // declaration of auxiliary static methods
static bool quasi_solve(bool trans, const char *mname, const char *vname); static bool quasi_solve(bool trans, const std::string &mname, const std::string &vname);
static bool mult_kron(bool trans, const char *mname, const char *vname, static bool mult_kron(bool trans, const std::string &mname, const std::string &vname,
const char *cname, int m, int n, int depth); const std::string &cname, int m, int n, int depth);
static bool level_kron(bool trans, const char *mname, const char *vname, static bool level_kron(bool trans, const std::string &mname, const std::string &vname,
const char *cname, int level, int m, int n, int depth); const std::string &cname, int level, int m, int n, int depth);
static bool kron_power(const char *m1name, const char *m2name, const char *vname, static bool kron_power(const std::string &m1name, const std::string &m2name, const std::string &vname,
const char *cname, int m, int n, int depth); const std::string &cname, int m, int n, int depth);
static bool lin_eval(const char *m1name, const char *m2name, const char *vname, static bool lin_eval(const std::string &m1name, const std::string &m2name, const std::string &vname,
const char *cname, int m, int n, int depth, const std::string &cname, int m, int n, int depth,
double alpha, double beta1, double beta2); double alpha, double beta1, double beta2);
static bool qua_eval(const char *m1name, const char *m2name, const char *vname, static bool qua_eval(const std::string &m1name, const std::string &m2name, const std::string &vname,
const char *cname, int m, int n, int depth, const std::string &cname, int m, int n, int depth,
double alpha, double betas, double gamma, double alpha, double betas, double gamma,
double delta1, double delta2); double delta1, double delta2);
static bool tri_sylv(const char *m1name, const char *m2name, const char *vname, static bool tri_sylv(const std::string &m1name, const std::string &m2name, const std::string &vname,
int m, int n, int depth); int m, int n, int depth);
static bool gen_sylv(const char *aname, const char *bname, const char *cname, static bool gen_sylv(const std::string &aname, const std::string &bname, const std::string &cname,
const char *dname, int m, int n, int order); const std::string &dname, int m, int n, int order);
static bool eig_bubble(const char *aname, int from, int to); static bool eig_bubble(const std::string &aname, int from, int to);
static bool block_diag(const char *aname, double log10norm = 3.0); static bool block_diag(const std::string &aname, double log10norm = 3.0);
static bool iter_sylv(const char *m1name, const char *m2name, const char *vname, static bool iter_sylv(const std::string &m1name, const std::string &m2name, const std::string &vname,
int m, int n, int depth); int m, int n, int depth);
}; };
double TestRunnable::eps_norm = 1.0e-10;
bool bool
TestRunnable::test() const TestRunnable::test() const
{ {
printf("Running test <%s>\n", name); std::cout << "Running test <" << name << '>' << std::endl;
clock_t start = clock(); clock_t start = clock();
bool passed = run(); bool passed = run();
clock_t end = clock(); clock_t end = clock();
printf("CPU time %8.4g (CPU seconds)..................", std::cout << "CPU time " << ((double) (end-start))/CLOCKS_PER_SEC << " (CPU seconds)..................";
((double) (end-start))/CLOCKS_PER_SEC);
if (passed) if (passed)
{ std::cout << "passed";
printf("passed\n\n");
return passed;
}
else else
{ std::cout << "FAILED";
printf("FAILED\n\n"); std::cout << std::endl << std::endl;
return passed; return passed;
}
} }
/**********************************************************/ /**********************************************************/
@ -95,27 +84,27 @@ TestRunnable::test() const
/**********************************************************/ /**********************************************************/
bool bool
TestRunnable::quasi_solve(bool trans, const char *mname, const char *vname) TestRunnable::quasi_solve(bool trans, const std::string &mname, const std::string &vname)
{ {
MMMatrixIn mmt(mname); MMMatrixIn mmt(mname);
MMMatrixIn mmv(vname); MMMatrixIn mmv(vname);
SylvMemoryDriver memdriver(1, mmt.row(), mmt.row(), 1); SylvMemoryDriver memdriver(1, mmt.row(), mmt.row(), 1);
QuasiTriangular *t; std::unique_ptr<QuasiTriangular> t;
QuasiTriangular *tsave; std::unique_ptr<QuasiTriangular> tsave;
if (mmt.row() == mmt.col()) if (mmt.row() == mmt.col())
{ {
t = new QuasiTriangular(mmt.getData(), mmt.row()); t = std::make_unique<QuasiTriangular>(mmt.getData(), mmt.row());
tsave = new QuasiTriangular(*t); tsave = std::make_unique<QuasiTriangular>(*t);
} }
else if (mmt.row() > mmt.col()) else if (mmt.row() > mmt.col())
{ {
t = new QuasiTriangularZero(mmt.row()-mmt.col(), mmt.getData(), mmt.col()); t = std::make_unique<QuasiTriangularZero>(mmt.row()-mmt.col(), mmt.getData(), mmt.col());
tsave = new QuasiTriangularZero((const QuasiTriangularZero &) *t); tsave = std::make_unique<QuasiTriangularZero>((const QuasiTriangularZero &) *t);
} }
else else
{ {
printf(" Wrong quasi triangular dimensions, rows must be >= cols.\n"); std::cout << " Wrong quasi triangular dimensions, rows must be >= cols.\n";
return false; return false;
} }
ConstVector v{mmv.getData()}; ConstVector v{mmv.getData()};
@ -125,24 +114,22 @@ TestRunnable::quasi_solve(bool trans, const char *mname, const char *vname)
t->solveTrans(x, v, eig_min); t->solveTrans(x, v, eig_min);
else else
t->solve(x, v, eig_min); t->solve(x, v, eig_min);
printf("eig_min = %8.4g\n", eig_min); std::cout << "eig_min = " << eig_min << std::endl;
Vector xx(v.length()); Vector xx(v.length());
if (trans) if (trans)
tsave->multVecTrans(xx, ConstVector(x)); tsave->multVecTrans(xx, ConstVector(x));
else else
tsave->multVec(xx, ConstVector(x)); tsave->multVec(xx, ConstVector(x));
delete tsave;
delete t;
xx.add(-1.0, v); xx.add(-1.0, v);
xx.add(1.0, x); xx.add(1.0, x);
double norm = xx.getNorm(); double norm = xx.getNorm();
printf("\terror norm = %8.4g\n", norm); std::cout << "\terror norm = " << norm << std::endl;
return (norm < eps_norm); return (norm < eps_norm);
} }
bool bool
TestRunnable::mult_kron(bool trans, const char *mname, const char *vname, TestRunnable::mult_kron(bool trans, const std::string &mname, const std::string &vname,
const char *cname, int m, int n, int depth) const std::string &cname, int m, int n, int depth)
{ {
MMMatrixIn mmt(mname); MMMatrixIn mmt(mname);
MMMatrixIn mmv(vname); MMMatrixIn mmv(vname);
@ -153,7 +140,10 @@ TestRunnable::mult_kron(bool trans, const char *mname, const char *vname,
|| mmv.row() != length || mmv.row() != length
|| mmc.row() != length) || mmc.row() != length)
{ {
printf(" Incompatible sizes for krom mult action, len=%d, matrow=%d, m=%d, vrow=%d, crow=%d \n", length, mmt.row(), m, mmv.row(), mmc.row()); std::cout << " Incompatible sizes for kron mult action, len=" << length
<< ", matrow=" << mmt.row() << ", m=" << m
<< ", vrow=" << mmv.row() << ", crow=" << mmc.row()
<< std::endl;
return false; return false;
} }
@ -169,13 +159,13 @@ TestRunnable::mult_kron(bool trans, const char *mname, const char *vname,
t.multKron(v); t.multKron(v);
c.add(-1.0, v); c.add(-1.0, v);
double norm = c.getNorm(); double norm = c.getNorm();
printf("\terror norm = %8.4g\n", norm); std::cout << "\terror norm = " << norm << std::endl;
return (norm < eps_norm); return (norm < eps_norm);
} }
bool bool
TestRunnable::level_kron(bool trans, const char *mname, const char *vname, TestRunnable::level_kron(bool trans, const std::string &mname, const std::string &vname,
const char *cname, int level, int m, int n, int depth) const std::string &cname, int level, int m, int n, int depth)
{ {
MMMatrixIn mmt(mname); MMMatrixIn mmt(mname);
MMMatrixIn mmv(vname); MMMatrixIn mmv(vname);
@ -187,7 +177,10 @@ TestRunnable::level_kron(bool trans, const char *mname, const char *vname,
|| mmv.row() != length || mmv.row() != length
|| mmc.row() != length) || mmc.row() != length)
{ {
printf(" Incompatible sizes for krom mult action, len=%d, matrow=%d, m=%d, n=%d, vrow=%d, crow=%d \n", length, mmt.row(), m, n, mmv.row(), mmc.row()); std::cout << " Incompatible sizes for kron mult action, len=" << length
<< ", matrow=" << mmt.row() << ", m=" << m << ", n=" << n
<< ", vrow=" << mmv.row() << ", crow=" << mmc.row()
<< std::endl;
return false; return false;
} }
@ -204,13 +197,13 @@ TestRunnable::level_kron(bool trans, const char *mname, const char *vname,
KronUtils::multAtLevel(level, t, x); KronUtils::multAtLevel(level, t, x);
x.add(-1, c); x.add(-1, c);
double norm = x.getNorm(); double norm = x.getNorm();
printf("\terror norm = %8.4g\n", norm); std::cout << "\terror norm = " << norm << std::endl;
return (norm < eps_norm); return (norm < eps_norm);
} }
bool bool
TestRunnable::kron_power(const char *m1name, const char *m2name, const char *vname, TestRunnable::kron_power(const std::string &m1name, const std::string &m2name, const std::string &vname,
const char *cname, int m, int n, int depth) const std::string &cname, int m, int n, int depth)
{ {
MMMatrixIn mmt1(m1name); MMMatrixIn mmt1(m1name);
MMMatrixIn mmt2(m2name); MMMatrixIn mmt2(m2name);
@ -223,7 +216,11 @@ TestRunnable::kron_power(const char *m1name, const char *m2name, const char *vna
|| mmv.row() != length || mmv.row() != length
|| mmc.row() != length) || mmc.row() != length)
{ {
printf(" Incompatible sizes for krom power mult action, len=%d, row1=%d, row2=%d, m=%d, n=%d, vrow=%d, crow=%d \n", length, mmt1.row(), mmt2.row(), m, n, mmv.row(), mmc.row()); std::cout << " Incompatible sizes for kron power mult action, len=" << length
<< ", row1=" << mmt1.row() << ", row2=" << mmt2.row()
<< ", m=" << m << ", n=" << n
<< ", vrow=" << mmv.row() << ", crow=" << mmc.row()
<< std::endl;
return false; return false;
} }
@ -240,13 +237,13 @@ TestRunnable::kron_power(const char *m1name, const char *m2name, const char *vna
memdriver.setStackMode(false); memdriver.setStackMode(false);
x.add(-1, c); x.add(-1, c);
double norm = x.getNorm(); double norm = x.getNorm();
printf("\terror norm = %8.4g\n", norm); std::cout << "\terror norm = " << norm << std::endl;
return (norm < eps_norm); return (norm < eps_norm);
} }
bool bool
TestRunnable::lin_eval(const char *m1name, const char *m2name, const char *vname, TestRunnable::lin_eval(const std::string &m1name, const std::string &m2name, const std::string &vname,
const char *cname, int m, int n, int depth, const std::string &cname, int m, int n, int depth,
double alpha, double beta1, double beta2) double alpha, double beta1, double beta2)
{ {
MMMatrixIn mmt1(m1name); MMMatrixIn mmt1(m1name);
@ -260,7 +257,11 @@ TestRunnable::lin_eval(const char *m1name, const char *m2name, const char *vname
|| mmv.row() != 2*length || mmv.row() != 2*length
|| mmc.row() != 2*length) || mmc.row() != 2*length)
{ {
printf(" Incompatible sizes for lin eval action, len=%d, row1=%d, row2=%d, m=%d, n=%d, vrow=%d, crow=%d \n", length, mmt1.row(), mmt2.row(), m, n, mmv.row(), mmc.row()); std::cout << " Incompatible sizes for lin eval action, len=" << length
<< ", row1=" << mmt1.row() << ", row2=" << mmt2.row()
<< ", m=" << m << ", n=" << n
<< ", vrow=" << mmv.row() << ", crow=" << mmc.row()
<< std::endl;
return false; return false;
} }
@ -285,13 +286,13 @@ TestRunnable::lin_eval(const char *m1name, const char *m2name, const char *vname
x2.add(-1, c2); x2.add(-1, c2);
double norm1 = x1.getNorm(); double norm1 = x1.getNorm();
double norm2 = x2.getNorm(); double norm2 = x2.getNorm();
printf("\terror norm1 = %8.4g\n\terror norm2 = %8.4g\n", norm1, norm2); std::cout << "\terror norm1 = " << norm1 << "\n\terror norm2 = " << norm2 << '\n';
return (norm1*norm1+norm2*norm2 < eps_norm*eps_norm); return (norm1*norm1+norm2*norm2 < eps_norm*eps_norm);
} }
bool bool
TestRunnable::qua_eval(const char *m1name, const char *m2name, const char *vname, TestRunnable::qua_eval(const std::string &m1name, const std::string &m2name, const std::string &vname,
const char *cname, int m, int n, int depth, const std::string &cname, int m, int n, int depth,
double alpha, double betas, double gamma, double alpha, double betas, double gamma,
double delta1, double delta2) double delta1, double delta2)
{ {
@ -306,7 +307,11 @@ TestRunnable::qua_eval(const char *m1name, const char *m2name, const char *vname
|| mmv.row() != 2*length || mmv.row() != 2*length
|| mmc.row() != 2*length) || mmc.row() != 2*length)
{ {
printf(" Incompatible sizes for qua eval action, len=%d, row1=%d, row2=%d, m=%d, n=%d, vrow=%d, crow=%d \n", length, mmt1.row(), mmt2.row(), m, n, mmv.row(), mmc.row()); std::cout << " Incompatible sizes for qua eval action, len=" << length
<< ", row1=" << mmt1.row() << ", row2=" << mmt2.row()
<< ", m=" << m << ", n=" << n
<< ", vrow=" << mmv.row() << ", crow=" << mmc.row()
<< std::endl;
return false; return false;
} }
@ -331,12 +336,12 @@ TestRunnable::qua_eval(const char *m1name, const char *m2name, const char *vname
x2.add(-1, c2); x2.add(-1, c2);
double norm1 = x1.getNorm(); double norm1 = x1.getNorm();
double norm2 = x2.getNorm(); double norm2 = x2.getNorm();
printf("\terror norm1 = %8.4g\n\terror norm2 = %8.4g\n", norm1, norm2); std::cout << "\terror norm1 = " << norm1 << "\n\terror norm2 = " << norm2 << std::endl;
return (norm1*norm1+norm2*norm2 < 100*eps_norm*eps_norm); // relax norm return (norm1*norm1+norm2*norm2 < 100*eps_norm*eps_norm); // relax norm
} }
bool bool
TestRunnable::tri_sylv(const char *m1name, const char *m2name, const char *vname, TestRunnable::tri_sylv(const std::string &m1name, const std::string &m2name, const std::string &vname,
int m, int n, int depth) int m, int n, int depth)
{ {
MMMatrixIn mmt1(m1name); MMMatrixIn mmt1(m1name);
@ -348,7 +353,11 @@ TestRunnable::tri_sylv(const char *m1name, const char *m2name, const char *vname
|| mmt2.row() != n || mmt2.row() != n
|| mmv.row() != length) || mmv.row() != length)
{ {
printf(" Incompatible sizes for triangular sylvester action, len=%d, row1=%d, row2=%d, m=%d, n=%d, vrow=%d\n", length, mmt1.row(), mmt2.row(), m, n, mmv.row()); std::cout << " Incompatible sizes for triangular sylvester action, len=" << length
<< ", row1=" << mmt1.row() << ", row2=" << mmt2.row()
<< ", m=" << m << ", n=" << n
<< ", vrow=" << mmv.row()
<< std::endl;
return false; return false;
} }
@ -369,17 +378,17 @@ TestRunnable::tri_sylv(const char *m1name, const char *m2name, const char *vname
dcheck.add(-1.0, v); dcheck.add(-1.0, v);
double norm = dcheck.getNorm(); double norm = dcheck.getNorm();
double xnorm = v.getNorm(); double xnorm = v.getNorm();
printf("\trel. error norm = %8.4g\n", norm/xnorm); std::cout << "\trel. error norm = " << norm/xnorm << std::endl;
double max = dcheck.getMax(); double max = dcheck.getMax();
double xmax = v.getMax(); double xmax = v.getMax();
printf("\trel. error max = %8.4g\n", max/xmax); std::cout << "\trel. error max = " << max/xmax << std::endl;
memdriver.setStackMode(false); memdriver.setStackMode(false);
return (norm < xnorm*eps_norm); return (norm < xnorm*eps_norm);
} }
bool bool
TestRunnable::gen_sylv(const char *aname, const char *bname, const char *cname, TestRunnable::gen_sylv(const std::string &aname, const std::string &bname, const std::string &cname,
const char *dname, int m, int n, int order) const std::string &dname, int m, int n, int order)
{ {
MMMatrixIn mma(aname); MMMatrixIn mma(aname);
MMMatrixIn mmb(bname); MMMatrixIn mmb(bname);
@ -391,7 +400,7 @@ TestRunnable::gen_sylv(const char *aname, const char *bname, const char *cname,
|| n != mmb.row() || n < mmb.col() || n != mmb.row() || n < mmb.col()
|| n != mmd.row() || power(m, order) != mmd.col()) || n != mmd.row() || power(m, order) != mmd.col())
{ {
printf(" Incompatible sizes for gen_sylv.\n"); std::cout << " Incompatible sizes for gen_sylv.\n";
return false; return false;
} }
@ -410,13 +419,13 @@ TestRunnable::gen_sylv(const char *aname, const char *bname, const char *cname,
} }
bool bool
TestRunnable::eig_bubble(const char *aname, int from, int to) TestRunnable::eig_bubble(const std::string &aname, int from, int to)
{ {
MMMatrixIn mma(aname); MMMatrixIn mma(aname);
if (mma.row() != mma.col()) if (mma.row() != mma.col())
{ {
printf(" Matrix is not square\n"); std::cout << " Matrix is not square\n";
return false; return false;
} }
@ -438,21 +447,21 @@ TestRunnable::eig_bubble(const char *aname, int from, int to)
double normInf = check.getNormInf(); double normInf = check.getNormInf();
double onorm1 = orig.getNorm1(); double onorm1 = orig.getNorm1();
double onormInf = orig.getNormInf(); double onormInf = orig.getNormInf();
printf("\tabs. error1 = %8.4g\n", norm1); std:: cout << "\tabs. error1 = " << norm1 << std::endl
printf("\tabs. errorI = %8.4g\n", normInf); << "\tabs. errorI = " << normInf << std::endl
printf("\trel. error1 = %8.4g\n", norm1/onorm1); << "\trel. error1 = " << norm1/onorm1 << std::endl
printf("\trel. errorI = %8.4g\n", normInf/onormInf); << "\trel. errorI = " << normInf/onormInf << std::endl;
return (norm1 < eps_norm*onorm1 && normInf < eps_norm*onormInf); return (norm1 < eps_norm*onorm1 && normInf < eps_norm*onormInf);
} }
bool bool
TestRunnable::block_diag(const char *aname, double log10norm) TestRunnable::block_diag(const std::string &aname, double log10norm)
{ {
MMMatrixIn mma(aname); MMMatrixIn mma(aname);
if (mma.row() != mma.col()) if (mma.row() != mma.col())
{ {
printf(" Matrix is not square\n"); std::cout << " Matrix is not square\n";
return false; return false;
} }
@ -468,25 +477,25 @@ TestRunnable::block_diag(const char *aname, double log10norm)
double normInf = check.getNormInf(); double normInf = check.getNormInf();
double onorm1 = orig.getNorm1(); double onorm1 = orig.getNorm1();
double onormInf = orig.getNormInf(); double onormInf = orig.getNormInf();
printf("\terror Q*B*invQ:\n"); std::cout << "\terror Q*B*invQ:" << std::endl
printf("\tabs. error1 = %8.4g\n", norm1); << "\tabs. error1 = " << norm1 << std::endl
printf("\tabs. errorI = %8.4g\n", normInf); << "\tabs. errorI = " << normInf << std::endl
printf("\trel. error1 = %8.4g\n", norm1/onorm1); << "\trel. error1 = " << norm1/onorm1 << std::endl
printf("\trel. errorI = %8.4g\n", normInf/onormInf); << "\trel. errorI = " << normInf/onormInf << std::endl;
SqSylvMatrix check2(dec.getQ(), dec.getInvQ()); SqSylvMatrix check2(dec.getQ(), dec.getInvQ());
SqSylvMatrix in(n); SqSylvMatrix in(n);
in.setUnit(); in.setUnit();
check2.add(-1, in); check2.add(-1, in);
double nor1 = check2.getNorm1(); double nor1 = check2.getNorm1();
double norInf = check2.getNormInf(); double norInf = check2.getNormInf();
printf("\terror Q*invQ:\n"); std::cout << "\terror Q*invQ:" << std::endl
printf("\tabs. error1 = %8.4g\n", nor1); << "\tabs. error1 = " << nor1 << std::endl
printf("\tabs. errorI = %8.4g\n", norInf); << "\tabs. errorI = " << norInf << std::endl;
return (norm1 < eps_norm*pow(10, log10norm)*onorm1); return (norm1 < eps_norm*pow(10, log10norm)*onorm1);
} }
bool bool
TestRunnable::iter_sylv(const char *m1name, const char *m2name, const char *vname, TestRunnable::iter_sylv(const std::string &m1name, const std::string &m2name, const std::string &vname,
int m, int n, int depth) int m, int n, int depth)
{ {
MMMatrixIn mmt1(m1name); MMMatrixIn mmt1(m1name);
@ -498,7 +507,11 @@ TestRunnable::iter_sylv(const char *m1name, const char *m2name, const char *vnam
|| mmt2.row() != n || mmt2.row() != n
|| mmv.row() != length) || mmv.row() != length)
{ {
printf(" Incompatible sizes for triangular sylvester iteration, len=%d, row1=%d, row2=%d, m=%d, n=%d, vrow=%d\n", length, mmt1.row(), mmt2.row(), m, n, mmv.row()); std::cout << " Incompatible sizes for triangular sylvester iteration, len=" << length
<< ", row1=" << mmt1.row() << ", row2=" << mmt2.row()
<< ", m=" << m << ", n=" << n
<< ", vrow=" << mmv.row()
<< std::endl;
return false; return false;
} }
@ -511,7 +524,7 @@ TestRunnable::iter_sylv(const char *m1name, const char *m2name, const char *vnam
ConstKronVector v(vraw, m, n, depth); ConstKronVector v(vraw, m, n, depth);
KronVector d(v); // copy of v KronVector d(v); // copy of v
SylvParams pars; SylvParams pars;
pars.method = SylvParams::iter; pars.method = SylvParams::solve_method::iter;
is.solve(pars, d); is.solve(pars, d);
pars.print("\t"); pars.print("\t");
KronVector dcheck((const KronVector &)d); KronVector dcheck((const KronVector &)d);
@ -520,10 +533,10 @@ TestRunnable::iter_sylv(const char *m1name, const char *m2name, const char *vnam
dcheck.add(-1.0, v); dcheck.add(-1.0, v);
double cnorm = dcheck.getNorm(); double cnorm = dcheck.getNorm();
double xnorm = v.getNorm(); double xnorm = v.getNorm();
printf("\trel. error norm = %8.4g\n", cnorm/xnorm); std::cout << "\trel. error norm = " << cnorm/xnorm << std::endl;
double max = dcheck.getMax(); double max = dcheck.getMax();
double xmax = v.getMax(); double xmax = v.getMax();
printf("\trel. error max = %8.4g\n", max/xmax); std::cout << "\trel. error max = " << max/xmax << std::endl;
memdriver.setStackMode(false); memdriver.setStackMode(false);
return (cnorm < xnorm*eps_norm); return (cnorm < xnorm*eps_norm);
} }
@ -1149,79 +1162,76 @@ BlockDiagBigTest::run() const
int int
main() main()
{ {
TestRunnable *all_tests[50]; std::vector<std::unique_ptr<TestRunnable>> all_tests;
// fill in vector of all tests // fill in vector of all tests
int num_tests = 0; all_tests.push_back(std::make_unique<PureTriangTest>());
all_tests[num_tests++] = new PureTriangTest(); all_tests.push_back(std::make_unique<PureTriangTransTest>());
all_tests[num_tests++] = new PureTriangTransTest(); all_tests.push_back(std::make_unique<PureTrLargeTest>());
all_tests[num_tests++] = new PureTrLargeTest(); all_tests.push_back(std::make_unique<PureTrLargeTransTest>());
all_tests[num_tests++] = new PureTrLargeTransTest(); all_tests.push_back(std::make_unique<QuasiTriangTest>());
all_tests[num_tests++] = new QuasiTriangTest(); all_tests.push_back(std::make_unique<QuasiTriangTransTest>());
all_tests[num_tests++] = new QuasiTriangTransTest(); all_tests.push_back(std::make_unique<QuasiTrLargeTest>());
all_tests[num_tests++] = new QuasiTrLargeTest(); all_tests.push_back(std::make_unique<QuasiTrLargeTransTest>());
all_tests[num_tests++] = new QuasiTrLargeTransTest(); all_tests.push_back(std::make_unique<QuasiZeroSmallTest>());
all_tests[num_tests++] = new QuasiZeroSmallTest(); all_tests.push_back(std::make_unique<MultKronSmallTest>());
all_tests[num_tests++] = new MultKronSmallTest(); all_tests.push_back(std::make_unique<MultKronTest>());
all_tests[num_tests++] = new MultKronTest(); all_tests.push_back(std::make_unique<MultKronSmallTransTest>());
all_tests[num_tests++] = new MultKronSmallTransTest(); all_tests.push_back(std::make_unique<MultKronTransTest>());
all_tests[num_tests++] = new MultKronTransTest(); all_tests.push_back(std::make_unique<LevelKronTest>());
all_tests[num_tests++] = new LevelKronTest(); all_tests.push_back(std::make_unique<LevelKronTransTest>());
all_tests[num_tests++] = new LevelKronTransTest(); all_tests.push_back(std::make_unique<LevelZeroKronTest>());
all_tests[num_tests++] = new LevelZeroKronTest(); all_tests.push_back(std::make_unique<LevelZeroKronTransTest>());
all_tests[num_tests++] = new LevelZeroKronTransTest(); all_tests.push_back(std::make_unique<KronPowerTest>());
all_tests[num_tests++] = new KronPowerTest(); all_tests.push_back(std::make_unique<SmallLinEvalTest>());
all_tests[num_tests++] = new SmallLinEvalTest(); all_tests.push_back(std::make_unique<LinEvalTest>());
all_tests[num_tests++] = new LinEvalTest(); all_tests.push_back(std::make_unique<SmallQuaEvalTest>());
all_tests[num_tests++] = new SmallQuaEvalTest(); all_tests.push_back(std::make_unique<QuaEvalTest>());
all_tests[num_tests++] = new QuaEvalTest(); all_tests.push_back(std::make_unique<EigBubFrankTest>());
all_tests[num_tests++] = new EigBubFrankTest(); all_tests.push_back(std::make_unique<EigBubSplitTest>());
all_tests[num_tests++] = new EigBubSplitTest(); all_tests.push_back(std::make_unique<EigBubSameTest>());
all_tests[num_tests++] = new EigBubSameTest(); all_tests.push_back(std::make_unique<BlockDiagSmallTest>());
all_tests[num_tests++] = new BlockDiagSmallTest(); all_tests.push_back(std::make_unique<BlockDiagFrankTest>());
all_tests[num_tests++] = new BlockDiagFrankTest(); all_tests.push_back(std::make_unique<BlockDiagIllCondTest>());
all_tests[num_tests++] = new BlockDiagIllCondTest(); all_tests.push_back(std::make_unique<BlockDiagBigTest>());
all_tests[num_tests++] = new BlockDiagBigTest(); all_tests.push_back(std::make_unique<TriSylvSmallRealTest>());
all_tests[num_tests++] = new TriSylvSmallRealTest(); all_tests.push_back(std::make_unique<TriSylvSmallComplexTest>());
all_tests[num_tests++] = new TriSylvSmallComplexTest(); all_tests.push_back(std::make_unique<TriSylvTest>());
all_tests[num_tests++] = new TriSylvTest(); all_tests.push_back(std::make_unique<TriSylvBigTest>());
all_tests[num_tests++] = new TriSylvBigTest(); all_tests.push_back(std::make_unique<TriSylvLargeTest>());
all_tests[num_tests++] = new TriSylvLargeTest(); all_tests.push_back(std::make_unique<IterSylvTest>());
all_tests[num_tests++] = new IterSylvTest(); all_tests.push_back(std::make_unique<IterSylvLargeTest>());
all_tests[num_tests++] = new IterSylvLargeTest(); all_tests.push_back(std::make_unique<GenSylvSmallTest>());
all_tests[num_tests++] = new GenSylvSmallTest(); all_tests.push_back(std::make_unique<GenSylvTest>());
all_tests[num_tests++] = new GenSylvTest(); all_tests.push_back(std::make_unique<GenSylvSingTest>());
all_tests[num_tests++] = new GenSylvSingTest(); all_tests.push_back(std::make_unique<GenSylvLargeTest>());
all_tests[num_tests++] = new GenSylvLargeTest();
// launch the tests // launch the tests
std::cout << std::setprecision(4);
int success = 0; int success = 0;
for (int i = 0; i < num_tests; i++) for (const auto &test : all_tests)
{ {
try try
{ {
if (all_tests[i]->test()) if (test->test())
success++; success++;
} }
catch (const MMException &e) catch (const MMException &e)
{ {
printf("Caugth MM exception in <%s>:\n%s", all_tests[i]->getName(), std::cout << "Caught MM exception in <" << test->name << ">:\n" << e.getMessage();
e.getMessage());
} }
catch (SylvException &e) catch (SylvException &e)
{ {
printf("Caught Sylv exception in %s:\n", all_tests[i]->getName()); std::cout << "Caught Sylv exception in " << test->name << ":\n";
e.printMessage(); e.printMessage();
} }
} }
printf("There were %d tests that failed out of %d tests run.\n", int nfailed = all_tests.size() - success;
num_tests - success, num_tests); std::cout << "There were " << nfailed << " tests that failed out of "
<< all_tests.size() << " tests run." << std::endl;
// destroy if (nfailed)
for (int i = 0; i < num_tests; i++) return EXIT_FAILURE;
{ else
delete all_tests[i]; return EXIT_SUCCESS;
}
return 0;
} }