From 9261d30aa9618cb2a5a698d39752263b076f2d4b Mon Sep 17 00:00:00 2001 From: serge-sans-paille <serge.guelton@telecom-bretagne.eu> Date: Tue, 20 Aug 2024 23:50:55 +0200 Subject: [PATCH] Fix numpy.fix output type This one changed with recent numpy upgrade, see https://github.com/numpy/numpy/pull/26766 --- pythran/pythonic/include/numpy/fix.hpp | 17 ++++++++++++++--- pythran/pythonic/numpy/fix.hpp | 6 +++--- pythran/tests/test_numpy_func0.py | 5 +++++ 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/pythran/pythonic/include/numpy/fix.hpp b/pythran/pythonic/include/numpy/fix.hpp index 2708930d6c..e4a85a5049 100644 --- a/pythran/pythonic/include/numpy/fix.hpp +++ b/pythran/pythonic/include/numpy/fix.hpp @@ -1,18 +1,29 @@ #ifndef PYTHONIC_INCLUDE_NUMPY_FIX_HPP #define PYTHONIC_INCLUDE_NUMPY_FIX_HPP -#include "pythonic/include/utils/functor.hpp" #include "pythonic/include/types/ndarray.hpp" +#include "pythonic/include/utils/functor.hpp" #include "pythonic/include/utils/numpy_traits.hpp" PYTHONIC_NS_BEGIN namespace numpy { + namespace wrapper + { + template <class E> + E fix(E const &e) + { + if (std::is_integral<E>::value) + return e; + else + return std::trunc(e); + } + } // namespace wrapper #define NUMPY_NARY_FUNC_NAME fix -#define NUMPY_NARY_FUNC_SYM std::trunc +#define NUMPY_NARY_FUNC_SYM wrapper::fix #include "pythonic/include/types/numpy_nary_expr.hpp" -} +} // namespace numpy PYTHONIC_NS_END #endif diff --git a/pythran/pythonic/numpy/fix.hpp b/pythran/pythonic/numpy/fix.hpp index 5b1b020dc2..84773b61cf 100644 --- a/pythran/pythonic/numpy/fix.hpp +++ b/pythran/pythonic/numpy/fix.hpp @@ -3,8 +3,8 @@ #include "pythonic/include/numpy/fix.hpp" -#include "pythonic/utils/functor.hpp" #include "pythonic/types/ndarray.hpp" +#include "pythonic/utils/functor.hpp" #include "pythonic/utils/numpy_traits.hpp" PYTHONIC_NS_BEGIN @@ -13,9 +13,9 @@ namespace numpy { #define NUMPY_NARY_FUNC_NAME fix -#define NUMPY_NARY_FUNC_SYM std::trunc +#define NUMPY_NARY_FUNC_SYM wrapper::fix #include "pythonic/types/numpy_nary_expr.hpp" -} +} // namespace numpy PYTHONIC_NS_END #endif diff --git a/pythran/tests/test_numpy_func0.py b/pythran/tests/test_numpy_func0.py index 3e11133fec..41f716d900 100644 --- a/pythran/tests/test_numpy_func0.py +++ b/pythran/tests/test_numpy_func0.py @@ -1,12 +1,16 @@ import unittest from pythran.tests import TestEnv import numpy +from packaging import version import tempfile import os from pythran.typing import NDArray, List, Tuple +np_version = version.parse(numpy.version.version) + + class TestNumpyFunc0(TestEnv): def test_extended_sum0(self): @@ -910,6 +914,7 @@ def test_flatnonzero1(self): def test_fix0(self): self.run_test("def np_fix0(x): from numpy import fix ; return fix(x)", 3.14, np_fix0=[float]) + @unittest.skipIf(np_version <= version.Version("2.1"), reason="np.fix used to return float on integral input") def test_fix1(self): self.run_test("def np_fix1(x): from numpy import fix ; return fix(x)", 3, np_fix1=[int])