Dynare++ tensor library: modernize normal moments computation

time-shift
Sébastien Villemot 2019-02-19 12:53:02 +01:00
parent d08ca8ca7f
commit 0f37649755
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
1 changed files with 22 additions and 26 deletions

View File

@ -1,5 +1,7 @@
// Copyright 2004, Ondra Kamenik // Copyright 2004, Ondra Kamenik
#include <memory>
#include "normal_moments.hh" #include "normal_moments.hh"
#include "permutation.hh" #include "permutation.hh"
#include "kron_prod.hh" #include "kron_prod.hh"
@ -30,16 +32,15 @@ UNormalMoments::generateMoments(int maxdim, const TwoDMatrix &v)
auto *mom2 = new URSingleTensor(nv, 2); auto *mom2 = new URSingleTensor(nv, 2);
mom2->getData() = v.getData(); mom2->getData() = v.getData();
insert(mom2); insert(mom2);
auto *kronv = new URSingleTensor(nv, 2); auto kronv = std::make_unique<URSingleTensor>(nv, 2);
kronv->getData() = v.getData(); kronv->getData() = v.getData();
for (int d = 4; d <= maxdim; d += 2) for (int d = 4; d <= maxdim; d += 2)
{ {
auto *newkronv = new URSingleTensor(nv, d); auto newkronv = std::make_unique<URSingleTensor>(nv, d);
KronProd::kronMult(ConstVector(v.getData()), KronProd::kronMult(ConstVector(v.getData()),
ConstVector(kronv->getData()), ConstVector(kronv->getData()),
newkronv->getData()); newkronv->getData());
delete kronv; kronv = std::move(newkronv);
kronv = newkronv;
auto *mom = new URSingleTensor(nv, d); auto *mom = new URSingleTensor(nv, d);
// apply $F_n$ to |kronv| // apply $F_n$ to |kronv|
/* Here we go through all equivalences, select only those having 2 /* Here we go through all equivalences, select only those having 2
@ -52,24 +53,21 @@ UNormalMoments::generateMoments(int maxdim, const TwoDMatrix &v)
how the |Equivalence::apply| method works. */ how the |Equivalence::apply| method works. */
mom->zeros(); mom->zeros();
const EquivalenceSet eset = ebundle.get(d); const EquivalenceSet eset = ebundle.get(d);
for (const auto & cit : eset) for (const auto &cit : eset)
{ if (selectEquiv(cit))
if (selectEquiv(cit)) {
{ Permutation per(cit);
Permutation per(cit); per.inverse();
per.inverse(); for (Tensor::index it = kronv->begin(); it != kronv->end(); ++it)
for (Tensor::index it = kronv->begin(); it != kronv->end(); ++it) {
{ IntSequence ind(kronv->dimen());
IntSequence ind(kronv->dimen()); per.apply(it.getCoor(), ind);
per.apply(it.getCoor(), ind); Tensor::index it2(*mom, ind);
Tensor::index it2(*mom, ind); mom->get(*it2, 0) += kronv->get(*it, 0);
mom->get(*it2, 0) += kronv->get(*it, 0); }
} }
}
}
insert(mom); insert(mom);
} }
delete kronv;
} }
/* We return |true| for an equivalence whose each class has 2 elements. */ /* We return |true| for an equivalence whose each class has 2 elements. */
@ -79,11 +77,9 @@ UNormalMoments::selectEquiv(const Equivalence &e)
{ {
if (2*e.numClasses() != e.getN()) if (2*e.numClasses() != e.getN())
return false; return false;
for (const auto & si : e) for (const auto &si : e)
{ if (si.length() != 2)
if (si.length() != 2) return false;
return false;
}
return true; return true;
} }
@ -92,7 +88,7 @@ UNormalMoments::selectEquiv(const Equivalence &e)
FNormalMoments::FNormalMoments(const UNormalMoments &moms) FNormalMoments::FNormalMoments(const UNormalMoments &moms)
: TensorContainer<FRSingleTensor>(1) : TensorContainer<FRSingleTensor>(1)
{ {
for (const auto & mom : moms) for (const auto &mom : moms)
{ {
auto *fm = new FRSingleTensor(*(mom.second)); auto *fm = new FRSingleTensor(*(mom.second));
insert(fm); insert(fm);