Dynare++ tensor library: modernize normal moments computation

time-shift
Sébastien Villemot 2019-02-19 12:53:02 +01:00
parent d08ca8ca7f
commit 0f37649755
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
1 changed files with 22 additions and 26 deletions

View File

@ -1,5 +1,7 @@
// Copyright 2004, Ondra Kamenik
#include <memory>
#include "normal_moments.hh"
#include "permutation.hh"
#include "kron_prod.hh"
@ -30,16 +32,15 @@ UNormalMoments::generateMoments(int maxdim, const TwoDMatrix &v)
auto *mom2 = new URSingleTensor(nv, 2);
mom2->getData() = v.getData();
insert(mom2);
auto *kronv = new URSingleTensor(nv, 2);
auto kronv = std::make_unique<URSingleTensor>(nv, 2);
kronv->getData() = v.getData();
for (int d = 4; d <= maxdim; d += 2)
{
auto *newkronv = new URSingleTensor(nv, d);
auto newkronv = std::make_unique<URSingleTensor>(nv, d);
KronProd::kronMult(ConstVector(v.getData()),
ConstVector(kronv->getData()),
newkronv->getData());
delete kronv;
kronv = newkronv;
kronv = std::move(newkronv);
auto *mom = new URSingleTensor(nv, d);
// apply $F_n$ to |kronv|
/* Here we go through all equivalences, select only those having 2
@ -52,24 +53,21 @@ UNormalMoments::generateMoments(int maxdim, const TwoDMatrix &v)
how the |Equivalence::apply| method works. */
mom->zeros();
const EquivalenceSet eset = ebundle.get(d);
for (const auto & cit : eset)
{
if (selectEquiv(cit))
{
Permutation per(cit);
per.inverse();
for (Tensor::index it = kronv->begin(); it != kronv->end(); ++it)
{
IntSequence ind(kronv->dimen());
per.apply(it.getCoor(), ind);
Tensor::index it2(*mom, ind);
mom->get(*it2, 0) += kronv->get(*it, 0);
}
}
}
for (const auto &cit : eset)
if (selectEquiv(cit))
{
Permutation per(cit);
per.inverse();
for (Tensor::index it = kronv->begin(); it != kronv->end(); ++it)
{
IntSequence ind(kronv->dimen());
per.apply(it.getCoor(), ind);
Tensor::index it2(*mom, ind);
mom->get(*it2, 0) += kronv->get(*it, 0);
}
}
insert(mom);
}
delete kronv;
}
/* We return |true| for an equivalence whose each class has 2 elements. */
@ -79,11 +77,9 @@ UNormalMoments::selectEquiv(const Equivalence &e)
{
if (2*e.numClasses() != e.getN())
return false;
for (const auto & si : e)
{
if (si.length() != 2)
return false;
}
for (const auto &si : e)
if (si.length() != 2)
return false;
return true;
}
@ -92,7 +88,7 @@ UNormalMoments::selectEquiv(const Equivalence &e)
FNormalMoments::FNormalMoments(const UNormalMoments &moms)
: TensorContainer<FRSingleTensor>(1)
{
for (const auto & mom : moms)
for (const auto &mom : moms)
{
auto *fm = new FRSingleTensor(*(mom.second));
insert(fm);