From e066a6c559e9d7f31c359ea95da42d0e45c585ce Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Tue, 19 Mar 2024 11:32:32 +0100 Subject: [PATCH 01/65] replace the use of `numpy.array_api` with `array_api_strict` This would make it a dependency of `namedarray`, and not allow behavior that is allowed but not required by the array API standard. Otherwise we can: - use the main `numpy` namespace - use `array_api_compat` (would also be a new dependency) to allow optional behavior --- xarray/namedarray/_array_api.py | 9 --------- 1 file changed, 9 deletions(-) Index: xarray-2024.05.0/xarray/namedarray/_array_api.py =================================================================== --- xarray-2024.05.0.orig/xarray/namedarray/_array_api.py +++ xarray-2024.05.0/xarray/namedarray/_array_api.py @@ -1,6 +1,5 @@ from __future__ import annotations -import warnings from types import ModuleType from typing import Any @@ -21,14 +20,6 @@ from xarray.namedarray._typing import ( ) from xarray.namedarray.core import NamedArray -with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", - r"The numpy.array_api submodule is still experimental", - category=UserWarning, - ) - import numpy.array_api as nxp # noqa: F401 - def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: if isinstance(x._data, _arrayapi): @@ -68,13 +59,13 @@ def astype( Examples -------- - >>> narr = NamedArray(("x",), nxp.asarray([1.5, 2.5])) + >>> narr = NamedArray(("x",), np.asarray([1.5, 2.5])) >>> narr Size: 16B - Array([1.5, 2.5], dtype=float64) + array([1.5, 2.5]) >>> astype(narr, np.dtype(np.int32)) Size: 8B - Array([1, 2], dtype=int32) + array([1, 2], dtype=int32) """ if isinstance(x._data, _arrayapi): xp = x._data.__array_namespace__() @@ -109,7 +100,7 @@ def imag( Examples -------- - >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) # TODO: Use nxp + >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) >>> imag(narr) Size: 16B array([2., 4.]) @@ -141,7 +132,7 @@ def real( Examples -------- - >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) # TODO: Use nxp + >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) >>> real(narr) Size: 16B array([1., 2.]) @@ -179,15 +170,15 @@ def expand_dims( Examples -------- - >>> x = NamedArray(("x", "y"), nxp.asarray([[1.0, 2.0], [3.0, 4.0]])) + >>> x = NamedArray(("x", "y"), np.asarray([[1.0, 2.0], [3.0, 4.0]])) >>> expand_dims(x) Size: 32B - Array([[[1., 2.], - [3., 4.]]], dtype=float64) + array([[[1., 2.], + [3., 4.]]]) >>> expand_dims(x, dim="z") Size: 32B - Array([[[1., 2.], - [3., 4.]]], dtype=float64) + array([[[1., 2.], + [3., 4.]]]) """ xp = _get_data_namespace(x) dims = x.dims Index: xarray-2024.05.0/xarray/tests/__init__.py =================================================================== --- xarray-2024.05.0.orig/xarray/tests/__init__.py +++ xarray-2024.05.0/xarray/tests/__init__.py @@ -147,9 +147,10 @@ has_numbagg_or_bottleneck = has_numbagg requires_numbagg_or_bottleneck = pytest.mark.skipif( not has_scipy_or_netCDF4, reason="requires scipy or netCDF4" ) -has_numpy_array_api, requires_numpy_array_api = _importorskip("numpy", "1.26.0") has_numpy_2, requires_numpy_2 = _importorskip("numpy", "2.0.0") +has_array_api_strict, requires_array_api_strict = _importorskip("array_api_strict") + def _importorskip_h5netcdf_ros3(): try: Index: xarray-2024.05.0/xarray/tests/test_array_api.py =================================================================== --- xarray-2024.05.0.orig/xarray/tests/test_array_api.py +++ xarray-2024.05.0/xarray/tests/test_array_api.py @@ -6,20 +6,9 @@ import xarray as xr from xarray.testing import assert_equal np = pytest.importorskip("numpy", minversion="1.22") +xp = pytest.importorskip("array_api_strict") -try: - import warnings - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - import numpy.array_api as xp - from numpy.array_api._array_object import Array -except ImportError: - # for `numpy>=2.0` - xp = pytest.importorskip("array_api_strict") - - from array_api_strict._array_object import Array # type: ignore[no-redef] +from array_api_strict._array_object import Array # isort:skip # type: ignore[no-redef] @pytest.fixture @@ -65,8 +54,8 @@ def test_aggregation_skipna(arrays) -> N def test_astype(arrays) -> None: np_arr, xp_arr = arrays expected = np_arr.astype(np.int64) - actual = xp_arr.astype(np.int64) - assert actual.dtype == np.int64 + actual = xp_arr.astype(xp.int64) + assert actual.dtype == xp.int64 assert isinstance(actual.data, Array) assert_equal(actual, expected) @@ -118,8 +107,10 @@ def test_indexing(arrays: tuple[xr.DataA def test_properties(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: np_arr, xp_arr = arrays - assert np_arr.nbytes == np_arr.data.nbytes - assert xp_arr.nbytes == np_arr.data.nbytes + + expected = np_arr.data.nbytes + assert np_arr.nbytes == expected + assert xp_arr.nbytes == expected def test_reorganizing_operation(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: Index: xarray-2024.05.0/xarray/tests/test_namedarray.py =================================================================== --- xarray-2024.05.0.orig/xarray/tests/test_namedarray.py +++ xarray-2024.05.0/xarray/tests/test_namedarray.py @@ -1,7 +1,6 @@ from __future__ import annotations import copy -import warnings from abc import abstractmethod from collections.abc import Mapping from typing import TYPE_CHECKING, Any, Generic, cast, overload @@ -79,6 +78,17 @@ class CustomArrayIndexable( return np +def check_duck_array_typevar(a: duckarray[Any, _DType]) -> duckarray[Any, _DType]: + # Mypy checks a is valid: + b: duckarray[Any, _DType] = a + + # Runtime check if valid: + if isinstance(b, _arrayfunction_or_api): + return b + else: + raise TypeError(f"a ({type(a)}) is not a valid _arrayfunction or _arrayapi") + + class NamedArraySubclassobjects: @pytest.fixture def target(self, data: np.ndarray[Any, Any]) -> Any: @@ -328,48 +338,27 @@ class TestNamedArray(NamedArraySubclasso named_array.dims = new_dims assert named_array.dims == tuple(new_dims) - def test_duck_array_class( - self, - ) -> None: - def test_duck_array_typevar( - a: duckarray[Any, _DType], - ) -> duckarray[Any, _DType]: - # Mypy checks a is valid: - b: duckarray[Any, _DType] = a - - # Runtime check if valid: - if isinstance(b, _arrayfunction_or_api): - return b - else: - raise TypeError( - f"a ({type(a)}) is not a valid _arrayfunction or _arrayapi" - ) - + def test_duck_array_class(self) -> None: numpy_a: NDArray[np.int64] numpy_a = np.array([2.1, 4], dtype=np.dtype(np.int64)) - test_duck_array_typevar(numpy_a) + check_duck_array_typevar(numpy_a) masked_a: np.ma.MaskedArray[Any, np.dtype[np.int64]] masked_a = np.ma.asarray([2.1, 4], dtype=np.dtype(np.int64)) # type: ignore[no-untyped-call] - test_duck_array_typevar(masked_a) + check_duck_array_typevar(masked_a) custom_a: CustomArrayIndexable[Any, np.dtype[np.int64]] custom_a = CustomArrayIndexable(numpy_a) - test_duck_array_typevar(custom_a) + check_duck_array_typevar(custom_a) + def test_duck_array_class_array_api(self) -> None: # Test numpy's array api: - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", - r"The numpy.array_api submodule is still experimental", - category=UserWarning, - ) - import numpy.array_api as nxp + nxp = pytest.importorskip("array_api_strict", minversion="1.0") # TODO: nxp doesn't use dtype typevars, so can only use Any for the moment: arrayapi_a: duckarray[Any, Any] # duckarray[Any, np.dtype[np.int64]] - arrayapi_a = nxp.asarray([2.1, 4], dtype=np.dtype(np.int64)) - test_duck_array_typevar(arrayapi_a) + arrayapi_a = nxp.asarray([2.1, 4], dtype=nxp.int64) + check_duck_array_typevar(arrayapi_a) def test_new_namedarray(self) -> None: dtype_float = np.dtype(np.float32) Index: xarray-2024.05.0/xarray/tests/test_strategies.py =================================================================== --- xarray-2024.05.0.orig/xarray/tests/test_strategies.py +++ xarray-2024.05.0/xarray/tests/test_strategies.py @@ -1,6 +1,9 @@ +import warnings + import numpy as np import numpy.testing as npt import pytest +from packaging.version import Version pytest.importorskip("hypothesis") # isort: split @@ -19,7 +22,6 @@ from xarray.testing.strategies import ( unique_subset_of, variables, ) -from xarray.tests import requires_numpy_array_api ALLOWED_ATTRS_VALUES_TYPES = (int, bool, str, np.ndarray) @@ -199,7 +201,6 @@ class TestVariablesStrategy: ) ) - @requires_numpy_array_api @given(st.data()) def test_make_strategies_namespace(self, data): """ @@ -208,16 +209,24 @@ class TestVariablesStrategy: We still want to generate dtypes not in the array API by default, but this checks we don't accidentally override the user's choice of dtypes with non-API-compliant ones. """ - from numpy import ( - array_api as np_array_api, # requires numpy>=1.26.0, and we expect a UserWarning to be raised - ) + if Version(np.__version__) >= Version("2.0.0.dev0"): + nxp = np + else: + # requires numpy>=1.26.0, and we expect a UserWarning to be raised + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", category=UserWarning, message=".+See NEP 47." + ) + from numpy import ( # type: ignore[no-redef,unused-ignore] + array_api as nxp, + ) - np_array_api_st = make_strategies_namespace(np_array_api) + nxp_st = make_strategies_namespace(nxp) data.draw( variables( - array_strategy_fn=np_array_api_st.arrays, - dtype=np_array_api_st.scalar_dtypes(), + array_strategy_fn=nxp_st.arrays, + dtype=nxp_st.scalar_dtypes(), ) ) Index: xarray-2024.05.0/xarray/core/duck_array_ops.py =================================================================== --- xarray-2024.05.0.orig/xarray/core/duck_array_ops.py +++ xarray-2024.05.0/xarray/core/duck_array_ops.py @@ -142,17 +142,25 @@ around.__doc__ = str.replace( def isnull(data): data = asarray(data) - scalar_type = data.dtype.type - if issubclass(scalar_type, (np.datetime64, np.timedelta64)): + + xp = get_array_namespace(data) + scalar_type = data.dtype + if dtypes.is_datetime_like(scalar_type): # datetime types use NaT for null # note: must check timedelta64 before integers, because currently # timedelta64 inherits from np.integer return isnat(data) - elif issubclass(scalar_type, np.inexact): + elif dtypes.isdtype(scalar_type, ("real floating", "complex floating"), xp=xp): # float types use NaN for null xp = get_array_namespace(data) return xp.isnan(data) - elif issubclass(scalar_type, (np.bool_, np.integer, np.character, np.void)): + elif dtypes.isdtype(scalar_type, ("bool", "integral"), xp=xp) or ( + isinstance(scalar_type, np.dtype) + and ( + np.issubdtype(scalar_type, np.character) + or np.issubdtype(scalar_type, np.void) + ) + ): # these types cannot represent missing values return full_like(data, dtype=bool, fill_value=False) else: @@ -406,13 +414,22 @@ def _create_nan_agg_method(name, coerce_ if invariant_0d and axis == (): return values - values = asarray(values) + xp = get_array_namespace(values) + values = asarray(values, xp=xp) - if coerce_strings and values.dtype.kind in "SU": + if coerce_strings and dtypes.is_string(values.dtype): values = astype(values, object) func = None - if skipna or (skipna is None and values.dtype.kind in "cfO"): + if skipna or ( + skipna is None + and ( + dtypes.isdtype( + values.dtype, ("complex floating", "real floating"), xp=xp + ) + or dtypes.is_object(values.dtype) + ) + ): nanname = "nan" + name func = getattr(nanops, nanname) else: @@ -477,8 +494,8 @@ def _datetime_nanmin(array): - numpy nanmin() don't work on datetime64 (all versions at the moment of writing) - dask min() does not work on datetime64 (all versions at the moment of writing) """ - assert array.dtype.kind in "mM" dtype = array.dtype + assert dtypes.is_datetime_like(dtype) # (NaT).astype(float) does not produce NaN... array = where(pandas_isnull(array), np.nan, array.astype(float)) array = min(array, skipna=True) @@ -515,7 +532,7 @@ def datetime_to_numeric(array, offset=No """ # Set offset to minimum if not given if offset is None: - if array.dtype.kind in "Mm": + if dtypes.is_datetime_like(array.dtype): offset = _datetime_nanmin(array) else: offset = min(array) @@ -527,7 +544,7 @@ def datetime_to_numeric(array, offset=No # This map_blocks call is for backwards compatibility. # dask == 2021.04.1 does not support subtracting object arrays # which is required for cftime - if is_duck_dask_array(array) and np.issubdtype(array.dtype, object): + if is_duck_dask_array(array) and dtypes.is_object(array.dtype): array = array.map_blocks(lambda a, b: a - b, offset, meta=array._meta) else: array = array - offset @@ -537,11 +554,11 @@ def datetime_to_numeric(array, offset=No array = np.array(array) # Convert timedelta objects to float by first converting to microseconds. - if array.dtype.kind in "O": + if dtypes.is_object(array.dtype): return py_timedelta_to_float(array, datetime_unit or "ns").astype(dtype) # Convert np.NaT to np.nan - elif array.dtype.kind in "mM": + elif dtypes.is_datetime_like(array.dtype): # Convert to specified timedelta units. if datetime_unit: array = array / np.timedelta64(1, datetime_unit) @@ -641,7 +658,7 @@ def mean(array, axis=None, skipna=None, from xarray.core.common import _contains_cftime_datetimes array = asarray(array) - if array.dtype.kind in "Mm": + if dtypes.is_datetime_like(array.dtype): offset = _datetime_nanmin(array) # xarray always uses np.datetime64[ns] for np.datetime64 data @@ -689,7 +706,9 @@ def cumsum(array, axis=None, **kwargs): def first(values, axis, skipna=None): """Return the first non-NA elements in this array along the given axis""" - if (skipna or skipna is None) and values.dtype.kind not in "iSU": + if (skipna or skipna is None) and not ( + dtypes.isdtype(values.dtype, "signed integer") or dtypes.is_string(values.dtype) + ): # only bother for dtypes that can hold NaN if is_chunked_array(values): return chunked_nanfirst(values, axis) @@ -700,7 +719,9 @@ def first(values, axis, skipna=None): def last(values, axis, skipna=None): """Return the last non-NA elements in this array along the given axis""" - if (skipna or skipna is None) and values.dtype.kind not in "iSU": + if (skipna or skipna is None) and not ( + dtypes.isdtype(values.dtype, "signed integer") or dtypes.is_string(values.dtype) + ): # only bother for dtypes that can hold NaN if is_chunked_array(values): return chunked_nanlast(values, axis) Index: xarray-2024.05.0/xarray/core/dtypes.py =================================================================== --- xarray-2024.05.0.orig/xarray/core/dtypes.py +++ xarray-2024.05.0/xarray/core/dtypes.py @@ -4,8 +4,9 @@ import functools from typing import Any import numpy as np +from pandas.api.types import is_extension_array_dtype -from xarray.core import utils +from xarray.core import npcompat, utils # Use as a sentinel value to indicate a dtype appropriate NA value. NA = utils.ReprObject("") @@ -60,22 +61,22 @@ def maybe_promote(dtype: np.dtype) -> tu # N.B. these casting rules should match pandas dtype_: np.typing.DTypeLike fill_value: Any - if np.issubdtype(dtype, np.floating): + if isdtype(dtype, "real floating"): dtype_ = dtype fill_value = np.nan - elif np.issubdtype(dtype, np.timedelta64): + elif isinstance(dtype, np.dtype) and np.issubdtype(dtype, np.timedelta64): # See https://github.com/numpy/numpy/issues/10685 # np.timedelta64 is a subclass of np.integer # Check np.timedelta64 before np.integer fill_value = np.timedelta64("NaT") dtype_ = dtype - elif np.issubdtype(dtype, np.integer): + elif isdtype(dtype, "integral"): dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64 fill_value = np.nan - elif np.issubdtype(dtype, np.complexfloating): + elif isdtype(dtype, "complex floating"): dtype_ = dtype fill_value = np.nan + np.nan * 1j - elif np.issubdtype(dtype, np.datetime64): + elif isinstance(dtype, np.dtype) and np.issubdtype(dtype, np.datetime64): dtype_ = dtype fill_value = np.datetime64("NaT") else: @@ -118,16 +119,16 @@ def get_pos_infinity(dtype, max_for_int= ------- fill_value : positive infinity value corresponding to this dtype. """ - if issubclass(dtype.type, np.floating): + if isdtype(dtype, "real floating"): return np.inf - if issubclass(dtype.type, np.integer): + if isdtype(dtype, "integral"): if max_for_int: return np.iinfo(dtype).max else: return np.inf - if issubclass(dtype.type, np.complexfloating): + if isdtype(dtype, "complex floating"): return np.inf + 1j * np.inf return INF @@ -146,24 +147,66 @@ def get_neg_infinity(dtype, min_for_int= ------- fill_value : positive infinity value corresponding to this dtype. """ - if issubclass(dtype.type, np.floating): + if isdtype(dtype, "real floating"): return -np.inf - if issubclass(dtype.type, np.integer): + if isdtype(dtype, "integral"): if min_for_int: return np.iinfo(dtype).min else: return -np.inf - if issubclass(dtype.type, np.complexfloating): + if isdtype(dtype, "complex floating"): return -np.inf - 1j * np.inf return NINF -def is_datetime_like(dtype): +def is_datetime_like(dtype) -> bool: """Check if a dtype is a subclass of the numpy datetime types""" - return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) + return _is_numpy_subdtype(dtype, (np.datetime64, np.timedelta64)) + + +def is_object(dtype) -> bool: + """Check if a dtype is object""" + return _is_numpy_subdtype(dtype, object) + + +def is_string(dtype) -> bool: + """Check if a dtype is a string dtype""" + return _is_numpy_subdtype(dtype, (np.str_, np.character)) + + +def _is_numpy_subdtype(dtype, kind) -> bool: + if not isinstance(dtype, np.dtype): + return False + + kinds = kind if isinstance(kind, tuple) else (kind,) + return any(np.issubdtype(dtype, kind) for kind in kinds) + + +def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool: + """Compatibility wrapper for isdtype() from the array API standard. + + Unlike xp.isdtype(), kind must be a string. + """ + # TODO(shoyer): remove this wrapper when Xarray requires + # numpy>=2 and pandas extensions arrays are implemented in + # Xarray via the array API + if not isinstance(kind, str) and not ( + isinstance(kind, tuple) and all(isinstance(k, str) for k in kind) + ): + raise TypeError(f"kind must be a string or a tuple of strings: {repr(kind)}") + + if isinstance(dtype, np.dtype): + return npcompat.isdtype(dtype, kind) + elif is_extension_array_dtype(dtype): + # we never want to match pandas extension array dtypes + return False + else: + if xp is None: + xp = np + return xp.isdtype(dtype, kind) def result_type( @@ -184,12 +227,26 @@ def result_type( ------- numpy.dtype for the result. """ - types = {np.result_type(t).type for t in arrays_and_dtypes} + from xarray.core.duck_array_ops import get_array_namespace + + # TODO(shoyer): consider moving this logic into get_array_namespace() + # or another helper function. + namespaces = {get_array_namespace(t) for t in arrays_and_dtypes} + non_numpy = namespaces - {np} + if non_numpy: + [xp] = non_numpy + else: + xp = np + + types = {xp.result_type(t) for t in arrays_and_dtypes} - for left, right in PROMOTE_TO_OBJECT: - if any(issubclass(t, left) for t in types) and any( - issubclass(t, right) for t in types - ): - return np.dtype(object) + if any(isinstance(t, np.dtype) for t in types): + # only check if there's numpy dtypes – the array API does not + # define the types we're checking for + for left, right in PROMOTE_TO_OBJECT: + if any(np.issubdtype(t, left) for t in types) and any( + np.issubdtype(t, right) for t in types + ): + return xp.dtype(object) - return np.result_type(*arrays_and_dtypes) + return xp.result_type(*arrays_and_dtypes) Index: xarray-2024.05.0/xarray/namedarray/core.py =================================================================== --- xarray-2024.05.0.orig/xarray/namedarray/core.py +++ xarray-2024.05.0/xarray/namedarray/core.py @@ -470,10 +470,28 @@ class NamedArray(NamedArrayAggregations, If the underlying data array does not include ``nbytes``, estimates the bytes consumed based on the ``size`` and ``dtype``. """ + from xarray.namedarray._array_api import _get_data_namespace + if hasattr(self._data, "nbytes"): return self._data.nbytes # type: ignore[no-any-return] + + if hasattr(self.dtype, "itemsize"): + itemsize = self.dtype.itemsize + elif isinstance(self._data, _arrayapi): + xp = _get_data_namespace(self) + + if xp.isdtype(self.dtype, "bool"): + itemsize = 1 + elif xp.isdtype(self.dtype, "integral"): + itemsize = xp.iinfo(self.dtype).bits // 8 + else: + itemsize = xp.finfo(self.dtype).bits // 8 else: - return self.size * self.dtype.itemsize + raise TypeError( + "cannot compute the number of bytes (no array API nor nbytes / itemsize)" + ) + + return self.size * itemsize @property def dims(self) -> _Dims: Index: xarray-2024.05.0/xarray/tests/test_dtypes.py =================================================================== --- xarray-2024.05.0.orig/xarray/tests/test_dtypes.py +++ xarray-2024.05.0/xarray/tests/test_dtypes.py @@ -4,6 +4,18 @@ import numpy as np import pytest from xarray.core import dtypes +from xarray.tests import requires_array_api_strict + +try: + import array_api_strict +except ImportError: + + class DummyArrayAPINamespace: + bool = None + int32 = None + float64 = None + + array_api_strict = DummyArrayAPINamespace @pytest.mark.parametrize( @@ -58,7 +70,6 @@ def test_inf(obj) -> None: @pytest.mark.parametrize( "kind, expected", [ - ("a", (np.dtype("O"), "nan")), # dtype('S') ("b", (np.float32, "nan")), # dtype('int8') ("B", (np.float32, "nan")), # dtype('uint8') ("c", (np.dtype("O"), "nan")), # dtype('S1') @@ -98,3 +109,54 @@ def test_nat_types_membership() -> None: assert np.datetime64("NaT").dtype in dtypes.NAT_TYPES assert np.timedelta64("NaT").dtype in dtypes.NAT_TYPES assert np.float64 not in dtypes.NAT_TYPES + + +@pytest.mark.parametrize( + ["dtype", "kinds", "xp", "expected"], + ( + (np.dtype("int32"), "integral", np, True), + (np.dtype("float16"), "real floating", np, True), + (np.dtype("complex128"), "complex floating", np, True), + (np.dtype("U"), "numeric", np, False), + pytest.param( + array_api_strict.int32, + "integral", + array_api_strict, + True, + marks=requires_array_api_strict, + id="array_api-int", + ), + pytest.param( + array_api_strict.float64, + "real floating", + array_api_strict, + True, + marks=requires_array_api_strict, + id="array_api-float", + ), + pytest.param( + array_api_strict.bool, + "numeric", + array_api_strict, + False, + marks=requires_array_api_strict, + id="array_api-bool", + ), + ), +) +def test_isdtype(dtype, kinds, xp, expected) -> None: + actual = dtypes.isdtype(dtype, kinds, xp=xp) + assert actual == expected + + +@pytest.mark.parametrize( + ["dtype", "kinds", "xp", "error", "pattern"], + ( + (np.dtype("int32"), "foo", np, (TypeError, ValueError), "kind"), + (np.dtype("int32"), np.signedinteger, np, TypeError, "kind"), + (np.dtype("float16"), 1, np, TypeError, "kind"), + ), +) +def test_isdtype_error(dtype, kinds, xp, error, pattern): + with pytest.raises(error, match=pattern): + dtypes.isdtype(dtype, kinds, xp=xp) Index: xarray-2024.05.0/xarray/core/npcompat.py =================================================================== --- xarray-2024.05.0.orig/xarray/core/npcompat.py +++ xarray-2024.05.0/xarray/core/npcompat.py @@ -28,3 +28,33 @@ # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +try: + # requires numpy>=2.0 + from numpy import isdtype # type: ignore[attr-defined,unused-ignore] +except ImportError: + import numpy as np + + dtype_kinds = { + "bool": np.bool_, + "signed integer": np.signedinteger, + "unsigned integer": np.unsignedinteger, + "integral": np.integer, + "real floating": np.floating, + "complex floating": np.complexfloating, + "numeric": np.number, + } + + def isdtype(dtype, kind): + kinds = kind if isinstance(kind, tuple) else (kind,) + + unknown_dtypes = [kind for kind in kinds if kind not in dtype_kinds] + if unknown_dtypes: + raise ValueError(f"unknown dtype kinds: {unknown_dtypes}") + + # verified the dtypes already, no need to check again + translated_kinds = [dtype_kinds[kind] for kind in kinds] + if isinstance(dtype, np.generic): + return any(isinstance(dtype, kind) for kind in translated_kinds) + else: + return any(np.issubdtype(dtype, kind) for kind in translated_kinds)