Bytecode: fix bug in sparse matrix multiplication routines

Since those routines compute Aᵀ·B, the number of rows of the output is equal to
the number of columns of A.
trustregion
Sébastien Villemot 2022-02-28 12:18:58 +01:00
parent 0147faad5f
commit a7cc466285
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
1 changed files with 3 additions and 6 deletions

View File

@ -1775,13 +1775,12 @@ mxArray *
dynSparseMatrix::mult_SAT_B(const mxArray *A_m, const mxArray *B_m)
{
size_t n_A = mxGetN(A_m);
size_t m_A = mxGetM(A_m);
mwIndex *A_i = mxGetIr(A_m);
mwIndex *A_j = mxGetJc(A_m);
double *A_d = mxGetPr(A_m);
size_t n_B = mxGetN(B_m);
double *B_d = mxGetPr(B_m);
mxArray *C_m = mxCreateDoubleMatrix(m_A, n_B, mxREAL);
mxArray *C_m = mxCreateDoubleMatrix(n_A, n_B, mxREAL);
double *C_d = mxGetPr(C_m);
for (int j = 0; j < static_cast<int>(n_B); j++)
for (unsigned int i = 0; i < n_A; i++)
@ -1802,14 +1801,13 @@ mxArray *
dynSparseMatrix::Sparse_mult_SAT_B(const mxArray *A_m, const mxArray *B_m)
{
size_t n_A = mxGetN(A_m);
size_t m_A = mxGetM(A_m);
mwIndex *A_i = mxGetIr(A_m);
mwIndex *A_j = mxGetJc(A_m);
double *A_d = mxGetPr(A_m);
size_t n_B = mxGetN(B_m);
size_t m_B = mxGetM(B_m);
double *B_d = mxGetPr(B_m);
mxArray *C_m = mxCreateSparse(m_A, n_B, m_A*n_B, mxREAL);
mxArray *C_m = mxCreateSparse(n_A, n_B, n_A*n_B, mxREAL);
mwIndex *C_i = mxGetIr(C_m);
mwIndex *C_j = mxGetJc(C_m);
double *C_d = mxGetPr(C_m);
@ -1847,7 +1845,6 @@ mxArray *
dynSparseMatrix::Sparse_mult_SAT_SB(const mxArray *A_m, const mxArray *B_m)
{
size_t n_A = mxGetN(A_m);
size_t m_A = mxGetM(A_m);
mwIndex *A_i = mxGetIr(A_m);
mwIndex *A_j = mxGetJc(A_m);
double *A_d = mxGetPr(A_m);
@ -1855,7 +1852,7 @@ dynSparseMatrix::Sparse_mult_SAT_SB(const mxArray *A_m, const mxArray *B_m)
mwIndex *B_i = mxGetIr(B_m);
mwIndex *B_j = mxGetJc(B_m);
double *B_d = mxGetPr(B_m);
mxArray *C_m = mxCreateSparse(m_A, n_B, m_A*n_B, mxREAL);
mxArray *C_m = mxCreateSparse(n_A, n_B, n_A*n_B, mxREAL);
mwIndex *C_i = mxGetIr(C_m);
mwIndex *C_j = mxGetJc(C_m);
double *C_d = mxGetPr(C_m);