Bytecode: simplify Interpreter::compute_complete()

kalman-mex
Sébastien Villemot 2023-10-20 09:38:03 -04:00
parent 0bfcc6d2f5
commit 093a547684
No known key found for this signature in database
GPG Key ID: 2CECE9350ECEBE4A
2 changed files with 42 additions and 21 deletions

View File

@ -24,6 +24,7 @@
#include <cfenv> #include <cfenv>
#include <type_traits> #include <type_traits>
#include <chrono> #include <chrono>
#include <limits>
#include "Interpreter.hh" #include "Interpreter.hh"
@ -2660,8 +2661,8 @@ Interpreter::compute_complete(bool no_derivatives)
return result; return result;
} }
bool pair<bool, double>
Interpreter::compute_complete(double lambda, double *crit) Interpreter::compute_complete(double lambda)
{ {
double res1_ = 0, res2_ = 0, max_res_ = 0; double res1_ = 0, res2_ = 0, max_res_ = 0;
int max_res_idx_ = 0; int max_res_idx_ = 0;
@ -2678,7 +2679,7 @@ Interpreter::compute_complete(double lambda, double *crit)
if (compute_complete(true)) if (compute_complete(true))
res2_ = res2; res2_ = res2;
else else
return false; return { false, numeric_limits<double>::quiet_NaN() };
} }
else else
{ {
@ -2703,14 +2704,14 @@ Interpreter::compute_complete(double lambda, double *crit)
} }
} }
else else
return false; return { false, numeric_limits<double>::quiet_NaN() };
} }
it_ = periods+y_kmin-1; // Do not leave it_ in inconsistent state it_ = periods+y_kmin-1; // Do not leave it_ in inconsistent state
} }
if (verbosity >= 2) if (verbosity >= 2)
mexPrintf(" lambda=%e, res2=%e\n", lambda, res2_); mexPrintf(" lambda=%e, res2=%e\n", lambda, res2_);
*crit = res2_/2; double crit {res2_/2};
return true; return { true, crit };
} }
bool bool
@ -2722,20 +2723,30 @@ Interpreter::mnbrak(double *ax, double *bx, double *cx, double *fa, double *fb,
auto sign = [](double a, double b) { return b >= 0.0 ? fabs(a) : -fabs(a); }; auto sign = [](double a, double b) { return b >= 0.0 ? fabs(a) : -fabs(a); };
bool success;
if (verbosity >= 2) if (verbosity >= 2)
mexPrintf("bracketing *ax=%f, *bx=%f\n", *ax, *bx); mexPrintf("bracketing *ax=%f, *bx=%f\n", *ax, *bx);
if (!compute_complete(*ax, fa))
tie(success, *fa) = compute_complete(*ax);
if (!success)
return false; return false;
if (!compute_complete(*bx, fb))
tie(success, *fb) = compute_complete(*bx);
if (!success)
return false; return false;
if (*fb > *fa) if (*fb > *fa)
{ {
swap(*ax, *bx); swap(*ax, *bx);
swap(*fa, *fb); swap(*fa, *fb);
} }
*cx = (*bx)+GOLD*(*bx-*ax); *cx = (*bx)+GOLD*(*bx-*ax);
if (!compute_complete(*cx, fc)) tie(success, *fc) = compute_complete(*cx);
if (!success)
return false; return false;
while (*fb > *fc) while (*fb > *fc)
{ {
double r = (*bx-*ax)*(*fb-*fc); double r = (*bx-*ax)*(*fb-*fc);
@ -2746,7 +2757,8 @@ Interpreter::mnbrak(double *ax, double *bx, double *cx, double *fa, double *fb,
double fu; double fu;
if ((*bx-u)*(u-*cx) > 0.0) if ((*bx-u)*(u-*cx) > 0.0)
{ {
if (!compute_complete(u, &fu)) tie(success, fu) = compute_complete(u);
if (!success)
return false; return false;
if (fu < *fc) if (fu < *fc)
{ {
@ -2763,12 +2775,14 @@ Interpreter::mnbrak(double *ax, double *bx, double *cx, double *fa, double *fb,
return true; return true;
} }
u = (*cx)+GOLD*(*cx-*bx); u = (*cx)+GOLD*(*cx-*bx);
if (!compute_complete(u, &fu)) tie(success, fu) = compute_complete(u);
if (!success)
return false; return false;
} }
else if ((*cx-u)*(u-ulim) > 0.0) else if ((*cx-u)*(u-ulim) > 0.0)
{ {
if (!compute_complete(u, &fu)) tie(success, fu) = compute_complete(u);
if (!success)
return false; return false;
if (fu < *fc) if (fu < *fc)
{ {
@ -2777,20 +2791,23 @@ Interpreter::mnbrak(double *ax, double *bx, double *cx, double *fa, double *fb,
u = *cx+GOLD*(*cx-*bx); u = *cx+GOLD*(*cx-*bx);
*fb = *fc; *fb = *fc;
*fc = fu; *fc = fu;
if (!compute_complete(u, &fu)) tie(success, fu) = compute_complete(u);
if (!success)
return false; return false;
} }
} }
else if ((u-ulim)*(ulim-*cx) >= 0.0) else if ((u-ulim)*(ulim-*cx) >= 0.0)
{ {
u = ulim; u = ulim;
if (!compute_complete(u, &fu)) tie(success, fu) = compute_complete(u);
if (!success)
return false; return false;
} }
else else
{ {
u = (*cx)+GOLD*(*cx-*bx); u = (*cx)+GOLD*(*cx-*bx);
if (!compute_complete(u, &fu)) tie(success, fu) = compute_complete(u);
if (!success)
return false; return false;
} }
*ax = *bx; *ax = *bx;
@ -2811,7 +2828,7 @@ Interpreter::golden(double ax, double bx, double cx, double tol, double solve_to
if (verbosity >= 2) if (verbosity >= 2)
mexPrintf("golden\n"); mexPrintf("golden\n");
int iter = 0, max_iter = 100; int iter = 0, max_iter = 100;
double f1, f2, x1, x2; double x1, x2;
double x0 = ax; double x0 = ax;
double x3 = cx; double x3 = cx;
if (fabs(cx-bx) > fabs(bx-ax)) if (fabs(cx-bx) > fabs(bx-ax))
@ -2824,9 +2841,11 @@ Interpreter::golden(double ax, double bx, double cx, double tol, double solve_to
x2 = bx; x2 = bx;
x1 = bx-C*(bx-ax); x1 = bx-C*(bx-ax);
} }
if (!compute_complete(x1, &f1)) auto [success, f1] = compute_complete(x1);
if (!success)
return false; return false;
if (!compute_complete(x2, &f2)) auto [success2, f2] = compute_complete(x2);
if (!success2)
return false; return false;
while (fabs(x3-x0) > tol*(fabs(x1)+fabs(x2)) && f1 > solve_tolf && f2 > solve_tolf while (fabs(x3-x0) > tol*(fabs(x1)+fabs(x2)) && f1 > solve_tolf && f2 > solve_tolf
&& iter < max_iter && abs(x1 - x2) > 1e-4) && iter < max_iter && abs(x1 - x2) > 1e-4)
@ -2837,7 +2856,8 @@ Interpreter::golden(double ax, double bx, double cx, double tol, double solve_to
x1 = x2; x1 = x2;
x2 = R*x1+C*x3; x2 = R*x1+C*x3;
f1 = f2; f1 = f2;
if (!compute_complete(x2, &f2)) tie(success, f2) = compute_complete(x2);
if (!success)
return false; return false;
} }
else else
@ -2846,7 +2866,8 @@ Interpreter::golden(double ax, double bx, double cx, double tol, double solve_to
x2 = x1; x2 = x1;
x1 = R*x2+C*x0; x1 = R*x2+C*x0;
f2 = f1; f2 = f1;
if (!compute_complete(x1, &f1)) tie(success, f1) = compute_complete(x1);
if (!success)
return false; return false;
} }
iter++; iter++;

View File

@ -232,7 +232,7 @@ private:
void compute_block_time(int Per_u_, bool evaluate, bool no_derivatives); void compute_block_time(int Per_u_, bool evaluate, bool no_derivatives);
bool compute_complete(bool no_derivatives); bool compute_complete(bool no_derivatives);
bool compute_complete(double lambda, double *crit); pair<bool, double> compute_complete(double lambda);
public: public:
Interpreter(Evaluate &evaluator_arg, double *params_arg, double *y_arg, double *ya_arg, double *x_arg, double *steady_y_arg, Interpreter(Evaluate &evaluator_arg, double *params_arg, double *y_arg, double *ya_arg, double *x_arg, double *steady_y_arg,