Initial guess for Lagrange multipliers solved by SVD decomposition to improve conditioning.
parent
7ca3683281
commit
64346de401
|
@ -327,11 +327,11 @@ MultInitSS::MultInitSS(const PlannerBuilder& pb, const Vector& pvals, Vector& yy
|
|||
ogp::FormulaCustomEvaluator fe(builder.static_tree, terms);
|
||||
fe.eval(dssav, *this);
|
||||
|
||||
// solve overdetermined system b+F*lambda=0 => lambda=-(F^T*F)^{-1}*F^T*b
|
||||
GeneralMatrix FtF(F, "transpose", F);
|
||||
Vector lambda(builder.diff_f_static.dim2());
|
||||
F.multVecTrans(0.0, lambda, -1.0, b);
|
||||
ConstGeneralMatrix(FtF).multInvLeft(lambda);
|
||||
// solve overdetermined system b+F*lambda=0 using SVD decomposition
|
||||
SVDDecomp decomp(F);
|
||||
Vector lambda(builder.diff_f_static.dim2());
|
||||
decomp.solve(b, lambda);
|
||||
lambda.mult(-1);
|
||||
|
||||
// take values of lambda and put it to yy
|
||||
for (int fi = 0; fi < builder.diff_f_static.dim2(); fi++) {
|
||||
|
|
|
@ -481,3 +481,76 @@ void ConstGeneralMatrix::print() const
|
|||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
void SVDDecomp::construct(const GeneralMatrix& A)
|
||||
{
|
||||
// quick exit if empty matrix
|
||||
if (minmn == 0) {
|
||||
U.unit();
|
||||
VT.unit();
|
||||
conv = true;
|
||||
return;
|
||||
}
|
||||
|
||||
// make copy of the matrix
|
||||
GeneralMatrix AA(A);
|
||||
|
||||
lapack_int m = AA.numRows();
|
||||
lapack_int n = AA.numCols();
|
||||
double* a = AA.base();
|
||||
lapack_int lda = AA.getLD();
|
||||
double* s = sigma.base();
|
||||
double* u = U.base();
|
||||
lapack_int ldu = U.getLD();
|
||||
double* vt = VT.base();
|
||||
lapack_int ldvt = VT.getLD();
|
||||
double tmpwork;
|
||||
lapack_int lwork = -1;
|
||||
lapack_int info;
|
||||
|
||||
lapack_int* iwork = new lapack_int[8*minmn];
|
||||
// query for optimal lwork
|
||||
dgesdd("A", &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, &tmpwork,
|
||||
&lwork, iwork, &info);
|
||||
lwork = (lapack_int)tmpwork;
|
||||
Vector work(lwork);
|
||||
// do the decomposition
|
||||
dgesdd("A", &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work.base(),
|
||||
&lwork, iwork, &info);
|
||||
delete [] iwork;
|
||||
if (info < 0)
|
||||
throw SYLV_MES_EXCEPTION("Internal error in SVDDecomp constructor");
|
||||
if (info == 0)
|
||||
conv = true;
|
||||
}
|
||||
|
||||
void SVDDecomp::solve(const GeneralMatrix& B, GeneralMatrix& X) const
|
||||
{
|
||||
if (B.numRows() != U.numRows())
|
||||
throw SYLV_MES_EXCEPTION("Incompatible number of rows ");
|
||||
|
||||
// reciprocal condition number for determination of zeros in the
|
||||
// end of sigma
|
||||
double rcond = 1e-13;
|
||||
|
||||
// solve U: B = U^T*B
|
||||
GeneralMatrix UTB(U, "trans", B);
|
||||
// determine nz=number of zeros in the end of sigma
|
||||
int nz = 0;
|
||||
while (nz < minmn && sigma[minmn-1-nz] < rcond*sigma[0])
|
||||
nz++;
|
||||
// take relevant B for sigma inversion
|
||||
int m = U.numRows();
|
||||
int n = VT.numCols();
|
||||
GeneralMatrix Bprime(UTB, m-minmn, 0, minmn-nz, B.numCols());
|
||||
// solve sigma
|
||||
for (int i = 0; i < minmn-nz; i++)
|
||||
Vector(i, Bprime).mult(1.0/sigma[i]);
|
||||
// solve VT
|
||||
X.zeros();
|
||||
//- copy Bprime to right place of X
|
||||
for (int i = 0; i < minmn-nz; i++)
|
||||
Vector(n-minmn+i, X) = ConstVector(i, Bprime);
|
||||
//- multiply with VT
|
||||
X.multLeftTrans(VT);
|
||||
}
|
||||
|
|
|
@ -7,6 +7,8 @@
|
|||
|
||||
#include "Vector.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
class GeneralMatrix;
|
||||
|
||||
class ConstGeneralMatrix {
|
||||
|
@ -272,6 +274,40 @@ private:
|
|||
static int md_length;
|
||||
};
|
||||
|
||||
class SVDDecomp {
|
||||
protected:
|
||||
/** Minimum of number of rows and columns of the decomposed
|
||||
* matrix. */
|
||||
const int minmn;
|
||||
/** Singular values. */
|
||||
Vector sigma;
|
||||
/** Orthogonal matrix U. */
|
||||
GeneralMatrix U;
|
||||
/** Orthogonal matrix V^T. */
|
||||
GeneralMatrix VT;
|
||||
/** Convered flag. */
|
||||
bool conv;
|
||||
public:
|
||||
SVDDecomp(const GeneralMatrix& A)
|
||||
: minmn(std::min<int>(A.numRows(), A.numCols())),
|
||||
sigma(minmn),
|
||||
U(A.numRows(), A.numRows()),
|
||||
VT(A.numCols(), A.numCols()),
|
||||
conv(false)
|
||||
{construct(A);}
|
||||
const GeneralMatrix& getU() const
|
||||
{return U;}
|
||||
const GeneralMatrix& getVT() const
|
||||
{return VT;}
|
||||
void solve(const GeneralMatrix& B, GeneralMatrix& X) const;
|
||||
void solve(const Vector& b, Vector& x) const
|
||||
{
|
||||
GeneralMatrix xmat(x.base(), x.length(), 1);
|
||||
solve(GeneralMatrix(b.base(), b.length(), 1), xmat);
|
||||
}
|
||||
private:
|
||||
void construct(const GeneralMatrix& A);
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue