Dynare++ tensor library: modernization of Kronecker product classes
parent
6e747b5dba
commit
d08ca8ca7f
|
@ -3,7 +3,7 @@
|
||||||
#include "kron_prod.hh"
|
#include "kron_prod.hh"
|
||||||
#include "tl_exception.hh"
|
#include "tl_exception.hh"
|
||||||
|
|
||||||
#include <cstdio>
|
#include <tuple>
|
||||||
|
|
||||||
/* Here we construct Kronecker product dimensions from Kronecker
|
/* Here we construct Kronecker product dimensions from Kronecker
|
||||||
product dimensions by picking a given matrix and all other set to
|
product dimensions by picking a given matrix and all other set to
|
||||||
|
@ -20,8 +20,8 @@
|
||||||
Then we fork according to |i|. */
|
Then we fork according to |i|. */
|
||||||
|
|
||||||
KronProdDimens::KronProdDimens(const KronProdDimens &kd, int i)
|
KronProdDimens::KronProdDimens(const KronProdDimens &kd, int i)
|
||||||
: rows((i == 0 || i == kd.dimen()-1) ? (2) : (3)),
|
: rows((i == 0 || i == kd.dimen()-1) ? 2 : 3),
|
||||||
cols((i == 0 || i == kd.dimen()-1) ? (2) : (3))
|
cols((i == 0 || i == kd.dimen()-1) ? 2 : 3)
|
||||||
{
|
{
|
||||||
TL_RAISE_IF(i < 0 || i >= kd.dimen(),
|
TL_RAISE_IF(i < 0 || i >= kd.dimen(),
|
||||||
"Wrong index for pickup in KronProdDimens constructor");
|
"Wrong index for pickup in KronProdDimens constructor");
|
||||||
|
@ -72,9 +72,8 @@ KronProdDimens::KronProdDimens(const KronProdDimens &kd, int i)
|
||||||
void
|
void
|
||||||
KronProd::checkDimForMult(const ConstTwoDMatrix &in, const TwoDMatrix &out) const
|
KronProd::checkDimForMult(const ConstTwoDMatrix &in, const TwoDMatrix &out) const
|
||||||
{
|
{
|
||||||
int my_rows;
|
int my_rows, my_cols;
|
||||||
int my_cols;
|
std::tie(my_rows, my_cols) = kpd.getRC();
|
||||||
kpd.getRC(my_rows, my_cols);
|
|
||||||
TL_RAISE_IF(in.nrows() != out.nrows() || in.ncols() != my_rows,
|
TL_RAISE_IF(in.nrows() != out.nrows() || in.ncols() != my_rows,
|
||||||
"Wrong dimensions for KronProd in KronProd::checkDimForMult");
|
"Wrong dimensions for KronProd in KronProd::checkDimForMult");
|
||||||
}
|
}
|
||||||
|
@ -137,8 +136,8 @@ KronProdIA::mult(const ConstTwoDMatrix &in, TwoDMatrix &out) const
|
||||||
|
|
||||||
for (int i = 0; i < id_cols; i++)
|
for (int i = 0; i < id_cols; i++)
|
||||||
{
|
{
|
||||||
TwoDMatrix outi(out, i *a.ncols(), a.ncols());
|
TwoDMatrix outi(out, i * a.ncols(), a.ncols());
|
||||||
ConstTwoDMatrix ini(in, i *a.nrows(), a.nrows());
|
ConstTwoDMatrix ini(in, i * a.nrows(), a.nrows());
|
||||||
outi.mult(ini, a);
|
outi.mult(ini, a);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -221,9 +220,8 @@ KronProdIAI::mult(const ConstTwoDMatrix &in, TwoDMatrix &out) const
|
||||||
int id_cols = kpd.cols[0];
|
int id_cols = kpd.cols[0];
|
||||||
|
|
||||||
KronProdAI akronid(*this);
|
KronProdAI akronid(*this);
|
||||||
int in_bl_width;
|
int in_bl_width, out_bl_width;
|
||||||
int out_bl_width;
|
std::tie(in_bl_width, out_bl_width) = akronid.kpd.getRC();
|
||||||
akronid.kpd.getRC(in_bl_width, out_bl_width);
|
|
||||||
|
|
||||||
for (int i = 0; i < id_cols; i++)
|
for (int i = 0; i < id_cols; i++)
|
||||||
{
|
{
|
||||||
|
@ -279,7 +277,7 @@ KronProdAll::mult(const ConstTwoDMatrix &in, TwoDMatrix &out) const
|
||||||
}
|
}
|
||||||
|
|
||||||
int c;
|
int c;
|
||||||
TwoDMatrix *last = nullptr;
|
std::unique_ptr<TwoDMatrix> last;
|
||||||
|
|
||||||
// perform first multiplication AI
|
// perform first multiplication AI
|
||||||
/* Here we have to construct $A_1\otimes I$, allocate intermediate
|
/* Here we have to construct $A_1\otimes I$, allocate intermediate
|
||||||
|
@ -288,30 +286,25 @@ KronProdAll::mult(const ConstTwoDMatrix &in, TwoDMatrix &out) const
|
||||||
{
|
{
|
||||||
KronProdAI akronid(*this);
|
KronProdAI akronid(*this);
|
||||||
c = akronid.kpd.ncols();
|
c = akronid.kpd.ncols();
|
||||||
last = new TwoDMatrix(in.nrows(), c);
|
last = std::make_unique<TwoDMatrix>(in.nrows(), c);
|
||||||
akronid.mult(in, *last);
|
akronid.mult(in, *last);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
last = std::make_unique<TwoDMatrix>(in.nrows(), in.ncols(), Vector{in.getData()});
|
||||||
last = new TwoDMatrix(in.nrows(), in.ncols(), Vector{in.getData()});
|
|
||||||
}
|
|
||||||
|
|
||||||
// perform intermediate multiplications IAI
|
// perform intermediate multiplications IAI
|
||||||
/* Here we go through all $I\otimes A_i\otimes I$, construct the
|
/* Here we go through all $I\otimes A_i\otimes I$, construct the
|
||||||
product, allocate new storage for result |newlast|, perform the
|
product, allocate new storage for result |newlast|, perform the
|
||||||
multiplication, deallocate old |last|, and set |last| to |newlast|. */
|
multiplication, deallocate old |last|, and set |last| to |newlast|. */
|
||||||
for (int i = 1; i < dimen()-1; i++)
|
for (int i = 1; i < dimen()-1; i++)
|
||||||
{
|
if (matlist[i])
|
||||||
if (matlist[i])
|
{
|
||||||
{
|
KronProdIAI interkron(*this, i);
|
||||||
KronProdIAI interkron(*this, i);
|
c = interkron.kpd.ncols();
|
||||||
c = interkron.kpd.ncols();
|
auto newlast = std::make_unique<TwoDMatrix>(in.nrows(), c);
|
||||||
auto *newlast = new TwoDMatrix(in.nrows(), c);
|
interkron.mult(*last, *newlast);
|
||||||
interkron.mult(*last, *newlast);
|
last = std::move(newlast);
|
||||||
delete last;
|
}
|
||||||
last = newlast;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// perform last multiplication IA
|
// perform last multiplication IA
|
||||||
/* Here just construct $I\otimes A_n$ and perform multiplication and
|
/* Here just construct $I\otimes A_n$ and perform multiplication and
|
||||||
|
@ -322,25 +315,22 @@ KronProdAll::mult(const ConstTwoDMatrix &in, TwoDMatrix &out) const
|
||||||
idkrona.mult(*last, out);
|
idkrona.mult(*last, out);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
out = *last;
|
||||||
out = *last;
|
|
||||||
}
|
|
||||||
delete last;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* This calculates a Kornecker product of rows of matrices, the row
|
/* This calculates a Kornecker product of rows of matrices, the row
|
||||||
indices are given by the integer sequence. The result is allocated and
|
indices are given by the integer sequence. The result is allocated and
|
||||||
returned. The caller is repsonsible for its deallocation. */
|
returned. */
|
||||||
|
|
||||||
Vector *
|
std::unique_ptr<Vector>
|
||||||
KronProdAll::multRows(const IntSequence &irows) const
|
KronProdAll::multRows(const IntSequence &irows) const
|
||||||
{
|
{
|
||||||
TL_RAISE_IF(irows.size() != dimen(),
|
TL_RAISE_IF(irows.size() != dimen(),
|
||||||
"Wrong length of row indices in KronProdAll::multRows");
|
"Wrong length of row indices in KronProdAll::multRows");
|
||||||
|
|
||||||
Vector *last = nullptr;
|
std::unique_ptr<Vector> last;
|
||||||
ConstVector *row;
|
std::unique_ptr<ConstVector> row;
|
||||||
std::vector<Vector *> to_delete;
|
std::vector<std::unique_ptr<Vector>> to_delete;
|
||||||
for (int i = 0; i < dimen(); i++)
|
for (int i = 0; i < dimen(); i++)
|
||||||
{
|
{
|
||||||
int j = dimen()-1-i;
|
int j = dimen()-1-i;
|
||||||
|
@ -352,14 +342,14 @@ KronProdAll::multRows(const IntSequence &irows) const
|
||||||
the |row| as ConstVector of this vector, which sheduled for
|
the |row| as ConstVector of this vector, which sheduled for
|
||||||
deallocation. */
|
deallocation. */
|
||||||
if (matlist[j])
|
if (matlist[j])
|
||||||
row = new ConstVector(matlist[j]->getRow(irows[j]));
|
row = std::make_unique<ConstVector>(matlist[j]->getRow(irows[j]));
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
auto *aux = new Vector(ncols(j));
|
auto aux = std::make_unique<Vector>(ncols(j));
|
||||||
aux->zeros();
|
aux->zeros();
|
||||||
(*aux)[irows[j]] = 1.0;
|
(*aux)[irows[j]] = 1.0;
|
||||||
to_delete.push_back(aux);
|
row = std::make_unique<ConstVector>(*aux);
|
||||||
row = new ConstVector(*aux);
|
to_delete.emplace_back(std::move(aux));
|
||||||
}
|
}
|
||||||
|
|
||||||
// set |last| to product of |row| and |last|
|
// set |last| to product of |row| and |last|
|
||||||
|
@ -368,24 +358,15 @@ KronProdAll::multRows(const IntSequence &irows) const
|
||||||
then we only make |last| equal to |row|. */
|
then we only make |last| equal to |row|. */
|
||||||
if (last)
|
if (last)
|
||||||
{
|
{
|
||||||
Vector *newlast;
|
auto newlast = std::make_unique<Vector>(last->length()*row->length());
|
||||||
newlast = new Vector(last->length()*row->length());
|
|
||||||
kronMult(*row, ConstVector(*last), *newlast);
|
kronMult(*row, ConstVector(*last), *newlast);
|
||||||
delete last;
|
last = std::move(newlast);
|
||||||
last = newlast;
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
last = std::make_unique<Vector>(*row);
|
||||||
last = new Vector(*row);
|
|
||||||
}
|
|
||||||
|
|
||||||
delete row;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto & i : to_delete)
|
return std::move(last);
|
||||||
delete i;
|
|
||||||
|
|
||||||
return last;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* This permutes the matrices so that the new ordering would minimize
|
/* This permutes the matrices so that the new ordering would minimize
|
||||||
|
@ -402,7 +383,8 @@ KronProdAllOptim::optimizeOrder()
|
||||||
int swaps = 0;
|
int swaps = 0;
|
||||||
for (int j = 0; j < dimen()-1; j++)
|
for (int j = 0; j < dimen()-1; j++)
|
||||||
{
|
{
|
||||||
if (((double) kpd.rows[j])/kpd.cols[j] < ((double) kpd.rows[j+1])/kpd.cols[j+1])
|
if (static_cast<double>(kpd.rows[j])/kpd.cols[j]
|
||||||
|
< static_cast<double>(kpd.rows[j+1])/kpd.cols[j+1])
|
||||||
{
|
{
|
||||||
// swap dimensions and matrices at |j| and |j+1|
|
// swap dimensions and matrices at |j| and |j+1|
|
||||||
int s = kpd.rows[j+1];
|
int s = kpd.rows[j+1];
|
||||||
|
@ -423,8 +405,6 @@ KronProdAllOptim::optimizeOrder()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (swaps == 0)
|
if (swaps == 0)
|
||||||
{
|
return;
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,10 @@
|
||||||
#ifndef KRON_PROD_H
|
#ifndef KRON_PROD_H
|
||||||
#define KRON_PROD_H
|
#define KRON_PROD_H
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
#include "twod_matrix.hh"
|
#include "twod_matrix.hh"
|
||||||
#include "permutation.hh"
|
#include "permutation.hh"
|
||||||
#include "int_sequence.hh"
|
#include "int_sequence.hh"
|
||||||
|
@ -65,14 +69,12 @@ public:
|
||||||
: rows(dim, 0), cols(dim, 0)
|
: rows(dim, 0), cols(dim, 0)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
KronProdDimens(const KronProdDimens &kd)
|
KronProdDimens(const KronProdDimens &kd) = default;
|
||||||
|
KronProdDimens(KronProdDimens &&kd) = default;
|
||||||
= default;
|
|
||||||
KronProdDimens(const KronProdDimens &kd, int i);
|
KronProdDimens(const KronProdDimens &kd, int i);
|
||||||
|
|
||||||
KronProdDimens &
|
KronProdDimens &operator=(const KronProdDimens &kd) = default;
|
||||||
operator=(const KronProdDimens &kd)
|
KronProdDimens &operator=(KronProdDimens &&kd) = default;
|
||||||
= default;
|
|
||||||
bool
|
bool
|
||||||
operator==(const KronProdDimens &kd) const
|
operator==(const KronProdDimens &kd) const
|
||||||
{
|
{
|
||||||
|
@ -87,17 +89,18 @@ public:
|
||||||
void
|
void
|
||||||
setRC(int i, int r, int c)
|
setRC(int i, int r, int c)
|
||||||
{
|
{
|
||||||
rows[i] = r; cols[i] = c;
|
rows[i] = r;
|
||||||
|
cols[i] = c;
|
||||||
}
|
}
|
||||||
void
|
std::pair<int, int>
|
||||||
getRC(int i, int &r, int &c) const
|
getRC(int i) const
|
||||||
{
|
{
|
||||||
r = rows[i]; c = cols[i];
|
return { rows[i], cols[i] };
|
||||||
}
|
}
|
||||||
void
|
std::pair<int, int>
|
||||||
getRC(int &r, int &c) const
|
getRC() const
|
||||||
{
|
{
|
||||||
r = rows.mult(); c = cols.mult();
|
return { rows.mult(), cols.mult() };
|
||||||
}
|
}
|
||||||
int
|
int
|
||||||
nrows() const
|
nrows() const
|
||||||
|
@ -145,11 +148,9 @@ public:
|
||||||
: kpd(kd)
|
: kpd(kd)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
KronProd(const KronProd &kp)
|
KronProd(const KronProd &kp) = default;
|
||||||
|
KronProd(KronProd &&kp) = default;
|
||||||
= default;
|
virtual ~KronProd() = default;
|
||||||
virtual ~KronProd()
|
|
||||||
= default;
|
|
||||||
|
|
||||||
int
|
int
|
||||||
dimen() const
|
dimen() const
|
||||||
|
@ -218,16 +219,13 @@ class KronProdAll : public KronProd
|
||||||
friend class KronProdIAI;
|
friend class KronProdIAI;
|
||||||
friend class KronProdAI;
|
friend class KronProdAI;
|
||||||
protected:
|
protected:
|
||||||
const TwoDMatrix **const matlist;
|
std::vector<const TwoDMatrix *> matlist;
|
||||||
public:
|
public:
|
||||||
KronProdAll(int dim)
|
KronProdAll(int dim)
|
||||||
: KronProd(dim), matlist(new const TwoDMatrix *[dim])
|
: KronProd(dim), matlist(dim)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
~KronProdAll() override
|
~KronProdAll() override = default;
|
||||||
{
|
|
||||||
delete [] matlist;
|
|
||||||
}
|
|
||||||
void setMat(int i, const TwoDMatrix &m);
|
void setMat(int i, const TwoDMatrix &m);
|
||||||
void setUnit(int i, int n);
|
void setUnit(int i, int n);
|
||||||
const TwoDMatrix &
|
const TwoDMatrix &
|
||||||
|
@ -237,7 +235,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
void mult(const ConstTwoDMatrix &in, TwoDMatrix &out) const override;
|
void mult(const ConstTwoDMatrix &in, TwoDMatrix &out) const override;
|
||||||
Vector *multRows(const IntSequence &irows) const;
|
std::unique_ptr<Vector> multRows(const IntSequence &irows) const;
|
||||||
private:
|
private:
|
||||||
bool isUnit() const;
|
bool isUnit() const;
|
||||||
};
|
};
|
||||||
|
|
|
@ -386,14 +386,13 @@ FPSTensor::FPSTensor(const TensorDimens &td, const Equivalence &e, const Permuta
|
||||||
auto sl = a.getMap().lower_bound(c);
|
auto sl = a.getMap().lower_bound(c);
|
||||||
if (sl != a.getMap().end())
|
if (sl != a.getMap().end())
|
||||||
{
|
{
|
||||||
Vector *row_prod = kp.multRows(run.getCoor());
|
auto row_prod = kp.multRows(run.getCoor());
|
||||||
auto su = a.getMap().upper_bound(c);
|
auto su = a.getMap().upper_bound(c);
|
||||||
for (auto srun = sl; srun != su; ++srun)
|
for (auto srun = sl; srun != su; ++srun)
|
||||||
{
|
{
|
||||||
Vector out_row{getRow((*srun).second.first)};
|
Vector out_row{getRow((*srun).second.first)};
|
||||||
out_row.add((*srun).second.second, *row_prod);
|
out_row.add((*srun).second.second, *row_prod);
|
||||||
}
|
}
|
||||||
delete row_prod;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue