From 3bd3c78e0eac4a81470f7753a1d671cf40ceda1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= Date: Fri, 4 Jun 2021 12:56:01 +0200 Subject: [PATCH] A_times_B_kronecker_C MEX: rewrite in Fortran --- mex/build/kronecker.am | 7 +- .../kronecker/A_times_B_kronecker_C.cc | 114 ----------------- .../kronecker/A_times_B_kronecker_C.f08 | 115 ++++++++++++++++++ 3 files changed, 121 insertions(+), 115 deletions(-) delete mode 100644 mex/sources/kronecker/A_times_B_kronecker_C.cc create mode 100644 mex/sources/kronecker/A_times_B_kronecker_C.f08 diff --git a/mex/build/kronecker.am b/mex/build/kronecker.am index 734e1ec53..ac07bb602 100644 --- a/mex/build/kronecker.am +++ b/mex/build/kronecker.am @@ -1,7 +1,7 @@ mex_PROGRAMS = sparse_hessian_times_B_kronecker_C A_times_B_kronecker_C nodist_sparse_hessian_times_B_kronecker_C_SOURCES = sparse_hessian_times_B_kronecker_C.cc -nodist_A_times_B_kronecker_C_SOURCES = A_times_B_kronecker_C.cc +nodist_A_times_B_kronecker_C_SOURCES = A_times_B_kronecker_C.f08 matlab_mex.F08 blas_lapack.F08 sparse_hessian_times_B_kronecker_C_CXXFLAGS = $(AM_CXXFLAGS) -fopenmp sparse_hessian_times_B_kronecker_C_LDFLAGS = $(AM_LDFLAGS) $(OPENMP_LDFLAGS) @@ -11,3 +11,8 @@ CLEANFILES = $(nodist_sparse_hessian_times_B_kronecker_C_SOURCES) $(nodist_A_tim %.cc: $(top_srcdir)/../../sources/kronecker/%.cc $(LN_S) -f $< $@ + +A_times_B_kronecker_C.o : matlab_mex.mod lapack.mod + +%.f08: $(top_srcdir)/../../sources/kronecker/%.f08 + $(LN_S) -f $< $@ diff --git a/mex/sources/kronecker/A_times_B_kronecker_C.cc b/mex/sources/kronecker/A_times_B_kronecker_C.cc deleted file mode 100644 index 95d0f7268..000000000 --- a/mex/sources/kronecker/A_times_B_kronecker_C.cc +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Copyright © 2007-2020 Dynare Team - * - * This file is part of Dynare. - * - * Dynare is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * Dynare is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with Dynare. If not, see . - */ - -/* - * This mex file computes A·(B⊗C) or A·(B⊗B) without explicitly building B⊗C or B⊗B, so that - * one can consider large matrices B and/or C. - */ - -#include -#include - -void -full_A_times_kronecker_B_C(const double *A, const double *B, const double *C, double *D, - blas_int mA, blas_int nA, blas_int mB, blas_int nB, blas_int mC, blas_int nC) -{ - const blas_int shiftA = mA*mC; - const blas_int shiftD = mA*nC; - blas_int kd = 0, ka = 0; - double one = 1.0; - for (blas_int col = 0; col < nB; col++) - { - ka = 0; - for (blas_int row = 0; row < mB; row++) - { - dgemm("N", "N", &mA, &nC, &mC, &B[mB*col+row], &A[ka], &mA, C, &mC, &one, &D[kd], &mA); - ka += shiftA; - } - kd += shiftD; - } -} - -void -full_A_times_kronecker_B_B(const double *A, const double *B, double *D, blas_int mA, blas_int nA, blas_int mB, blas_int nB) -{ - const blas_int shiftA = mA*mB; - const blas_int shiftD = mA*nB; - blas_int kd = 0, ka = 0; - double one = 1.0; - for (blas_int col = 0; col < nB; col++) - { - ka = 0; - for (blas_int row = 0; row < mB; row++) - { - dgemm("N", "N", &mA, &nB, &mB, &B[mB*col+row], &A[ka], &mA, B, &mB, &one, &D[kd], &mA); - ka += shiftA; - } - kd += shiftD; - } -} - -void -mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) -{ - // Check input and output: - if (nrhs > 3 || nrhs < 2 || nlhs != 1) - { - mexErrMsgTxt("A_times_B_kronecker_C takes 2 or 3 input arguments and provides 1 output argument."); - return; // Needed to shut up some GCC warnings - } - - // Get & Check dimensions (columns and rows): - size_t mA = mxGetM(prhs[0]); - size_t nA = mxGetN(prhs[0]); - size_t mB = mxGetM(prhs[1]); - size_t nB = mxGetN(prhs[1]); - size_t mC, nC; - if (nrhs == 3) // A·(B⊗C) is to be computed. - { - mC = mxGetM(prhs[2]); - nC = mxGetN(prhs[2]); - if (mB*mC != nA) - mexErrMsgTxt("Input dimension error!"); - } - else // A·(B⊗B) is to be computed. - { - if (mB*mB != nA) - mexErrMsgTxt("Input dimension error!"); - } - // Get input matrices: - const double *A = mxGetPr(prhs[0]); - const double *B = mxGetPr(prhs[1]); - const double *C{nullptr}; - if (nrhs == 3) - C = mxGetPr(prhs[2]); - - // Initialization of the ouput: - if (nrhs == 3) - plhs[0] = mxCreateDoubleMatrix(mA, nB*nC, mxREAL); - else - plhs[0] = mxCreateDoubleMatrix(mA, nB*nB, mxREAL); - double *D = mxGetPr(plhs[0]); - - // Computational part: - if (nrhs == 2) - full_A_times_kronecker_B_B(A, B, D, mA, nA, mB, nB); - else - full_A_times_kronecker_B_C(A, B, C, D, mA, nA, mB, nB, mC, nC); -} diff --git a/mex/sources/kronecker/A_times_B_kronecker_C.f08 b/mex/sources/kronecker/A_times_B_kronecker_C.f08 new file mode 100644 index 000000000..5c91fe4dd --- /dev/null +++ b/mex/sources/kronecker/A_times_B_kronecker_C.f08 @@ -0,0 +1,115 @@ +! This MEX file computes A·(B⊗C) or A·(B⊗B) without explicitly building B⊗C or +! B⊗B, so that one can consider large matrices B and/or C. + +! Copyright © 2007-2021 Dynare Team +! +! This file is part of Dynare. +! +! Dynare is free software: you can redistribute it and/or modify +! it under the terms of the GNU General Public License as published by +! the Free Software Foundation, either version 3 of the License, or +! (at your option) any later version. +! +! Dynare is distributed in the hope that it will be useful, +! but WITHOUT ANY WARRANTY; without even the implied warranty of +! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +! GNU General Public License for more details. +! +! You should have received a copy of the GNU General Public License +! along with Dynare. If not, see . + +subroutine mexFunction(nlhs, plhs, nrhs, prhs) bind(c, name='mexFunction') + use iso_fortran_env, only: real64 + use iso_c_binding, only: c_int + use matlab_mex + use blas + implicit none + + type(c_ptr), dimension(*), intent(in), target :: prhs + type(c_ptr), dimension(*), intent(out) :: plhs + integer(c_int), intent(in), value :: nlhs, nrhs + + integer(c_size_t) :: mA, nA, mB, nB, mC, nC + real(real64), dimension(:, :), pointer, contiguous :: A, B, C, D + + if (nrhs > 3 .or. nrhs < 2 .or. nlhs /= 1) then + call mexErrMsgTxt("A_times_B_kronecker_C takes 2 or 3 input arguments and provides 1 output argument") + end if + + if (.not. mxIsDouble(prhs(1)) .or. mxIsComplex(prhs(1)) & + .or. .not. mxIsDouble(prhs(2)) .or. mxIsComplex(prhs(2))) then + call mexErrMsgTxt("A_times_B_kronecker_C: first two arguments should be real matrices") + end if + mA = mxGetM(prhs(1)) + nA = mxGetN(prhs(1)) + mB = mxGetM(prhs(2)) + nB = mxGetN(prhs(2)) + A(1:mA,1:nA) => mxGetPr(prhs(1)) + B(1:mB,1:nB) => mxGetPr(prhs(2)) + + if (nrhs == 3) then + ! A·(B⊗C) is to be computed. + if (.not. mxIsDouble(prhs(3)) .or. mxIsComplex(prhs(3))) then + call mexErrMsgTxt("A_times_B_kronecker_C: third argument should be a real matrix") + end if + mC = mxGetM(prhs(3)) + nC = mxGetN(prhs(3)) + if (mB*mC /= nA) then + call mexErrMsgTxt("Input dimension error!") + end if + + C(1:mC,1:nC) => mxGetPr(prhs(3)) + + plhs(1) = mxCreateDoubleMatrix(mA, nB*nC, mxREAL) + D(1:mA,1:nB*nC) => mxGetPr(plhs(1)) + + call full_A_times_kronecker_B_C + else + ! A·(B⊗B) is to be computed. + if (mB*mB /= nA) then + call mexErrMsgTxt("Input dimension error!") + end if + + plhs(1) = mxCreateDoubleMatrix(mA, nB*nB, mxREAL) + D(1:mA,1:nB*nB) => mxGetPr(plhs(1)) + + call full_A_times_kronecker_B_B + end if + +contains + ! Computes D=A·(B⊗C) + subroutine full_A_times_kronecker_B_C + integer(c_size_t) :: i, j, ka, kd + + kd = 1 + do j = 1,nB + ka = 1 + do i = 1,mB + ! D(:,kd:kd+nC) += B(i,j)·A(:,ka:ka+mC)·C + call dgemm("N", "N", int(mA, blint), int(nC, blint), int(mC, blint), B(i,j), & + A(:,ka:ka+mC), int(mA, blint), C, int(mC, blint), 1._real64, & + D(:,kd:kd+nC), int(mA, blint)) + ka = ka + mC + end do + kd = kd + nC + end do + end subroutine full_A_times_kronecker_B_C + + ! Computes D=A·(B⊗B) + subroutine full_A_times_kronecker_B_B + integer(c_size_t) :: i, j, ka, kd + + kd = 1 + do j = 1,nB + ka = 1 + do i = 1,mB + ! D(:,kd:kd+nB) += B(i,j)·A(:,ka:ka+mB)·B + call dgemm("N", "N", int(mA, blint), int(nB, blint), int(mB, blint), B(i,j), & + A(:,ka:ka+mB), int(mA, blint), B, int(mB, blint), 1._real64, & + D(:,kd:kd+nB), int(mA, blint)) + ka = ka + mB + end do + kd = kd + nB + end do + end subroutine full_A_times_kronecker_B_B +end subroutine mexFunction