Dynare++ tensor library: modernization of Kronecker product classes

time-shift
Sébastien Villemot 2019-02-19 12:34:43 +01:00
parent 6e747b5dba
commit d08ca8ca7f
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
3 changed files with 61 additions and 84 deletions

View File

@ -3,7 +3,7 @@
#include "kron_prod.hh" #include "kron_prod.hh"
#include "tl_exception.hh" #include "tl_exception.hh"
#include <cstdio> #include <tuple>
/* Here we construct Kronecker product dimensions from Kronecker /* Here we construct Kronecker product dimensions from Kronecker
product dimensions by picking a given matrix and all other set to product dimensions by picking a given matrix and all other set to
@ -20,8 +20,8 @@
Then we fork according to |i|. */ Then we fork according to |i|. */
KronProdDimens::KronProdDimens(const KronProdDimens &kd, int i) KronProdDimens::KronProdDimens(const KronProdDimens &kd, int i)
: rows((i == 0 || i == kd.dimen()-1) ? (2) : (3)), : rows((i == 0 || i == kd.dimen()-1) ? 2 : 3),
cols((i == 0 || i == kd.dimen()-1) ? (2) : (3)) cols((i == 0 || i == kd.dimen()-1) ? 2 : 3)
{ {
TL_RAISE_IF(i < 0 || i >= kd.dimen(), TL_RAISE_IF(i < 0 || i >= kd.dimen(),
"Wrong index for pickup in KronProdDimens constructor"); "Wrong index for pickup in KronProdDimens constructor");
@ -72,9 +72,8 @@ KronProdDimens::KronProdDimens(const KronProdDimens &kd, int i)
void void
KronProd::checkDimForMult(const ConstTwoDMatrix &in, const TwoDMatrix &out) const KronProd::checkDimForMult(const ConstTwoDMatrix &in, const TwoDMatrix &out) const
{ {
int my_rows; int my_rows, my_cols;
int my_cols; std::tie(my_rows, my_cols) = kpd.getRC();
kpd.getRC(my_rows, my_cols);
TL_RAISE_IF(in.nrows() != out.nrows() || in.ncols() != my_rows, TL_RAISE_IF(in.nrows() != out.nrows() || in.ncols() != my_rows,
"Wrong dimensions for KronProd in KronProd::checkDimForMult"); "Wrong dimensions for KronProd in KronProd::checkDimForMult");
} }
@ -137,8 +136,8 @@ KronProdIA::mult(const ConstTwoDMatrix &in, TwoDMatrix &out) const
for (int i = 0; i < id_cols; i++) for (int i = 0; i < id_cols; i++)
{ {
TwoDMatrix outi(out, i *a.ncols(), a.ncols()); TwoDMatrix outi(out, i * a.ncols(), a.ncols());
ConstTwoDMatrix ini(in, i *a.nrows(), a.nrows()); ConstTwoDMatrix ini(in, i * a.nrows(), a.nrows());
outi.mult(ini, a); outi.mult(ini, a);
} }
} }
@ -221,9 +220,8 @@ KronProdIAI::mult(const ConstTwoDMatrix &in, TwoDMatrix &out) const
int id_cols = kpd.cols[0]; int id_cols = kpd.cols[0];
KronProdAI akronid(*this); KronProdAI akronid(*this);
int in_bl_width; int in_bl_width, out_bl_width;
int out_bl_width; std::tie(in_bl_width, out_bl_width) = akronid.kpd.getRC();
akronid.kpd.getRC(in_bl_width, out_bl_width);
for (int i = 0; i < id_cols; i++) for (int i = 0; i < id_cols; i++)
{ {
@ -279,7 +277,7 @@ KronProdAll::mult(const ConstTwoDMatrix &in, TwoDMatrix &out) const
} }
int c; int c;
TwoDMatrix *last = nullptr; std::unique_ptr<TwoDMatrix> last;
// perform first multiplication AI // perform first multiplication AI
/* Here we have to construct $A_1\otimes I$, allocate intermediate /* Here we have to construct $A_1\otimes I$, allocate intermediate
@ -288,30 +286,25 @@ KronProdAll::mult(const ConstTwoDMatrix &in, TwoDMatrix &out) const
{ {
KronProdAI akronid(*this); KronProdAI akronid(*this);
c = akronid.kpd.ncols(); c = akronid.kpd.ncols();
last = new TwoDMatrix(in.nrows(), c); last = std::make_unique<TwoDMatrix>(in.nrows(), c);
akronid.mult(in, *last); akronid.mult(in, *last);
} }
else else
{ last = std::make_unique<TwoDMatrix>(in.nrows(), in.ncols(), Vector{in.getData()});
last = new TwoDMatrix(in.nrows(), in.ncols(), Vector{in.getData()});
}
// perform intermediate multiplications IAI // perform intermediate multiplications IAI
/* Here we go through all $I\otimes A_i\otimes I$, construct the /* Here we go through all $I\otimes A_i\otimes I$, construct the
product, allocate new storage for result |newlast|, perform the product, allocate new storage for result |newlast|, perform the
multiplication, deallocate old |last|, and set |last| to |newlast|. */ multiplication, deallocate old |last|, and set |last| to |newlast|. */
for (int i = 1; i < dimen()-1; i++) for (int i = 1; i < dimen()-1; i++)
{ if (matlist[i])
if (matlist[i]) {
{ KronProdIAI interkron(*this, i);
KronProdIAI interkron(*this, i); c = interkron.kpd.ncols();
c = interkron.kpd.ncols(); auto newlast = std::make_unique<TwoDMatrix>(in.nrows(), c);
auto *newlast = new TwoDMatrix(in.nrows(), c); interkron.mult(*last, *newlast);
interkron.mult(*last, *newlast); last = std::move(newlast);
delete last; }
last = newlast;
}
}
// perform last multiplication IA // perform last multiplication IA
/* Here just construct $I\otimes A_n$ and perform multiplication and /* Here just construct $I\otimes A_n$ and perform multiplication and
@ -322,25 +315,22 @@ KronProdAll::mult(const ConstTwoDMatrix &in, TwoDMatrix &out) const
idkrona.mult(*last, out); idkrona.mult(*last, out);
} }
else else
{ out = *last;
out = *last;
}
delete last;
} }
/* This calculates a Kornecker product of rows of matrices, the row /* This calculates a Kornecker product of rows of matrices, the row
indices are given by the integer sequence. The result is allocated and indices are given by the integer sequence. The result is allocated and
returned. The caller is repsonsible for its deallocation. */ returned. */
Vector * std::unique_ptr<Vector>
KronProdAll::multRows(const IntSequence &irows) const KronProdAll::multRows(const IntSequence &irows) const
{ {
TL_RAISE_IF(irows.size() != dimen(), TL_RAISE_IF(irows.size() != dimen(),
"Wrong length of row indices in KronProdAll::multRows"); "Wrong length of row indices in KronProdAll::multRows");
Vector *last = nullptr; std::unique_ptr<Vector> last;
ConstVector *row; std::unique_ptr<ConstVector> row;
std::vector<Vector *> to_delete; std::vector<std::unique_ptr<Vector>> to_delete;
for (int i = 0; i < dimen(); i++) for (int i = 0; i < dimen(); i++)
{ {
int j = dimen()-1-i; int j = dimen()-1-i;
@ -352,14 +342,14 @@ KronProdAll::multRows(const IntSequence &irows) const
the |row| as ConstVector of this vector, which sheduled for the |row| as ConstVector of this vector, which sheduled for
deallocation. */ deallocation. */
if (matlist[j]) if (matlist[j])
row = new ConstVector(matlist[j]->getRow(irows[j])); row = std::make_unique<ConstVector>(matlist[j]->getRow(irows[j]));
else else
{ {
auto *aux = new Vector(ncols(j)); auto aux = std::make_unique<Vector>(ncols(j));
aux->zeros(); aux->zeros();
(*aux)[irows[j]] = 1.0; (*aux)[irows[j]] = 1.0;
to_delete.push_back(aux); row = std::make_unique<ConstVector>(*aux);
row = new ConstVector(*aux); to_delete.emplace_back(std::move(aux));
} }
// set |last| to product of |row| and |last| // set |last| to product of |row| and |last|
@ -368,24 +358,15 @@ KronProdAll::multRows(const IntSequence &irows) const
then we only make |last| equal to |row|. */ then we only make |last| equal to |row|. */
if (last) if (last)
{ {
Vector *newlast; auto newlast = std::make_unique<Vector>(last->length()*row->length());
newlast = new Vector(last->length()*row->length());
kronMult(*row, ConstVector(*last), *newlast); kronMult(*row, ConstVector(*last), *newlast);
delete last; last = std::move(newlast);
last = newlast;
} }
else else
{ last = std::make_unique<Vector>(*row);
last = new Vector(*row);
}
delete row;
} }
for (auto & i : to_delete) return std::move(last);
delete i;
return last;
} }
/* This permutes the matrices so that the new ordering would minimize /* This permutes the matrices so that the new ordering would minimize
@ -402,7 +383,8 @@ KronProdAllOptim::optimizeOrder()
int swaps = 0; int swaps = 0;
for (int j = 0; j < dimen()-1; j++) for (int j = 0; j < dimen()-1; j++)
{ {
if (((double) kpd.rows[j])/kpd.cols[j] < ((double) kpd.rows[j+1])/kpd.cols[j+1]) if (static_cast<double>(kpd.rows[j])/kpd.cols[j]
< static_cast<double>(kpd.rows[j+1])/kpd.cols[j+1])
{ {
// swap dimensions and matrices at |j| and |j+1| // swap dimensions and matrices at |j| and |j+1|
int s = kpd.rows[j+1]; int s = kpd.rows[j+1];
@ -423,8 +405,6 @@ KronProdAllOptim::optimizeOrder()
} }
} }
if (swaps == 0) if (swaps == 0)
{ return;
return;
}
} }
} }

View File

@ -26,6 +26,10 @@
#ifndef KRON_PROD_H #ifndef KRON_PROD_H
#define KRON_PROD_H #define KRON_PROD_H
#include <utility>
#include <vector>
#include <memory>
#include "twod_matrix.hh" #include "twod_matrix.hh"
#include "permutation.hh" #include "permutation.hh"
#include "int_sequence.hh" #include "int_sequence.hh"
@ -65,14 +69,12 @@ public:
: rows(dim, 0), cols(dim, 0) : rows(dim, 0), cols(dim, 0)
{ {
} }
KronProdDimens(const KronProdDimens &kd) KronProdDimens(const KronProdDimens &kd) = default;
KronProdDimens(KronProdDimens &&kd) = default;
= default;
KronProdDimens(const KronProdDimens &kd, int i); KronProdDimens(const KronProdDimens &kd, int i);
KronProdDimens & KronProdDimens &operator=(const KronProdDimens &kd) = default;
operator=(const KronProdDimens &kd) KronProdDimens &operator=(KronProdDimens &&kd) = default;
= default;
bool bool
operator==(const KronProdDimens &kd) const operator==(const KronProdDimens &kd) const
{ {
@ -87,17 +89,18 @@ public:
void void
setRC(int i, int r, int c) setRC(int i, int r, int c)
{ {
rows[i] = r; cols[i] = c; rows[i] = r;
cols[i] = c;
} }
void std::pair<int, int>
getRC(int i, int &r, int &c) const getRC(int i) const
{ {
r = rows[i]; c = cols[i]; return { rows[i], cols[i] };
} }
void std::pair<int, int>
getRC(int &r, int &c) const getRC() const
{ {
r = rows.mult(); c = cols.mult(); return { rows.mult(), cols.mult() };
} }
int int
nrows() const nrows() const
@ -145,11 +148,9 @@ public:
: kpd(kd) : kpd(kd)
{ {
} }
KronProd(const KronProd &kp) KronProd(const KronProd &kp) = default;
KronProd(KronProd &&kp) = default;
= default; virtual ~KronProd() = default;
virtual ~KronProd()
= default;
int int
dimen() const dimen() const
@ -218,16 +219,13 @@ class KronProdAll : public KronProd
friend class KronProdIAI; friend class KronProdIAI;
friend class KronProdAI; friend class KronProdAI;
protected: protected:
const TwoDMatrix **const matlist; std::vector<const TwoDMatrix *> matlist;
public: public:
KronProdAll(int dim) KronProdAll(int dim)
: KronProd(dim), matlist(new const TwoDMatrix *[dim]) : KronProd(dim), matlist(dim)
{ {
} }
~KronProdAll() override ~KronProdAll() override = default;
{
delete [] matlist;
}
void setMat(int i, const TwoDMatrix &m); void setMat(int i, const TwoDMatrix &m);
void setUnit(int i, int n); void setUnit(int i, int n);
const TwoDMatrix & const TwoDMatrix &
@ -237,7 +235,7 @@ public:
} }
void mult(const ConstTwoDMatrix &in, TwoDMatrix &out) const override; void mult(const ConstTwoDMatrix &in, TwoDMatrix &out) const override;
Vector *multRows(const IntSequence &irows) const; std::unique_ptr<Vector> multRows(const IntSequence &irows) const;
private: private:
bool isUnit() const; bool isUnit() const;
}; };

View File

@ -386,14 +386,13 @@ FPSTensor::FPSTensor(const TensorDimens &td, const Equivalence &e, const Permuta
auto sl = a.getMap().lower_bound(c); auto sl = a.getMap().lower_bound(c);
if (sl != a.getMap().end()) if (sl != a.getMap().end())
{ {
Vector *row_prod = kp.multRows(run.getCoor()); auto row_prod = kp.multRows(run.getCoor());
auto su = a.getMap().upper_bound(c); auto su = a.getMap().upper_bound(c);
for (auto srun = sl; srun != su; ++srun) for (auto srun = sl; srun != su; ++srun)
{ {
Vector out_row{getRow((*srun).second.first)}; Vector out_row{getRow((*srun).second.first)};
out_row.add((*srun).second.second, *row_prod); out_row.add((*srun).second.second, *row_prod);
} }
delete row_prod;
} }
} }
} }