From 3307d7527325dc80ac441d845982f93fd84d5311 Mon Sep 17 00:00:00 2001 From: Steven Murray Date: Mon, 23 Sep 2024 12:48:01 +0200 Subject: [PATCH] test: better test ids for get_h --- hankel/tools.py | 11 ++++------- tests/test_get_h.py | 11 ++++++++++- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/hankel/tools.py b/hankel/tools.py index 5e7b7de..5096509 100755 --- a/hankel/tools.py +++ b/hankel/tools.py @@ -234,7 +234,7 @@ def get_h( i += 1 if i == maxiter: - raise Exception("Maxiter reached while checking for non-zero values") + raise RuntimeError("Maxiter reached while checking for non-zero values") if K is None: # Do a normal integral of f(x)J_nu(x) K = 1 @@ -263,7 +263,7 @@ def getres(h): res = getres(hstart) if i == maxiter: - raise Exception("Maxiter reached while checking convergence") + raise RuntimeError("Maxiter reached while checking convergence") # Can do some more trimming of N potentially, by seeing where f(x)~0. def consecutive(data, stepsize=1): @@ -272,17 +272,14 @@ def consecutive(data, stepsize=1): hstart *= hdecrement - x = cls(nu, h=hstart, N=int(np.pi / hstart)).x + x = np.atleast_1d(cls(nu, h=hstart, N=int(np.pi / hstart)).x) lastk = np.where(f(x / np.max(K)) == 0)[0] if len(lastk) > 1: # if there are any that are zero, # and if there are more than 1 in a row # (otherwise might just be oscillatory) lastk = consecutive(lastk) # split into arrays of consecutive zeros - if len(lastk[-1]) == 1: - lastk = int(np.pi / hstart) - else: - lastk = lastk[-1][0] + lastk = int(np.pi / hstart) if len(lastk[-1]) == 1 else lastk[-1][0] else: # otherwise set back to N lastk = int(np.pi / hstart) diff --git a/tests/test_get_h.py b/tests/test_get_h.py index bc32ee3..f856c4b 100644 --- a/tests/test_get_h.py +++ b/tests/test_get_h.py @@ -26,7 +26,7 @@ def gammaincc_(a, x): @pytest.mark.parametrize( "f, anl", [ - (lambda x: 1, 1), # Ogata 05 + (lambda x: np.ones_like(x), 1), # Ogata 05 (lambda x: x / (x**2 + 1), k0(1)), # Ogata 05 (lambda x: x**2, -1), # wikipedia (lambda x: x**4, 9), # wikipedia @@ -40,6 +40,15 @@ def gammaincc_(a, x): 1.0 / 2**2 * np.exp(-0.5 / 2**2), ), ], + ids=[ + "ones", + "x / (x^2 + 1)", + "x^2", + "x^4", + "1 / sqrt(x)", + "x / sqrt(x^2 + 1)", + "exp(1/2 * x^2)", + ] ) def test_nu0(f, anl): ans = get_h(f=f, nu=0, hstart=0.5, atol=0, rtol=1e-3, maxiter=20)[1]