Dynare++ tensor library: modernization of Kronecker product classes

time-shift
Sébastien Villemot 2019-02-19 12:34:43 +01:00
parent 6e747b5dba
commit d08ca8ca7f
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
3 changed files with 61 additions and 84 deletions

View File

@ -3,7 +3,7 @@
#include "kron_prod.hh"
#include "tl_exception.hh"
#include <cstdio>
#include <tuple>
/* Here we construct Kronecker product dimensions from Kronecker
product dimensions by picking a given matrix and all other set to
@ -20,8 +20,8 @@
Then we fork according to |i|. */
KronProdDimens::KronProdDimens(const KronProdDimens &kd, int i)
: rows((i == 0 || i == kd.dimen()-1) ? (2) : (3)),
cols((i == 0 || i == kd.dimen()-1) ? (2) : (3))
: rows((i == 0 || i == kd.dimen()-1) ? 2 : 3),
cols((i == 0 || i == kd.dimen()-1) ? 2 : 3)
{
TL_RAISE_IF(i < 0 || i >= kd.dimen(),
"Wrong index for pickup in KronProdDimens constructor");
@ -72,9 +72,8 @@ KronProdDimens::KronProdDimens(const KronProdDimens &kd, int i)
void
KronProd::checkDimForMult(const ConstTwoDMatrix &in, const TwoDMatrix &out) const
{
int my_rows;
int my_cols;
kpd.getRC(my_rows, my_cols);
int my_rows, my_cols;
std::tie(my_rows, my_cols) = kpd.getRC();
TL_RAISE_IF(in.nrows() != out.nrows() || in.ncols() != my_rows,
"Wrong dimensions for KronProd in KronProd::checkDimForMult");
}
@ -137,8 +136,8 @@ KronProdIA::mult(const ConstTwoDMatrix &in, TwoDMatrix &out) const
for (int i = 0; i < id_cols; i++)
{
TwoDMatrix outi(out, i *a.ncols(), a.ncols());
ConstTwoDMatrix ini(in, i *a.nrows(), a.nrows());
TwoDMatrix outi(out, i * a.ncols(), a.ncols());
ConstTwoDMatrix ini(in, i * a.nrows(), a.nrows());
outi.mult(ini, a);
}
}
@ -221,9 +220,8 @@ KronProdIAI::mult(const ConstTwoDMatrix &in, TwoDMatrix &out) const
int id_cols = kpd.cols[0];
KronProdAI akronid(*this);
int in_bl_width;
int out_bl_width;
akronid.kpd.getRC(in_bl_width, out_bl_width);
int in_bl_width, out_bl_width;
std::tie(in_bl_width, out_bl_width) = akronid.kpd.getRC();
for (int i = 0; i < id_cols; i++)
{
@ -279,7 +277,7 @@ KronProdAll::mult(const ConstTwoDMatrix &in, TwoDMatrix &out) const
}
int c;
TwoDMatrix *last = nullptr;
std::unique_ptr<TwoDMatrix> last;
// perform first multiplication AI
/* Here we have to construct $A_1\otimes I$, allocate intermediate
@ -288,30 +286,25 @@ KronProdAll::mult(const ConstTwoDMatrix &in, TwoDMatrix &out) const
{
KronProdAI akronid(*this);
c = akronid.kpd.ncols();
last = new TwoDMatrix(in.nrows(), c);
last = std::make_unique<TwoDMatrix>(in.nrows(), c);
akronid.mult(in, *last);
}
else
{
last = new TwoDMatrix(in.nrows(), in.ncols(), Vector{in.getData()});
}
last = std::make_unique<TwoDMatrix>(in.nrows(), in.ncols(), Vector{in.getData()});
// perform intermediate multiplications IAI
/* Here we go through all $I\otimes A_i\otimes I$, construct the
product, allocate new storage for result |newlast|, perform the
multiplication, deallocate old |last|, and set |last| to |newlast|. */
for (int i = 1; i < dimen()-1; i++)
{
if (matlist[i])
{
KronProdIAI interkron(*this, i);
c = interkron.kpd.ncols();
auto *newlast = new TwoDMatrix(in.nrows(), c);
interkron.mult(*last, *newlast);
delete last;
last = newlast;
}
}
if (matlist[i])
{
KronProdIAI interkron(*this, i);
c = interkron.kpd.ncols();
auto newlast = std::make_unique<TwoDMatrix>(in.nrows(), c);
interkron.mult(*last, *newlast);
last = std::move(newlast);
}
// perform last multiplication IA
/* Here just construct $I\otimes A_n$ and perform multiplication and
@ -322,25 +315,22 @@ KronProdAll::mult(const ConstTwoDMatrix &in, TwoDMatrix &out) const
idkrona.mult(*last, out);
}
else
{
out = *last;
}
delete last;
out = *last;
}
/* This calculates a Kornecker product of rows of matrices, the row
indices are given by the integer sequence. The result is allocated and
returned. The caller is repsonsible for its deallocation. */
returned. */
Vector *
std::unique_ptr<Vector>
KronProdAll::multRows(const IntSequence &irows) const
{
TL_RAISE_IF(irows.size() != dimen(),
"Wrong length of row indices in KronProdAll::multRows");
Vector *last = nullptr;
ConstVector *row;
std::vector<Vector *> to_delete;
std::unique_ptr<Vector> last;
std::unique_ptr<ConstVector> row;
std::vector<std::unique_ptr<Vector>> to_delete;
for (int i = 0; i < dimen(); i++)
{
int j = dimen()-1-i;
@ -352,14 +342,14 @@ KronProdAll::multRows(const IntSequence &irows) const
the |row| as ConstVector of this vector, which sheduled for
deallocation. */
if (matlist[j])
row = new ConstVector(matlist[j]->getRow(irows[j]));
row = std::make_unique<ConstVector>(matlist[j]->getRow(irows[j]));
else
{
auto *aux = new Vector(ncols(j));
auto aux = std::make_unique<Vector>(ncols(j));
aux->zeros();
(*aux)[irows[j]] = 1.0;
to_delete.push_back(aux);
row = new ConstVector(*aux);
row = std::make_unique<ConstVector>(*aux);
to_delete.emplace_back(std::move(aux));
}
// set |last| to product of |row| and |last|
@ -368,24 +358,15 @@ KronProdAll::multRows(const IntSequence &irows) const
then we only make |last| equal to |row|. */
if (last)
{
Vector *newlast;
newlast = new Vector(last->length()*row->length());
auto newlast = std::make_unique<Vector>(last->length()*row->length());
kronMult(*row, ConstVector(*last), *newlast);
delete last;
last = newlast;
last = std::move(newlast);
}
else
{
last = new Vector(*row);
}
delete row;
last = std::make_unique<Vector>(*row);
}
for (auto & i : to_delete)
delete i;
return last;
return std::move(last);
}
/* This permutes the matrices so that the new ordering would minimize
@ -402,7 +383,8 @@ KronProdAllOptim::optimizeOrder()
int swaps = 0;
for (int j = 0; j < dimen()-1; j++)
{
if (((double) kpd.rows[j])/kpd.cols[j] < ((double) kpd.rows[j+1])/kpd.cols[j+1])
if (static_cast<double>(kpd.rows[j])/kpd.cols[j]
< static_cast<double>(kpd.rows[j+1])/kpd.cols[j+1])
{
// swap dimensions and matrices at |j| and |j+1|
int s = kpd.rows[j+1];
@ -423,8 +405,6 @@ KronProdAllOptim::optimizeOrder()
}
}
if (swaps == 0)
{
return;
}
return;
}
}

View File

@ -26,6 +26,10 @@
#ifndef KRON_PROD_H
#define KRON_PROD_H
#include <utility>
#include <vector>
#include <memory>
#include "twod_matrix.hh"
#include "permutation.hh"
#include "int_sequence.hh"
@ -65,14 +69,12 @@ public:
: rows(dim, 0), cols(dim, 0)
{
}
KronProdDimens(const KronProdDimens &kd)
= default;
KronProdDimens(const KronProdDimens &kd) = default;
KronProdDimens(KronProdDimens &&kd) = default;
KronProdDimens(const KronProdDimens &kd, int i);
KronProdDimens &
operator=(const KronProdDimens &kd)
= default;
KronProdDimens &operator=(const KronProdDimens &kd) = default;
KronProdDimens &operator=(KronProdDimens &&kd) = default;
bool
operator==(const KronProdDimens &kd) const
{
@ -87,17 +89,18 @@ public:
void
setRC(int i, int r, int c)
{
rows[i] = r; cols[i] = c;
rows[i] = r;
cols[i] = c;
}
void
getRC(int i, int &r, int &c) const
std::pair<int, int>
getRC(int i) const
{
r = rows[i]; c = cols[i];
return { rows[i], cols[i] };
}
void
getRC(int &r, int &c) const
std::pair<int, int>
getRC() const
{
r = rows.mult(); c = cols.mult();
return { rows.mult(), cols.mult() };
}
int
nrows() const
@ -145,11 +148,9 @@ public:
: kpd(kd)
{
}
KronProd(const KronProd &kp)
= default;
virtual ~KronProd()
= default;
KronProd(const KronProd &kp) = default;
KronProd(KronProd &&kp) = default;
virtual ~KronProd() = default;
int
dimen() const
@ -218,16 +219,13 @@ class KronProdAll : public KronProd
friend class KronProdIAI;
friend class KronProdAI;
protected:
const TwoDMatrix **const matlist;
std::vector<const TwoDMatrix *> matlist;
public:
KronProdAll(int dim)
: KronProd(dim), matlist(new const TwoDMatrix *[dim])
: KronProd(dim), matlist(dim)
{
}
~KronProdAll() override
{
delete [] matlist;
}
~KronProdAll() override = default;
void setMat(int i, const TwoDMatrix &m);
void setUnit(int i, int n);
const TwoDMatrix &
@ -237,7 +235,7 @@ public:
}
void mult(const ConstTwoDMatrix &in, TwoDMatrix &out) const override;
Vector *multRows(const IntSequence &irows) const;
std::unique_ptr<Vector> multRows(const IntSequence &irows) const;
private:
bool isUnit() const;
};

View File

@ -386,14 +386,13 @@ FPSTensor::FPSTensor(const TensorDimens &td, const Equivalence &e, const Permuta
auto sl = a.getMap().lower_bound(c);
if (sl != a.getMap().end())
{
Vector *row_prod = kp.multRows(run.getCoor());
auto row_prod = kp.multRows(run.getCoor());
auto su = a.getMap().upper_bound(c);
for (auto srun = sl; srun != su; ++srun)
{
Vector out_row{getRow((*srun).second.first)};
out_row.add((*srun).second.second, *row_prod);
}
delete row_prod;
}
}
}