From 4206cabe529350cc35a2809f2359f65bee59d7ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Adjemia=20=28Scylla=29?= Date: Fri, 14 Dec 2018 21:28:59 +0100 Subject: [PATCH] Streamlined gamrnd algorithms. --- matlab/distributions/+gamrnd/ahrens_dieter.m | 44 +++++------- matlab/distributions/+gamrnd/berman.m | 9 +-- matlab/distributions/+gamrnd/best_1978.m | 42 +++++------ matlab/distributions/+gamrnd/best_1983.m | 70 ++++++------------- matlab/distributions/+gamrnd/johnk.m | 9 +-- matlab/distributions/+gamrnd/knuth.m | 27 +++---- .../distributions/+gamrnd/weibull_rejection.m | 1 - 7 files changed, 68 insertions(+), 134 deletions(-) diff --git a/matlab/distributions/+gamrnd/ahrens_dieter.m b/matlab/distributions/+gamrnd/ahrens_dieter.m index a4d5fac9e..811f91e72 100644 --- a/matlab/distributions/+gamrnd/ahrens_dieter.m +++ b/matlab/distributions/+gamrnd/ahrens_dieter.m @@ -1,6 +1,6 @@ function g = ahrens_dieter(a, b) -% Returns gamma variates, see Devroye (1986) page 410. +% Returns gamma variates, see Devroye (1986) page 425. % % INPUTS % - a [double] n*1 vector, first hyperparameter. @@ -30,34 +30,24 @@ nn = length(a); mm = nn; bb = (exp(1)+a)/exp(1); cc = 1./a; -INDEX = 1:mm; -index = INDEX; -UW = NaN(nn,2); -V = NaN(nn,1); -X = NaN(nn,1); +index = 1:nn; +U = NaN(nn,1); +W = NaN(nn,1); +V = NaN(nn,1); +X = NaN(nn,1); while mm - UW(index,:) = rand(mm,2); - V(index) = UW(index,1).*bb(index); - state1 = find(V(index)<=1); - state2 = find(V(index)>1); - ID = []; - if ~isempty(state1) - X(index(state1)) = V(index(state1)).^cc(index(state1)); - ID = INDEX(index(state1(UW(index(state1),2)>exp(-X(index(state1)))))); - end - if ~isempty(state2) - X(index(state2)) = -log(cc(index(state2)).*(bb(index(state2))-V(index(state2)))); - if isempty(ID) - ID = INDEX(index(state2(UW(index(state2),2)>X(index(state2)).^(a(index(state2))-1)))); - else - ID = [ID, INDEX(index(state2(UW(index(state2),2)>X(index(state2)).^(a(index(state2))-1))))]; - end - end - mm = length(ID); - if mm - index = ID; - end + U(index) = rand(mm,1); + W(index) = rand(mm,1); + V(index) = U(index).*bb(index); + id1 = index(V(index)<=1); + id2 = setdiff(index, id1); + X(id1) = V(id1).^cc(id1); + id3 = id1(W(id1)>exp(-X(id1))); + X(id2) = -log(cc(id2).*(bb(id2)-V(id2))); + id4 = id2(W(id2)>X(id2).^(a(id2)-1)); + index = [id3, id4]; + mm = length(index); end g = X.*b ; diff --git a/matlab/distributions/+gamrnd/berman.m b/matlab/distributions/+gamrnd/berman.m index a3f330d20..91fe30761 100644 --- a/matlab/distributions/+gamrnd/berman.m +++ b/matlab/distributions/+gamrnd/berman.m @@ -40,13 +40,8 @@ while mm UV(index,:) = rand(mm,2); X(index) = UV(index,1).^aa(index); Y(index) = UV(index,2).^cc(index); - id = find(X+Y>1); - if isempty(id) - mm = 0; - else - index = INDEX(id); - mm = length(index); - end + index = index(X(index)+Y(index)>1); + mm = length(index); end Z = gamrnd(2*ones(nn,1), ones(nn,1)); diff --git a/matlab/distributions/+gamrnd/best_1978.m b/matlab/distributions/+gamrnd/best_1978.m index 8c2d9b9a2..e2820a4d9 100644 --- a/matlab/distributions/+gamrnd/best_1978.m +++ b/matlab/distributions/+gamrnd/best_1978.m @@ -1,6 +1,6 @@ function g = best_1978(a ,b) -% Returns gamma variates, see Devroye (1986) page 410. +% Returns gamma variates, see Devroye (1986) page 410 and Best (1978). % % INPUTS % - a [double] n*1 vector, first hyperparameter. @@ -30,34 +30,24 @@ nn = length(a); mm = nn; bb = a-1; cc = 3*a-.75; -UV = NaN(nn,2); -Y = NaN(nn,1); -X = NaN(nn,1); -Z = NaN(nn,1); -W = NaN(nn,1); +U = NaN(nn,1); +Y = NaN(nn,1); +X = NaN(nn,1); +Z = NaN(nn,1); +W = NaN(nn,1); index = 1:nn; -INDEX = index; while mm - UV(index,:) = rand(mm,2); - W(index) = UV(index,1).*(1-UV(index,1)); - Y(index) = sqrt(cc(index)./W(index)).*(UV(index,1)-.5); - X(index) = bb(index)+Y(index); - jndex = index(X(index)>=0); - Jndex = setdiff(index,jndex); - if ~isempty(jndex) - Z(jndex) = 64*W(jndex).*W(jndex).*W(jndex).*UV(jndex,2).*UV(jndex,2); - kndex = jndex(Z(jndex)<=1-2*Y(jndex).*Y(jndex)./X(jndex)); - Kndex = setdiff(jndex, kndex); - if ~isempty(Kndex) - lndex = Kndex(log(Z(Kndex))<=2*(bb(Kndex).*log(X(Kndex)./bb(Kndex))-Y(Kndex))); - Lndex = setdiff(Kndex, lndex); - else - Lndex = []; - end - new_index = INDEX(Lndex); - end - index = union(new_index, INDEX(Jndex)); + U(index) = rand(mm,1); + W(index) = U(index).*(1-U(index)); % e + Y(index) = sqrt(cc(index)./W(index)).*(U(index)-.5); % f + X(index) = bb(index)+Y(index); % x + id1 = index(X(index)<0); % Reject. + id2 = setdiff(index, id1); + Z(id2) = 64.0*(W(id2).^3).*(rand(length(id2),1).^2); % d + id3 = id2(Z(id2)>1.0-2.0*Y(id2).*Y(id2)./X(id2)); % Reject. + id4 = id3(log(Z(id3))>2.0*(bb(id3).*log(X(id3)./bb(id3))-Y(id3))); % Reject. + index = [id1, id4]; mm = length(index); end diff --git a/matlab/distributions/+gamrnd/best_1983.m b/matlab/distributions/+gamrnd/best_1983.m index 0be8736a0..a49d1cb45 100644 --- a/matlab/distributions/+gamrnd/best_1983.m +++ b/matlab/distributions/+gamrnd/best_1983.m @@ -1,6 +1,6 @@ function g = best_1983(a, b) -% Returns gamma variates, see Devroye (1986) page 426. +% Returns gamma variates, see Devroye (1986) page 426 and Best (1983) page 187. % % INPUTS % - a [double] n*1 vector, first hyperparameter. @@ -28,55 +28,31 @@ function g = best_1983(a, b) nn = length(a); mm = nn; -tt = .07 + .75*sqrt(1-a); -bb = 1 + exp(-tt).*a./tt; +index = 1:nn; +U = NaN(nn,1); +Ustar = NaN(nn, 1); +P = NaN(nn,1); +X = NaN(nn,1); +Y = NaN(nn,1); +zz = .07 + .75*sqrt(1-a); +bb = 1 + exp(-zz).*a./zz; cc = 1./a; -INDEX = 1:mm; -index = INDEX; -UW = NaN(nn,2); -V = NaN(nn,1); -X = NaN(nn,1); -Y = NaN(nn,1); while mm - UW(index,:) = rand(mm,2); - V(index) = UW(index,1).*bb(index); - state1 = find(V(index)<=1); - state2 = find(V(index)>1); - ID = []; - if ~isempty(state1) - X(index(state1)) = tt(index(state1)).*V(index(state1)).^cc(index(state1)); - test11 = UW(index(state1),2) <= (2-X(index(state1)))./(2+X(index(state1))) ; - id11 = find(~test11); - if ~isempty(id11) - test12 = UW(index(state1(id11)),2) <= exp(-X(index(state1(id11)))) ; - id12 = find(~test12); - else - id12 = []; - end - ID = INDEX(index(state1(id11(id12)))); - end - if ~isempty(state2) - X(index(state2)) = -log(cc(index(state2)).*tt(index(state2)).*(bb(index(state2))-V(index(state2)))) ; - Y(index(state2)) = X(index(state2))./tt(index(state2)) ; - test21 = UW(index(state2),2).*(a(index(state2)) + Y(index(state2)) - a(index(state2)).*Y(index(state2)) ) <= 1 ; - id21 = find(~test21); - if ~isempty(id21) - test22 = UW(index(state2(id21)),2) <= Y(index(state2(id21))).^(a(index(state2(id21)))-1) ; - id22 = find(~test22); - else - id22 = []; - end - if isempty(ID) - ID = INDEX(index(state2(id21(id22)))); - else - ID = [ID,INDEX(index(state2(id21(id22))))]; - end - end - mm = length(ID); - if mm - index = ID; - end + U(index) = rand(mm,1); + Ustar(index) = rand(mm, 1); + P(index) = U(index).*bb(index); + id1 = index(P(index)<=1); + id2 = setdiff(index, id1); % Goto 4. + X(id1) = zz(id1).*(P(id1).^cc(id1)); + id3 = id1(Ustar(id1)>((2-X(id1))./(2+X(id1)))); + id5 = id3(Ustar(id3)>exp(-X(id3))); + X(id2) = -log(cc(id2).*zz(id2).*(bb(id2)-P(id2))); % This is 4. + Y(id2) = X(id2)./zz(id2); + id4 = id2(Ustar(id2).*(a(id2)+(1-a(id2)).*Y(id2))>1); + id6 = id4(Ustar(id4)>Y(id4).^(a(id4)-1)); + index = [id5, id6]; + mm = length(index); end g = X.*b; diff --git a/matlab/distributions/+gamrnd/johnk.m b/matlab/distributions/+gamrnd/johnk.m index 0939446ec..45c5b1d7a 100644 --- a/matlab/distributions/+gamrnd/johnk.m +++ b/matlab/distributions/+gamrnd/johnk.m @@ -40,13 +40,8 @@ while mm UV(index,:) = rand(mm,2); X(index) = UV(index,1).^aa(index); Y(index) = UV(index,2).^bb(index); - id = find(X+Y>1); - if isempty(id) - mm = 0; - else - index = INDEX(id); - mm = length(index); - end + index = index(X(index)+Y(index)>1); + mm = length(index); end g = (exprnd(ones(nn,1)).*(X./(X+Y))).*b; \ No newline at end of file diff --git a/matlab/distributions/+gamrnd/knuth.m b/matlab/distributions/+gamrnd/knuth.m index 4099f5049..dfb4d9102 100644 --- a/matlab/distributions/+gamrnd/knuth.m +++ b/matlab/distributions/+gamrnd/knuth.m @@ -1,6 +1,6 @@ function g = knuth(a, b) -% Returns gamma variates, see Bauwens, Lubrano & Richard (1999) page 316. +% Returns gamma variates, see Knuth (1981) page 129. % % INPUTS % - a [double] n*1 vector, first hyperparameter. @@ -29,29 +29,18 @@ function g = knuth(a, b) nn = length(a); mm = nn; bb = sqrt(2*a-1); -dd = 1./(a-1); Y = NaN(nn,1); X = NaN(nn,1); -INDEX = 1:mm; -index = INDEX; +index = 1:mm; while mm Y(index) = tan(pi*rand(mm,1)); - X(index) = Y(index).*bb(index) + a(index) - 1 ; - idy1 = find(X(index)>=0); - idn1 = setdiff(index,index(idy1)); - if ~isempty(idy1) - test = log(rand(length(idy1),1)) <= ... - log(1+Y(index(idy1)).*Y(index(idy1))) + ... - (a(index(idy1))-1).*log(X(index(idy1)).*dd(index(idy1))) - ... - Y(index(idy1)).*bb(index(idy1)) ; - idy2 = find(test); - idn2 = setdiff(idy1, idy1(idy2)); - else - idy2 = []; - idn2 = []; - end - index = [ INDEX(idn1) , INDEX(index(idn2)) ] ; + X(index) = Y( + index).*bb(index) + a(index) - 1; + id1 = index(X(index)<=0); % Rejected draws. + id2 = setdiff(index, id1); + id3 = id2(rand(length(id2), 1)>(1+Y(id2).*Y(id2)).*exp((a(id2)-1).*(log(X(id2))-log(a(id2)-1))-bb(id2).*Y(id2))); % Rejected draws. + index = [id1, id3]; mm = length(index); end diff --git a/matlab/distributions/+gamrnd/weibull_rejection.m b/matlab/distributions/+gamrnd/weibull_rejection.m index e69f47415..d3fee6142 100644 --- a/matlab/distributions/+gamrnd/weibull_rejection.m +++ b/matlab/distributions/+gamrnd/weibull_rejection.m @@ -37,7 +37,6 @@ X = NaN(nn, 1); index = 1:nn; while mm - % Generate Weibull Z(index) = -log(rand(mm, 1)); Y(index) = Z(index).^cc(index); INDEX = index(rand(mm,1)>aa(index).*exp(Z(index)-Y(index)));