Initial guess for Lagrange multipliers solved by SVD decomposition to improve conditioning.

time-shift
Stéphane Adjemian (Sedna) 2015-03-25 14:45:09 +01:00
parent 7ca3683281
commit 64346de401
3 changed files with 114 additions and 5 deletions

View File

@ -327,11 +327,11 @@ MultInitSS::MultInitSS(const PlannerBuilder& pb, const Vector& pvals, Vector& yy
ogp::FormulaCustomEvaluator fe(builder.static_tree, terms);
fe.eval(dssav, *this);
// solve overdetermined system b+F*lambda=0 => lambda=-(F^T*F)^{-1}*F^T*b
GeneralMatrix FtF(F, "transpose", F);
Vector lambda(builder.diff_f_static.dim2());
F.multVecTrans(0.0, lambda, -1.0, b);
ConstGeneralMatrix(FtF).multInvLeft(lambda);
// solve overdetermined system b+F*lambda=0 using SVD decomposition
SVDDecomp decomp(F);
Vector lambda(builder.diff_f_static.dim2());
decomp.solve(b, lambda);
lambda.mult(-1);
// take values of lambda and put it to yy
for (int fi = 0; fi < builder.diff_f_static.dim2(); fi++) {

View File

@ -481,3 +481,76 @@ void ConstGeneralMatrix::print() const
printf("\n");
}
}
void SVDDecomp::construct(const GeneralMatrix& A)
{
// quick exit if empty matrix
if (minmn == 0) {
U.unit();
VT.unit();
conv = true;
return;
}
// make copy of the matrix
GeneralMatrix AA(A);
lapack_int m = AA.numRows();
lapack_int n = AA.numCols();
double* a = AA.base();
lapack_int lda = AA.getLD();
double* s = sigma.base();
double* u = U.base();
lapack_int ldu = U.getLD();
double* vt = VT.base();
lapack_int ldvt = VT.getLD();
double tmpwork;
lapack_int lwork = -1;
lapack_int info;
lapack_int* iwork = new lapack_int[8*minmn];
// query for optimal lwork
dgesdd("A", &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, &tmpwork,
&lwork, iwork, &info);
lwork = (lapack_int)tmpwork;
Vector work(lwork);
// do the decomposition
dgesdd("A", &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work.base(),
&lwork, iwork, &info);
delete [] iwork;
if (info < 0)
throw SYLV_MES_EXCEPTION("Internal error in SVDDecomp constructor");
if (info == 0)
conv = true;
}
void SVDDecomp::solve(const GeneralMatrix& B, GeneralMatrix& X) const
{
if (B.numRows() != U.numRows())
throw SYLV_MES_EXCEPTION("Incompatible number of rows ");
// reciprocal condition number for determination of zeros in the
// end of sigma
double rcond = 1e-13;
// solve U: B = U^T*B
GeneralMatrix UTB(U, "trans", B);
// determine nz=number of zeros in the end of sigma
int nz = 0;
while (nz < minmn && sigma[minmn-1-nz] < rcond*sigma[0])
nz++;
// take relevant B for sigma inversion
int m = U.numRows();
int n = VT.numCols();
GeneralMatrix Bprime(UTB, m-minmn, 0, minmn-nz, B.numCols());
// solve sigma
for (int i = 0; i < minmn-nz; i++)
Vector(i, Bprime).mult(1.0/sigma[i]);
// solve VT
X.zeros();
//- copy Bprime to right place of X
for (int i = 0; i < minmn-nz; i++)
Vector(n-minmn+i, X) = ConstVector(i, Bprime);
//- multiply with VT
X.multLeftTrans(VT);
}

View File

@ -7,6 +7,8 @@
#include "Vector.h"
#include <algorithm>
class GeneralMatrix;
class ConstGeneralMatrix {
@ -272,6 +274,40 @@ private:
static int md_length;
};
class SVDDecomp {
protected:
/** Minimum of number of rows and columns of the decomposed
* matrix. */
const int minmn;
/** Singular values. */
Vector sigma;
/** Orthogonal matrix U. */
GeneralMatrix U;
/** Orthogonal matrix V^T. */
GeneralMatrix VT;
/** Convered flag. */
bool conv;
public:
SVDDecomp(const GeneralMatrix& A)
: minmn(std::min<int>(A.numRows(), A.numCols())),
sigma(minmn),
U(A.numRows(), A.numRows()),
VT(A.numCols(), A.numCols()),
conv(false)
{construct(A);}
const GeneralMatrix& getU() const
{return U;}
const GeneralMatrix& getVT() const
{return VT;}
void solve(const GeneralMatrix& B, GeneralMatrix& X) const;
void solve(const Vector& b, Vector& x) const
{
GeneralMatrix xmat(x.base(), x.length(), 1);
solve(GeneralMatrix(b.base(), b.length(), 1), xmat);
}
private:
void construct(const GeneralMatrix& A);
};