2012-02-28 08:43:28 +01:00
|
|
|
|
/*
|
2022-03-18 18:18:24 +01:00
|
|
|
|
* Copyright © 2010-2022 Dynare Team
|
2012-02-28 08:43:28 +01:00
|
|
|
|
*
|
|
|
|
|
* This file is part of Dynare.
|
|
|
|
|
*
|
|
|
|
|
* Dynare is free software: you can redistribute it and/or modify
|
|
|
|
|
* it under the terms of the GNU General Public License as published by
|
|
|
|
|
* the Free Software Foundation, either version 3 of the License, or
|
|
|
|
|
* (at your option) any later version.
|
|
|
|
|
*
|
|
|
|
|
* Dynare is distributed in the hope that it will be useful,
|
|
|
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
|
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
|
|
|
* GNU General Public License for more details.
|
|
|
|
|
*
|
|
|
|
|
* You should have received a copy of the GNU General Public License
|
2021-06-09 17:33:48 +02:00
|
|
|
|
* along with Dynare. If not, see <https://www.gnu.org/licenses/>.
|
2012-02-28 08:43:28 +01:00
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
/*
|
2012-03-05 23:10:14 +01:00
|
|
|
|
* This mex file computes particles at time t+1 given particles and innovations at time t,
|
2012-02-28 08:43:28 +01:00
|
|
|
|
* using a second order approximation of the nonlinear state space model.
|
|
|
|
|
*/
|
|
|
|
|
|
2019-04-30 15:22:19 +02:00
|
|
|
|
#include <algorithm>
|
2022-03-18 18:18:24 +01:00
|
|
|
|
#include <string>
|
2023-11-29 19:00:21 +01:00
|
|
|
|
#include <tuple>
|
|
|
|
|
#include <vector>
|
2019-04-30 15:22:19 +02:00
|
|
|
|
|
2012-02-28 08:43:28 +01:00
|
|
|
|
#include <dynblas.h>
|
2023-11-29 19:00:21 +01:00
|
|
|
|
#include <dynmex.h>
|
2012-02-28 08:43:28 +01:00
|
|
|
|
|
2019-06-25 15:42:32 +02:00
|
|
|
|
#include <omp.h>
|
2012-02-28 08:43:28 +01:00
|
|
|
|
|
2021-10-20 15:50:34 +02:00
|
|
|
|
/*
|
|
|
|
|
Uncomment the following line to use BLAS instead of loops when computing
|
2021-10-22 17:57:29 +02:00
|
|
|
|
ghx·ŷ and ghu·ε.
|
2021-10-20 15:50:34 +02:00
|
|
|
|
N.B.: Under MATLAB, this only works in single-threaded mode, otherwise one
|
|
|
|
|
gets a crash (because of the incompatibility between Intel and GNU OpenMPs).
|
|
|
|
|
*/
|
2023-11-29 19:00:21 +01:00
|
|
|
|
// #define USE_BLAS_AT_FIRST_ORDER
|
2012-02-28 08:43:28 +01:00
|
|
|
|
|
2019-06-27 17:34:10 +02:00
|
|
|
|
std::tuple<std::vector<int>, std::vector<int>, std::vector<int>>
|
|
|
|
|
set_vector_of_indices(int n, int r)
|
2012-02-28 08:43:28 +01:00
|
|
|
|
{
|
2023-11-29 19:00:21 +01:00
|
|
|
|
int m = n * (n + 1) / 2;
|
2019-06-27 17:34:10 +02:00
|
|
|
|
std::vector<int> v1(m, 0), v2(m, 0), v3(m, 0);
|
2017-05-16 16:30:27 +02:00
|
|
|
|
for (int i = 0, index = 0, jndex = 0; i < n; i++)
|
2012-02-28 08:43:28 +01:00
|
|
|
|
{
|
2017-05-16 16:30:27 +02:00
|
|
|
|
jndex += i;
|
|
|
|
|
for (int j = i; j < n; j++, index++, jndex++)
|
2012-02-28 08:43:28 +01:00
|
|
|
|
{
|
|
|
|
|
v1[index] = i;
|
|
|
|
|
v2[index] = j;
|
2023-11-29 19:00:21 +01:00
|
|
|
|
v3[index] = jndex * r;
|
2012-02-28 08:43:28 +01:00
|
|
|
|
}
|
|
|
|
|
}
|
2023-11-29 19:00:21 +01:00
|
|
|
|
return {v1, v2, v3};
|
2012-02-28 08:43:28 +01:00
|
|
|
|
}
|
|
|
|
|
|
2017-05-16 16:30:27 +02:00
|
|
|
|
void
|
2023-11-29 19:00:21 +01:00
|
|
|
|
ss2Iteration_pruning(double* y2, double* y1, const double* yhat2, const double* yhat1,
|
|
|
|
|
const double* epsilon, const double* ghx, const double* ghu,
|
|
|
|
|
const double* constant, const double* ghxx, const double* ghuu,
|
|
|
|
|
const double* ghxu, const double* ss, blas_int m, blas_int n, blas_int q,
|
|
|
|
|
blas_int s, int number_of_threads)
|
2012-02-28 08:43:28 +01:00
|
|
|
|
{
|
2021-10-20 15:50:34 +02:00
|
|
|
|
#ifdef USE_BLAS_AT_FIRST_ORDER
|
2017-05-16 16:30:27 +02:00
|
|
|
|
const double one = 1.0;
|
|
|
|
|
const blas_int ONE = 1;
|
|
|
|
|
#endif
|
2019-09-11 16:06:35 +02:00
|
|
|
|
auto [ii1, ii2, ii3] = set_vector_of_indices(n, m); // vector indices for ghxx
|
|
|
|
|
auto [jj1, jj2, jj3] = set_vector_of_indices(q, m); // vector indices for ghuu
|
2019-06-25 15:42:32 +02:00
|
|
|
|
#pragma omp parallel for num_threads(number_of_threads)
|
2017-05-16 16:30:27 +02:00
|
|
|
|
for (int particle = 0; particle < s; particle++)
|
|
|
|
|
{
|
2023-11-29 19:00:21 +01:00
|
|
|
|
int particle_ = particle * m;
|
|
|
|
|
int particle__ = particle * n;
|
|
|
|
|
int particle___ = particle * q;
|
2019-04-30 15:22:19 +02:00
|
|
|
|
std::copy_n(constant, m, &y2[particle_]);
|
|
|
|
|
std::copy_n(ss, m, &y1[particle_]);
|
2021-10-20 15:50:34 +02:00
|
|
|
|
#ifdef USE_BLAS_AT_FIRST_ORDER
|
2019-04-30 15:22:19 +02:00
|
|
|
|
dgemv("N", &m, &n, &one, ghx, &m, &yhat2[particle__], &ONE, &one, &y2[particle_], &ONE);
|
|
|
|
|
dgemv("N", &m, &q, &one, ghu, &m, &epsilon[particle___], &ONE, &one, &y2[particle_], &ONE);
|
2017-05-16 16:30:27 +02:00
|
|
|
|
#endif
|
|
|
|
|
for (int variable = 0; variable < m; variable++)
|
|
|
|
|
{
|
|
|
|
|
int variable_ = variable + particle_;
|
2021-10-22 17:57:29 +02:00
|
|
|
|
// +ghx·ŷ₂+ghu·ε
|
2021-10-20 15:50:34 +02:00
|
|
|
|
#ifndef USE_BLAS_AT_FIRST_ORDER
|
2021-10-14 16:18:17 +02:00
|
|
|
|
for (int column = 0, column_ = 0; column < n; column++, column_ += m)
|
2023-11-29 19:00:21 +01:00
|
|
|
|
y2[variable_] += ghx[variable + column_] * yhat2[column + particle__];
|
2021-10-14 16:18:17 +02:00
|
|
|
|
for (int column = 0, column_ = 0; column < q; column++, column_ += m)
|
2023-11-29 19:00:21 +01:00
|
|
|
|
y2[variable_] += ghu[variable + column_] * epsilon[column + particle___];
|
2017-05-16 16:30:27 +02:00
|
|
|
|
#endif
|
2021-10-22 17:57:29 +02:00
|
|
|
|
// +½ghxx·ŷ₁⊗ŷ₁
|
2023-11-29 19:00:21 +01:00
|
|
|
|
for (int i = 0; i < n * (n + 1) / 2; i++)
|
2017-05-16 16:30:27 +02:00
|
|
|
|
{
|
2023-11-29 19:00:21 +01:00
|
|
|
|
int i1 = particle__ + ii1[i];
|
|
|
|
|
int i2 = particle__ + ii2[i];
|
2017-05-16 16:30:27 +02:00
|
|
|
|
if (i1 == i2)
|
2023-11-29 19:00:21 +01:00
|
|
|
|
y2[variable_] += .5 * ghxx[variable + ii3[i]] * yhat1[i1] * yhat1[i1];
|
2017-05-16 16:30:27 +02:00
|
|
|
|
else
|
2023-11-29 19:00:21 +01:00
|
|
|
|
y2[variable_] += ghxx[variable + ii3[i]] * yhat1[i1] * yhat1[i2];
|
2017-05-16 16:30:27 +02:00
|
|
|
|
}
|
2021-10-22 17:57:29 +02:00
|
|
|
|
// +½ghuu·ε⊗ε
|
2023-11-29 19:00:21 +01:00
|
|
|
|
for (int j = 0; j < q * (q + 1) / 2; j++)
|
2017-05-16 16:30:27 +02:00
|
|
|
|
{
|
2023-11-29 19:00:21 +01:00
|
|
|
|
int j1 = particle___ + jj1[j];
|
|
|
|
|
int j2 = particle___ + jj2[j];
|
2017-05-16 16:30:27 +02:00
|
|
|
|
if (j1 == j2)
|
2023-11-29 19:00:21 +01:00
|
|
|
|
y2[variable_] += .5 * ghuu[variable + jj3[j]] * epsilon[j1] * epsilon[j1];
|
2017-05-16 16:30:27 +02:00
|
|
|
|
else
|
2023-11-29 19:00:21 +01:00
|
|
|
|
y2[variable_] += ghuu[variable + jj3[j]] * epsilon[j1] * epsilon[j2];
|
2017-05-16 16:30:27 +02:00
|
|
|
|
}
|
2021-10-22 17:57:29 +02:00
|
|
|
|
// +ghxu·ŷ₁⊗ε
|
2023-11-29 19:00:21 +01:00
|
|
|
|
for (int v = particle__, i = 0; v < particle__ + n; v++)
|
|
|
|
|
for (int s = particle___; s < particle___ + q; s++, i += m)
|
|
|
|
|
y2[variable_] += ghxu[variable + i] * epsilon[s] * yhat1[v];
|
2021-10-20 15:50:34 +02:00
|
|
|
|
#ifndef USE_BLAS_AT_FIRST_ORDER
|
2017-05-16 16:30:27 +02:00
|
|
|
|
for (int column = 0, column_ = 0; column < q; column++, column_ += m)
|
|
|
|
|
{
|
2023-11-29 19:00:21 +01:00
|
|
|
|
int i1 = variable + column_;
|
|
|
|
|
int i2 = column + particle__;
|
|
|
|
|
int i3 = column + particle___;
|
|
|
|
|
y1[variable_] += ghx[i1] * yhat1[i2];
|
|
|
|
|
y1[variable_] += ghu[i1] * epsilon[i3];
|
2017-05-16 16:30:27 +02:00
|
|
|
|
}
|
2023-11-29 19:00:21 +01:00
|
|
|
|
for (int column = q, column_ = q * m; column < n; column++, column_ += m)
|
|
|
|
|
y1[variable_] += ghx[variable + column_] * yhat1[column + particle__];
|
2017-05-16 16:30:27 +02:00
|
|
|
|
#endif
|
|
|
|
|
}
|
2021-10-20 15:50:34 +02:00
|
|
|
|
#ifdef USE_BLAS_AT_FIRST_ORDER
|
2019-04-30 15:22:19 +02:00
|
|
|
|
dgemv("N", &m, &n, &one, &ghx[0], &m, &yhat1[particle__], &ONE, &one, &y1[particle_], &ONE);
|
2023-11-29 19:00:21 +01:00
|
|
|
|
dgemv("N", &m, &q, &one, &ghu[0], &m, &epsilon[particle___], &ONE, &one, &y1[particle_],
|
|
|
|
|
&ONE);
|
2017-05-16 16:30:27 +02:00
|
|
|
|
#endif
|
|
|
|
|
}
|
2012-02-28 08:43:28 +01:00
|
|
|
|
}
|
|
|
|
|
|
2017-05-16 16:30:27 +02:00
|
|
|
|
void
|
2023-11-29 19:00:21 +01:00
|
|
|
|
ss2Iteration(double* y, const double* yhat, const double* epsilon, const double* ghx,
|
|
|
|
|
const double* ghu, const double* constant, const double* ghxx, const double* ghuu,
|
|
|
|
|
const double* ghxu, blas_int m, blas_int n, blas_int q, blas_int s,
|
|
|
|
|
int number_of_threads)
|
2012-02-28 08:43:28 +01:00
|
|
|
|
{
|
2021-10-20 15:50:34 +02:00
|
|
|
|
#ifdef USE_BLAS_AT_FIRST_ORDER
|
2017-05-16 16:30:27 +02:00
|
|
|
|
const double one = 1.0;
|
|
|
|
|
const blas_int ONE = 1;
|
|
|
|
|
#endif
|
2019-09-11 16:06:35 +02:00
|
|
|
|
auto [ii1, ii2, ii3] = set_vector_of_indices(n, m); // vector indices for ghxx
|
|
|
|
|
auto [jj1, jj2, jj3] = set_vector_of_indices(q, m); // vector indices for ghuu
|
2019-06-25 15:42:32 +02:00
|
|
|
|
#pragma omp parallel for num_threads(number_of_threads)
|
2017-05-16 16:30:27 +02:00
|
|
|
|
for (int particle = 0; particle < s; particle++)
|
|
|
|
|
{
|
2023-11-29 19:00:21 +01:00
|
|
|
|
int particle_ = particle * m;
|
|
|
|
|
int particle__ = particle * n;
|
|
|
|
|
int particle___ = particle * q;
|
2019-04-30 15:22:19 +02:00
|
|
|
|
std::copy_n(constant, m, &y[particle_]);
|
2021-10-20 15:50:34 +02:00
|
|
|
|
#ifdef USE_BLAS_AT_FIRST_ORDER
|
2019-04-30 15:22:19 +02:00
|
|
|
|
dgemv("N", &m, &n, &one, ghx, &m, &yhat[particle__], &ONE, &one, &y[particle_], &ONE);
|
|
|
|
|
dgemv("N", &m, &q, &one, ghu, &m, &epsilon[particle___], &ONE, &one, &y[particle_], &ONE);
|
2017-05-16 16:30:27 +02:00
|
|
|
|
#endif
|
|
|
|
|
for (int variable = 0; variable < m; variable++)
|
|
|
|
|
{
|
|
|
|
|
int variable_ = variable + particle_;
|
2021-10-22 17:57:29 +02:00
|
|
|
|
// +ghx·ŷ+ghu·ε
|
2021-10-20 15:50:34 +02:00
|
|
|
|
#ifndef USE_BLAS_AT_FIRST_ORDER
|
2021-10-14 16:18:17 +02:00
|
|
|
|
for (int column = 0, column_ = 0; column < n; column++, column_ += m)
|
2023-11-29 19:00:21 +01:00
|
|
|
|
y[variable_] += ghx[variable + column_] * yhat[column + particle__];
|
2021-10-14 16:18:17 +02:00
|
|
|
|
for (int column = 0, column_ = 0; column < q; column++, column_ += m)
|
2023-11-29 19:00:21 +01:00
|
|
|
|
y[variable_] += ghu[variable + column_] * epsilon[column + particle___];
|
2017-05-16 16:30:27 +02:00
|
|
|
|
#endif
|
2021-10-22 17:57:29 +02:00
|
|
|
|
// +½ghxx·ŷ⊗ŷ
|
2023-11-29 19:00:21 +01:00
|
|
|
|
for (int i = 0; i < n * (n + 1) / 2; i++)
|
2017-05-16 16:30:27 +02:00
|
|
|
|
{
|
2023-11-29 19:00:21 +01:00
|
|
|
|
int i1 = particle__ + ii1[i];
|
|
|
|
|
int i2 = particle__ + ii2[i];
|
2017-05-16 16:30:27 +02:00
|
|
|
|
if (i1 == i2)
|
2023-11-29 19:00:21 +01:00
|
|
|
|
y[variable_] += .5 * ghxx[variable + ii3[i]] * yhat[i1] * yhat[i1];
|
2017-05-16 16:30:27 +02:00
|
|
|
|
else
|
2023-11-29 19:00:21 +01:00
|
|
|
|
y[variable_] += ghxx[variable + ii3[i]] * yhat[i1] * yhat[i2];
|
2017-05-16 16:30:27 +02:00
|
|
|
|
}
|
2021-10-22 17:57:29 +02:00
|
|
|
|
// +½ghuu·ε⊗ε
|
2023-11-29 19:00:21 +01:00
|
|
|
|
for (int j = 0; j < q * (q + 1) / 2; j++)
|
2017-05-16 16:30:27 +02:00
|
|
|
|
{
|
2023-11-29 19:00:21 +01:00
|
|
|
|
int j1 = particle___ + jj1[j];
|
|
|
|
|
int j2 = particle___ + jj2[j];
|
2017-05-16 16:30:27 +02:00
|
|
|
|
if (j1 == j2)
|
2023-11-29 19:00:21 +01:00
|
|
|
|
y[variable_] += .5 * ghuu[variable + jj3[j]] * epsilon[j1] * epsilon[j1];
|
2017-05-16 16:30:27 +02:00
|
|
|
|
else
|
2023-11-29 19:00:21 +01:00
|
|
|
|
y[variable_] += ghuu[variable + jj3[j]] * epsilon[j1] * epsilon[j2];
|
2017-05-16 16:30:27 +02:00
|
|
|
|
}
|
2021-10-22 17:57:29 +02:00
|
|
|
|
// +ghxu·ŷ⊗ε
|
2023-11-29 19:00:21 +01:00
|
|
|
|
for (int v = particle__, i = 0; v < particle__ + n; v++)
|
|
|
|
|
for (int s = particle___; s < particle___ + q; s++, i += m)
|
|
|
|
|
y[variable_] += ghxu[variable + i] * epsilon[s] * yhat[v];
|
2017-05-16 16:30:27 +02:00
|
|
|
|
}
|
|
|
|
|
}
|
2012-02-28 08:43:28 +01:00
|
|
|
|
}
|
|
|
|
|
|
2017-05-16 16:30:27 +02:00
|
|
|
|
void
|
2023-11-29 19:00:21 +01:00
|
|
|
|
mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
|
2012-02-28 08:43:28 +01:00
|
|
|
|
{
|
|
|
|
|
/*
|
2019-04-30 15:22:19 +02:00
|
|
|
|
prhs[0] yhat [double] n×s array, time t particles.
|
|
|
|
|
prhs[1] epsilon [double] q×s array, time t innovations.
|
|
|
|
|
prhs[2] ghx [double] m×n array, first order reduced form.
|
|
|
|
|
prhs[3] ghu [double] m×q array, first order reduced form.
|
2023-11-29 19:00:21 +01:00
|
|
|
|
prhs[4] constant [double] m×1 array, deterministic steady state + second order correction
|
|
|
|
|
for the union of the states and observed variables. prhs[5] ghxx [double] m×n² array,
|
|
|
|
|
second order reduced form. prhs[6] ghuu [double] m×q² array, second order reduced
|
|
|
|
|
form. prhs[7] ghxu [double] m×nq array, second order reduced form. prhs[8] yhat_
|
|
|
|
|
[double] [OPTIONAL] n×s array, time t particles (pruning additional latent variables). prhs[9]
|
|
|
|
|
ss [double] [OPTIONAL] m×1 array, steady state for the union of the states and the
|
|
|
|
|
observed variables (needed for the pruning mode).
|
2019-12-20 14:50:19 +01:00
|
|
|
|
|
2021-10-20 15:33:47 +02:00
|
|
|
|
prhs[8 or 10] [double] num of threads
|
2019-06-28 17:48:39 +02:00
|
|
|
|
|
2021-10-21 16:34:57 +02:00
|
|
|
|
plhs[0] y [double] m×s array, time t+1 particles.
|
|
|
|
|
plhs[1] y_ [double] m×s array, time t+1 particles for the pruning latent variables.
|
2012-02-28 08:43:28 +01:00
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
// Check the number of input and output.
|
2019-04-30 15:22:19 +02:00
|
|
|
|
if (nrhs != 9 && nrhs != 11)
|
2019-06-28 17:48:39 +02:00
|
|
|
|
mexErrMsgTxt("Nine or eleven input arguments are required.");
|
2019-04-30 15:22:19 +02:00
|
|
|
|
|
2012-03-05 23:10:14 +01:00
|
|
|
|
if (nlhs > 2)
|
2019-04-30 15:22:19 +02:00
|
|
|
|
mexErrMsgTxt("Too many output arguments.");
|
|
|
|
|
|
2023-11-29 19:00:21 +01:00
|
|
|
|
auto check_input_real_dense_array = [=](int i) {
|
2022-03-18 18:18:24 +01:00
|
|
|
|
if (!mxIsDouble(prhs[i]) || mxIsComplex(prhs[i]) || mxIsSparse(prhs[i]))
|
2023-11-29 19:00:21 +01:00
|
|
|
|
mexErrMsgTxt(
|
|
|
|
|
("Input argument " + std::to_string(i + 1) + " should be a real dense array").c_str());
|
2022-03-18 18:18:24 +01:00
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < 8; i++)
|
|
|
|
|
check_input_real_dense_array(i);
|
|
|
|
|
|
2012-02-28 08:43:28 +01:00
|
|
|
|
// Get dimensions.
|
2019-12-20 14:50:19 +01:00
|
|
|
|
size_t n = mxGetM(prhs[0]); // Number of states.
|
|
|
|
|
size_t s = mxGetN(prhs[0]); // Number of particles.
|
|
|
|
|
size_t q = mxGetM(prhs[1]); // Number of innovations.
|
|
|
|
|
size_t m = mxGetM(prhs[2]); // Number of elements in the union of states and observed variables.
|
2012-02-28 08:43:28 +01:00
|
|
|
|
// Check the dimensions.
|
2023-11-29 19:00:21 +01:00
|
|
|
|
if (s != mxGetN(prhs[1]) // Number of columns for epsilon
|
2019-04-30 15:22:19 +02:00
|
|
|
|
|| n != mxGetN(prhs[2]) // Number of columns for ghx
|
|
|
|
|
|| m != mxGetM(prhs[3]) // Number of rows for ghu
|
|
|
|
|
|| q != mxGetN(prhs[3]) // Number of columns for ghu
|
2023-11-29 19:00:21 +01:00
|
|
|
|
|| m != mxGetM(prhs[4]) // Number of rows for 2nd order constant correction + deterministic
|
|
|
|
|
// steady state
|
2019-04-30 15:22:19 +02:00
|
|
|
|
|| m != mxGetM(prhs[5]) // Number of rows for ghxx
|
2023-11-29 19:00:21 +01:00
|
|
|
|
|| n * n != mxGetN(prhs[5]) // Number of columns for ghxx
|
|
|
|
|
|| m != mxGetM(prhs[6]) // Number of rows for ghuu
|
|
|
|
|
|| q * q != mxGetN(prhs[6]) // Number of columns for ghuu
|
|
|
|
|
|| m != mxGetM(prhs[7]) // Number of rows for ghxu
|
|
|
|
|
|| n * q != mxGetN(prhs[7])) // Number of rows for ghxu
|
2019-04-30 15:22:19 +02:00
|
|
|
|
mexErrMsgTxt("Input dimension mismatch!.");
|
2017-05-16 16:30:27 +02:00
|
|
|
|
if (nrhs > 9)
|
2022-03-18 18:18:24 +01:00
|
|
|
|
{
|
|
|
|
|
for (int i = 8; i < 10; i++)
|
|
|
|
|
check_input_real_dense_array(i);
|
|
|
|
|
|
2023-11-29 19:00:21 +01:00
|
|
|
|
if (n != mxGetM(prhs[8]) // Number of rows for yhat_
|
|
|
|
|
|| s != mxGetN(prhs[8]) // Number of columns for yhat_
|
2022-03-18 18:18:24 +01:00
|
|
|
|
|| m != mxGetM(prhs[9])) // Number of rows for ss
|
|
|
|
|
mexErrMsgTxt("Input dimension mismatch!.");
|
|
|
|
|
}
|
2019-04-30 15:22:19 +02:00
|
|
|
|
|
2012-02-28 08:43:28 +01:00
|
|
|
|
// Get Input arrays.
|
2023-11-29 19:00:21 +01:00
|
|
|
|
const double* yhat = mxGetPr(prhs[0]);
|
|
|
|
|
const double* epsilon = mxGetPr(prhs[1]);
|
|
|
|
|
const double* ghx = mxGetPr(prhs[2]);
|
|
|
|
|
const double* ghu = mxGetPr(prhs[3]);
|
|
|
|
|
const double* constant = mxGetPr(prhs[4]);
|
|
|
|
|
const double* ghxx = mxGetPr(prhs[5]);
|
|
|
|
|
const double* ghuu = mxGetPr(prhs[6]);
|
|
|
|
|
const double* ghxu = mxGetPr(prhs[7]);
|
2019-04-30 15:22:19 +02:00
|
|
|
|
const double *yhat_ = nullptr, *ss = nullptr;
|
2017-05-16 16:30:27 +02:00
|
|
|
|
if (nrhs > 9)
|
2012-02-28 08:43:28 +01:00
|
|
|
|
{
|
|
|
|
|
yhat_ = mxGetPr(prhs[8]);
|
|
|
|
|
ss = mxGetPr(prhs[9]);
|
2012-03-05 23:10:14 +01:00
|
|
|
|
}
|
2022-03-18 18:18:24 +01:00
|
|
|
|
|
2023-11-29 19:00:21 +01:00
|
|
|
|
const mxArray* numthreads_mx = prhs[nrhs == 9 ? 8 : 10];
|
2022-03-18 18:18:24 +01:00
|
|
|
|
if (!(mxIsScalar(numthreads_mx) && mxIsNumeric(numthreads_mx)))
|
|
|
|
|
mexErrMsgTxt("Last argument should be a numeric scalar");
|
|
|
|
|
int numthreads = static_cast<int>(mxGetScalar(numthreads_mx));
|
|
|
|
|
if (numthreads <= 0)
|
|
|
|
|
mexErrMsgTxt("Last argument should be a positive integer");
|
|
|
|
|
|
2021-10-20 15:50:34 +02:00
|
|
|
|
#if defined(USE_BLAS_AT_FIRST_ORDER) && defined(MATLAB_MEX_FILE)
|
|
|
|
|
if (numthreads != 1)
|
|
|
|
|
mexErrMsgTxt("Parallelization is not possible when compiled with USE_BLAS_AT_FIRST_ORDER.");
|
|
|
|
|
#endif
|
2017-05-16 16:30:27 +02:00
|
|
|
|
if (nrhs == 9)
|
2012-02-28 08:43:28 +01:00
|
|
|
|
{
|
|
|
|
|
plhs[0] = mxCreateDoubleMatrix(m, s, mxREAL);
|
2023-11-29 19:00:21 +01:00
|
|
|
|
double* y = mxGetPr(plhs[0]);
|
|
|
|
|
ss2Iteration(y, yhat, epsilon, ghx, ghu, constant, ghxx, ghuu, ghxu, static_cast<int>(m),
|
|
|
|
|
static_cast<int>(n), static_cast<int>(q), static_cast<int>(s), numthreads);
|
2012-02-28 08:43:28 +01:00
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
plhs[0] = mxCreateDoubleMatrix(m, s, mxREAL);
|
|
|
|
|
plhs[1] = mxCreateDoubleMatrix(m, s, mxREAL);
|
2023-11-29 19:00:21 +01:00
|
|
|
|
double* y = mxGetPr(plhs[0]);
|
|
|
|
|
double* y_ = mxGetPr(plhs[1]);
|
|
|
|
|
ss2Iteration_pruning(y, y_, yhat, yhat_, epsilon, ghx, ghu, constant, ghxx, ghuu, ghxu, ss,
|
|
|
|
|
static_cast<int>(m), static_cast<int>(n), static_cast<int>(q),
|
|
|
|
|
static_cast<int>(s), numthreads);
|
2012-02-28 08:43:28 +01:00
|
|
|
|
}
|
|
|
|
|
}
|