1
0
forked from pool/python-xarray
python-xarray/xarray-pr8854-np2.patch
Sebastian Wagner d3af5d70ef - Update to 2024.7.0
* Add test for rechunking to a size string by @dcherian in #9117
  * Update docstring in api.py for open_mfdataset(), clarifying
    "chunks" argument by @arthur-e in #9121
  * Grouper refactor by @dcherian in #9122
  * adjust repr tests to account for different platforms (#9127) by
    @mgorny in #9128
  * Support duplicate dimensions in .chunk by @mraspaud in #9099
  * Update zendoo badge link by @max-sixty in #9133
  * Split out distributed writes in zarr docs by @max-sixty in
    #9132
  * Improve to_zarr docs by @max-sixty in #9139
  * groupby: remove some internal use of IndexVariable by @dcherian
    in #9123
  * Improve zarr chunks docs by @max-sixty in #9140
  * Include numbagg in type checks by @max-sixty in #9159
  * Remove mypy exclusions for a couple more libraries by
    @max-sixty in #9160
  * Add test for #9155 by @max-sixty in #9161
  * switch to datetime unit "D" by @keewis in #9170
  * Slightly improve DataTree repr by @shoyer in #9064
  * Fix example code formatting for CachingFileManager by @djhoese
    in #9178
  * Change np.core.defchararray to np.char (#9165) by @pont-us in
    #9166
  * temporarily remove pydap from CI by @keewis in #9183
  * also pin numpy in the all-but-dask CI by @keewis in #9184
  * promote floating-point numeric datetimes to 64-bit before
    decoding by @keewis in #9182
  * "source" encoding for datasets opened from fsspec objects by

OBS-URL: https://build.opensuse.org/package/show/devel:languages:python:numeric/python-xarray?expand=0&rev=99
2024-09-04 13:02:54 +00:00

773 lines
28 KiB
Diff
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

From e066a6c559e9d7f31c359ea95da42d0e45c585ce Mon Sep 17 00:00:00 2001
From: Justus Magin <keewis@posteo.de>
Date: Tue, 19 Mar 2024 11:32:32 +0100
Subject: [PATCH 01/65] replace the use of `numpy.array_api` with
`array_api_strict`
This would make it a dependency of `namedarray`, and not allow
behavior that is allowed but not required by the array API standard. Otherwise we can:
- use the main `numpy` namespace
- use `array_api_compat` (would also be a new dependency) to allow
optional behavior
---
xarray/namedarray/_array_api.py | 9 ---------
1 file changed, 9 deletions(-)
Index: xarray-2024.05.0/xarray/namedarray/_array_api.py
===================================================================
--- xarray-2024.05.0.orig/xarray/namedarray/_array_api.py
+++ xarray-2024.05.0/xarray/namedarray/_array_api.py
@@ -1,6 +1,5 @@
from __future__ import annotations
-import warnings
from types import ModuleType
from typing import Any
@@ -21,14 +20,6 @@ from xarray.namedarray._typing import (
)
from xarray.namedarray.core import NamedArray
-with warnings.catch_warnings():
- warnings.filterwarnings(
- "ignore",
- r"The numpy.array_api submodule is still experimental",
- category=UserWarning,
- )
- import numpy.array_api as nxp # noqa: F401
-
def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType:
if isinstance(x._data, _arrayapi):
@@ -68,13 +59,13 @@ def astype(
Examples
--------
- >>> narr = NamedArray(("x",), nxp.asarray([1.5, 2.5]))
+ >>> narr = NamedArray(("x",), np.asarray([1.5, 2.5]))
>>> narr
<xarray.NamedArray (x: 2)> Size: 16B
- Array([1.5, 2.5], dtype=float64)
+ array([1.5, 2.5])
>>> astype(narr, np.dtype(np.int32))
<xarray.NamedArray (x: 2)> Size: 8B
- Array([1, 2], dtype=int32)
+ array([1, 2], dtype=int32)
"""
if isinstance(x._data, _arrayapi):
xp = x._data.__array_namespace__()
@@ -109,7 +100,7 @@ def imag(
Examples
--------
- >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) # TODO: Use nxp
+ >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j]))
>>> imag(narr)
<xarray.NamedArray (x: 2)> Size: 16B
array([2., 4.])
@@ -141,7 +132,7 @@ def real(
Examples
--------
- >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) # TODO: Use nxp
+ >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j]))
>>> real(narr)
<xarray.NamedArray (x: 2)> Size: 16B
array([1., 2.])
@@ -179,15 +170,15 @@ def expand_dims(
Examples
--------
- >>> x = NamedArray(("x", "y"), nxp.asarray([[1.0, 2.0], [3.0, 4.0]]))
+ >>> x = NamedArray(("x", "y"), np.asarray([[1.0, 2.0], [3.0, 4.0]]))
>>> expand_dims(x)
<xarray.NamedArray (dim_2: 1, x: 2, y: 2)> Size: 32B
- Array([[[1., 2.],
- [3., 4.]]], dtype=float64)
+ array([[[1., 2.],
+ [3., 4.]]])
>>> expand_dims(x, dim="z")
<xarray.NamedArray (z: 1, x: 2, y: 2)> Size: 32B
- Array([[[1., 2.],
- [3., 4.]]], dtype=float64)
+ array([[[1., 2.],
+ [3., 4.]]])
"""
xp = _get_data_namespace(x)
dims = x.dims
Index: xarray-2024.05.0/xarray/tests/__init__.py
===================================================================
--- xarray-2024.05.0.orig/xarray/tests/__init__.py
+++ xarray-2024.05.0/xarray/tests/__init__.py
@@ -147,9 +147,10 @@ has_numbagg_or_bottleneck = has_numbagg
requires_numbagg_or_bottleneck = pytest.mark.skipif(
not has_scipy_or_netCDF4, reason="requires scipy or netCDF4"
)
-has_numpy_array_api, requires_numpy_array_api = _importorskip("numpy", "1.26.0")
has_numpy_2, requires_numpy_2 = _importorskip("numpy", "2.0.0")
+has_array_api_strict, requires_array_api_strict = _importorskip("array_api_strict")
+
def _importorskip_h5netcdf_ros3():
try:
Index: xarray-2024.05.0/xarray/tests/test_array_api.py
===================================================================
--- xarray-2024.05.0.orig/xarray/tests/test_array_api.py
+++ xarray-2024.05.0/xarray/tests/test_array_api.py
@@ -6,20 +6,9 @@ import xarray as xr
from xarray.testing import assert_equal
np = pytest.importorskip("numpy", minversion="1.22")
+xp = pytest.importorskip("array_api_strict")
-try:
- import warnings
-
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
-
- import numpy.array_api as xp
- from numpy.array_api._array_object import Array
-except ImportError:
- # for `numpy>=2.0`
- xp = pytest.importorskip("array_api_strict")
-
- from array_api_strict._array_object import Array # type: ignore[no-redef]
+from array_api_strict._array_object import Array # isort:skip # type: ignore[no-redef]
@pytest.fixture
@@ -65,8 +54,8 @@ def test_aggregation_skipna(arrays) -> N
def test_astype(arrays) -> None:
np_arr, xp_arr = arrays
expected = np_arr.astype(np.int64)
- actual = xp_arr.astype(np.int64)
- assert actual.dtype == np.int64
+ actual = xp_arr.astype(xp.int64)
+ assert actual.dtype == xp.int64
assert isinstance(actual.data, Array)
assert_equal(actual, expected)
@@ -118,8 +107,10 @@ def test_indexing(arrays: tuple[xr.DataA
def test_properties(arrays: tuple[xr.DataArray, xr.DataArray]) -> None:
np_arr, xp_arr = arrays
- assert np_arr.nbytes == np_arr.data.nbytes
- assert xp_arr.nbytes == np_arr.data.nbytes
+
+ expected = np_arr.data.nbytes
+ assert np_arr.nbytes == expected
+ assert xp_arr.nbytes == expected
def test_reorganizing_operation(arrays: tuple[xr.DataArray, xr.DataArray]) -> None:
Index: xarray-2024.05.0/xarray/tests/test_namedarray.py
===================================================================
--- xarray-2024.05.0.orig/xarray/tests/test_namedarray.py
+++ xarray-2024.05.0/xarray/tests/test_namedarray.py
@@ -1,7 +1,6 @@
from __future__ import annotations
import copy
-import warnings
from abc import abstractmethod
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Generic, cast, overload
@@ -79,6 +78,17 @@ class CustomArrayIndexable(
return np
+def check_duck_array_typevar(a: duckarray[Any, _DType]) -> duckarray[Any, _DType]:
+ # Mypy checks a is valid:
+ b: duckarray[Any, _DType] = a
+
+ # Runtime check if valid:
+ if isinstance(b, _arrayfunction_or_api):
+ return b
+ else:
+ raise TypeError(f"a ({type(a)}) is not a valid _arrayfunction or _arrayapi")
+
+
class NamedArraySubclassobjects:
@pytest.fixture
def target(self, data: np.ndarray[Any, Any]) -> Any:
@@ -328,48 +338,27 @@ class TestNamedArray(NamedArraySubclasso
named_array.dims = new_dims
assert named_array.dims == tuple(new_dims)
- def test_duck_array_class(
- self,
- ) -> None:
- def test_duck_array_typevar(
- a: duckarray[Any, _DType],
- ) -> duckarray[Any, _DType]:
- # Mypy checks a is valid:
- b: duckarray[Any, _DType] = a
-
- # Runtime check if valid:
- if isinstance(b, _arrayfunction_or_api):
- return b
- else:
- raise TypeError(
- f"a ({type(a)}) is not a valid _arrayfunction or _arrayapi"
- )
-
+ def test_duck_array_class(self) -> None:
numpy_a: NDArray[np.int64]
numpy_a = np.array([2.1, 4], dtype=np.dtype(np.int64))
- test_duck_array_typevar(numpy_a)
+ check_duck_array_typevar(numpy_a)
masked_a: np.ma.MaskedArray[Any, np.dtype[np.int64]]
masked_a = np.ma.asarray([2.1, 4], dtype=np.dtype(np.int64)) # type: ignore[no-untyped-call]
- test_duck_array_typevar(masked_a)
+ check_duck_array_typevar(masked_a)
custom_a: CustomArrayIndexable[Any, np.dtype[np.int64]]
custom_a = CustomArrayIndexable(numpy_a)
- test_duck_array_typevar(custom_a)
+ check_duck_array_typevar(custom_a)
+ def test_duck_array_class_array_api(self) -> None:
# Test numpy's array api:
- with warnings.catch_warnings():
- warnings.filterwarnings(
- "ignore",
- r"The numpy.array_api submodule is still experimental",
- category=UserWarning,
- )
- import numpy.array_api as nxp
+ nxp = pytest.importorskip("array_api_strict", minversion="1.0")
# TODO: nxp doesn't use dtype typevars, so can only use Any for the moment:
arrayapi_a: duckarray[Any, Any] # duckarray[Any, np.dtype[np.int64]]
- arrayapi_a = nxp.asarray([2.1, 4], dtype=np.dtype(np.int64))
- test_duck_array_typevar(arrayapi_a)
+ arrayapi_a = nxp.asarray([2.1, 4], dtype=nxp.int64)
+ check_duck_array_typevar(arrayapi_a)
def test_new_namedarray(self) -> None:
dtype_float = np.dtype(np.float32)
Index: xarray-2024.05.0/xarray/tests/test_strategies.py
===================================================================
--- xarray-2024.05.0.orig/xarray/tests/test_strategies.py
+++ xarray-2024.05.0/xarray/tests/test_strategies.py
@@ -1,6 +1,9 @@
+import warnings
+
import numpy as np
import numpy.testing as npt
import pytest
+from packaging.version import Version
pytest.importorskip("hypothesis")
# isort: split
@@ -19,7 +22,6 @@ from xarray.testing.strategies import (
unique_subset_of,
variables,
)
-from xarray.tests import requires_numpy_array_api
ALLOWED_ATTRS_VALUES_TYPES = (int, bool, str, np.ndarray)
@@ -199,7 +201,6 @@ class TestVariablesStrategy:
)
)
- @requires_numpy_array_api
@given(st.data())
def test_make_strategies_namespace(self, data):
"""
@@ -208,16 +209,24 @@ class TestVariablesStrategy:
We still want to generate dtypes not in the array API by default, but this checks we don't accidentally override
the user's choice of dtypes with non-API-compliant ones.
"""
- from numpy import (
- array_api as np_array_api, # requires numpy>=1.26.0, and we expect a UserWarning to be raised
- )
+ if Version(np.__version__) >= Version("2.0.0.dev0"):
+ nxp = np
+ else:
+ # requires numpy>=1.26.0, and we expect a UserWarning to be raised
+ with warnings.catch_warnings():
+ warnings.filterwarnings(
+ "ignore", category=UserWarning, message=".+See NEP 47."
+ )
+ from numpy import ( # type: ignore[no-redef,unused-ignore]
+ array_api as nxp,
+ )
- np_array_api_st = make_strategies_namespace(np_array_api)
+ nxp_st = make_strategies_namespace(nxp)
data.draw(
variables(
- array_strategy_fn=np_array_api_st.arrays,
- dtype=np_array_api_st.scalar_dtypes(),
+ array_strategy_fn=nxp_st.arrays,
+ dtype=nxp_st.scalar_dtypes(),
)
)
Index: xarray-2024.05.0/xarray/core/duck_array_ops.py
===================================================================
--- xarray-2024.05.0.orig/xarray/core/duck_array_ops.py
+++ xarray-2024.05.0/xarray/core/duck_array_ops.py
@@ -142,17 +142,25 @@ around.__doc__ = str.replace(
def isnull(data):
data = asarray(data)
- scalar_type = data.dtype.type
- if issubclass(scalar_type, (np.datetime64, np.timedelta64)):
+
+ xp = get_array_namespace(data)
+ scalar_type = data.dtype
+ if dtypes.is_datetime_like(scalar_type):
# datetime types use NaT for null
# note: must check timedelta64 before integers, because currently
# timedelta64 inherits from np.integer
return isnat(data)
- elif issubclass(scalar_type, np.inexact):
+ elif dtypes.isdtype(scalar_type, ("real floating", "complex floating"), xp=xp):
# float types use NaN for null
xp = get_array_namespace(data)
return xp.isnan(data)
- elif issubclass(scalar_type, (np.bool_, np.integer, np.character, np.void)):
+ elif dtypes.isdtype(scalar_type, ("bool", "integral"), xp=xp) or (
+ isinstance(scalar_type, np.dtype)
+ and (
+ np.issubdtype(scalar_type, np.character)
+ or np.issubdtype(scalar_type, np.void)
+ )
+ ):
# these types cannot represent missing values
return full_like(data, dtype=bool, fill_value=False)
else:
@@ -406,13 +414,22 @@ def _create_nan_agg_method(name, coerce_
if invariant_0d and axis == ():
return values
- values = asarray(values)
+ xp = get_array_namespace(values)
+ values = asarray(values, xp=xp)
- if coerce_strings and values.dtype.kind in "SU":
+ if coerce_strings and dtypes.is_string(values.dtype):
values = astype(values, object)
func = None
- if skipna or (skipna is None and values.dtype.kind in "cfO"):
+ if skipna or (
+ skipna is None
+ and (
+ dtypes.isdtype(
+ values.dtype, ("complex floating", "real floating"), xp=xp
+ )
+ or dtypes.is_object(values.dtype)
+ )
+ ):
nanname = "nan" + name
func = getattr(nanops, nanname)
else:
@@ -477,8 +494,8 @@ def _datetime_nanmin(array):
- numpy nanmin() don't work on datetime64 (all versions at the moment of writing)
- dask min() does not work on datetime64 (all versions at the moment of writing)
"""
- assert array.dtype.kind in "mM"
dtype = array.dtype
+ assert dtypes.is_datetime_like(dtype)
# (NaT).astype(float) does not produce NaN...
array = where(pandas_isnull(array), np.nan, array.astype(float))
array = min(array, skipna=True)
@@ -515,7 +532,7 @@ def datetime_to_numeric(array, offset=No
"""
# Set offset to minimum if not given
if offset is None:
- if array.dtype.kind in "Mm":
+ if dtypes.is_datetime_like(array.dtype):
offset = _datetime_nanmin(array)
else:
offset = min(array)
@@ -527,7 +544,7 @@ def datetime_to_numeric(array, offset=No
# This map_blocks call is for backwards compatibility.
# dask == 2021.04.1 does not support subtracting object arrays
# which is required for cftime
- if is_duck_dask_array(array) and np.issubdtype(array.dtype, object):
+ if is_duck_dask_array(array) and dtypes.is_object(array.dtype):
array = array.map_blocks(lambda a, b: a - b, offset, meta=array._meta)
else:
array = array - offset
@@ -537,11 +554,11 @@ def datetime_to_numeric(array, offset=No
array = np.array(array)
# Convert timedelta objects to float by first converting to microseconds.
- if array.dtype.kind in "O":
+ if dtypes.is_object(array.dtype):
return py_timedelta_to_float(array, datetime_unit or "ns").astype(dtype)
# Convert np.NaT to np.nan
- elif array.dtype.kind in "mM":
+ elif dtypes.is_datetime_like(array.dtype):
# Convert to specified timedelta units.
if datetime_unit:
array = array / np.timedelta64(1, datetime_unit)
@@ -641,7 +658,7 @@ def mean(array, axis=None, skipna=None,
from xarray.core.common import _contains_cftime_datetimes
array = asarray(array)
- if array.dtype.kind in "Mm":
+ if dtypes.is_datetime_like(array.dtype):
offset = _datetime_nanmin(array)
# xarray always uses np.datetime64[ns] for np.datetime64 data
@@ -689,7 +706,9 @@ def cumsum(array, axis=None, **kwargs):
def first(values, axis, skipna=None):
"""Return the first non-NA elements in this array along the given axis"""
- if (skipna or skipna is None) and values.dtype.kind not in "iSU":
+ if (skipna or skipna is None) and not (
+ dtypes.isdtype(values.dtype, "signed integer") or dtypes.is_string(values.dtype)
+ ):
# only bother for dtypes that can hold NaN
if is_chunked_array(values):
return chunked_nanfirst(values, axis)
@@ -700,7 +719,9 @@ def first(values, axis, skipna=None):
def last(values, axis, skipna=None):
"""Return the last non-NA elements in this array along the given axis"""
- if (skipna or skipna is None) and values.dtype.kind not in "iSU":
+ if (skipna or skipna is None) and not (
+ dtypes.isdtype(values.dtype, "signed integer") or dtypes.is_string(values.dtype)
+ ):
# only bother for dtypes that can hold NaN
if is_chunked_array(values):
return chunked_nanlast(values, axis)
Index: xarray-2024.05.0/xarray/core/dtypes.py
===================================================================
--- xarray-2024.05.0.orig/xarray/core/dtypes.py
+++ xarray-2024.05.0/xarray/core/dtypes.py
@@ -4,8 +4,9 @@ import functools
from typing import Any
import numpy as np
+from pandas.api.types import is_extension_array_dtype
-from xarray.core import utils
+from xarray.core import npcompat, utils
# Use as a sentinel value to indicate a dtype appropriate NA value.
NA = utils.ReprObject("<NA>")
@@ -60,22 +61,22 @@ def maybe_promote(dtype: np.dtype) -> tu
# N.B. these casting rules should match pandas
dtype_: np.typing.DTypeLike
fill_value: Any
- if np.issubdtype(dtype, np.floating):
+ if isdtype(dtype, "real floating"):
dtype_ = dtype
fill_value = np.nan
- elif np.issubdtype(dtype, np.timedelta64):
+ elif isinstance(dtype, np.dtype) and np.issubdtype(dtype, np.timedelta64):
# See https://github.com/numpy/numpy/issues/10685
# np.timedelta64 is a subclass of np.integer
# Check np.timedelta64 before np.integer
fill_value = np.timedelta64("NaT")
dtype_ = dtype
- elif np.issubdtype(dtype, np.integer):
+ elif isdtype(dtype, "integral"):
dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64
fill_value = np.nan
- elif np.issubdtype(dtype, np.complexfloating):
+ elif isdtype(dtype, "complex floating"):
dtype_ = dtype
fill_value = np.nan + np.nan * 1j
- elif np.issubdtype(dtype, np.datetime64):
+ elif isinstance(dtype, np.dtype) and np.issubdtype(dtype, np.datetime64):
dtype_ = dtype
fill_value = np.datetime64("NaT")
else:
@@ -118,16 +119,16 @@ def get_pos_infinity(dtype, max_for_int=
-------
fill_value : positive infinity value corresponding to this dtype.
"""
- if issubclass(dtype.type, np.floating):
+ if isdtype(dtype, "real floating"):
return np.inf
- if issubclass(dtype.type, np.integer):
+ if isdtype(dtype, "integral"):
if max_for_int:
return np.iinfo(dtype).max
else:
return np.inf
- if issubclass(dtype.type, np.complexfloating):
+ if isdtype(dtype, "complex floating"):
return np.inf + 1j * np.inf
return INF
@@ -146,24 +147,66 @@ def get_neg_infinity(dtype, min_for_int=
-------
fill_value : positive infinity value corresponding to this dtype.
"""
- if issubclass(dtype.type, np.floating):
+ if isdtype(dtype, "real floating"):
return -np.inf
- if issubclass(dtype.type, np.integer):
+ if isdtype(dtype, "integral"):
if min_for_int:
return np.iinfo(dtype).min
else:
return -np.inf
- if issubclass(dtype.type, np.complexfloating):
+ if isdtype(dtype, "complex floating"):
return -np.inf - 1j * np.inf
return NINF
-def is_datetime_like(dtype):
+def is_datetime_like(dtype) -> bool:
"""Check if a dtype is a subclass of the numpy datetime types"""
- return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)
+ return _is_numpy_subdtype(dtype, (np.datetime64, np.timedelta64))
+
+
+def is_object(dtype) -> bool:
+ """Check if a dtype is object"""
+ return _is_numpy_subdtype(dtype, object)
+
+
+def is_string(dtype) -> bool:
+ """Check if a dtype is a string dtype"""
+ return _is_numpy_subdtype(dtype, (np.str_, np.character))
+
+
+def _is_numpy_subdtype(dtype, kind) -> bool:
+ if not isinstance(dtype, np.dtype):
+ return False
+
+ kinds = kind if isinstance(kind, tuple) else (kind,)
+ return any(np.issubdtype(dtype, kind) for kind in kinds)
+
+
+def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool:
+ """Compatibility wrapper for isdtype() from the array API standard.
+
+ Unlike xp.isdtype(), kind must be a string.
+ """
+ # TODO(shoyer): remove this wrapper when Xarray requires
+ # numpy>=2 and pandas extensions arrays are implemented in
+ # Xarray via the array API
+ if not isinstance(kind, str) and not (
+ isinstance(kind, tuple) and all(isinstance(k, str) for k in kind)
+ ):
+ raise TypeError(f"kind must be a string or a tuple of strings: {repr(kind)}")
+
+ if isinstance(dtype, np.dtype):
+ return npcompat.isdtype(dtype, kind)
+ elif is_extension_array_dtype(dtype):
+ # we never want to match pandas extension array dtypes
+ return False
+ else:
+ if xp is None:
+ xp = np
+ return xp.isdtype(dtype, kind)
def result_type(
@@ -184,12 +227,26 @@ def result_type(
-------
numpy.dtype for the result.
"""
- types = {np.result_type(t).type for t in arrays_and_dtypes}
+ from xarray.core.duck_array_ops import get_array_namespace
+
+ # TODO(shoyer): consider moving this logic into get_array_namespace()
+ # or another helper function.
+ namespaces = {get_array_namespace(t) for t in arrays_and_dtypes}
+ non_numpy = namespaces - {np}
+ if non_numpy:
+ [xp] = non_numpy
+ else:
+ xp = np
+
+ types = {xp.result_type(t) for t in arrays_and_dtypes}
- for left, right in PROMOTE_TO_OBJECT:
- if any(issubclass(t, left) for t in types) and any(
- issubclass(t, right) for t in types
- ):
- return np.dtype(object)
+ if any(isinstance(t, np.dtype) for t in types):
+ # only check if there's numpy dtypes the array API does not
+ # define the types we're checking for
+ for left, right in PROMOTE_TO_OBJECT:
+ if any(np.issubdtype(t, left) for t in types) and any(
+ np.issubdtype(t, right) for t in types
+ ):
+ return xp.dtype(object)
- return np.result_type(*arrays_and_dtypes)
+ return xp.result_type(*arrays_and_dtypes)
Index: xarray-2024.05.0/xarray/namedarray/core.py
===================================================================
--- xarray-2024.05.0.orig/xarray/namedarray/core.py
+++ xarray-2024.05.0/xarray/namedarray/core.py
@@ -470,10 +470,28 @@ class NamedArray(NamedArrayAggregations,
If the underlying data array does not include ``nbytes``, estimates
the bytes consumed based on the ``size`` and ``dtype``.
"""
+ from xarray.namedarray._array_api import _get_data_namespace
+
if hasattr(self._data, "nbytes"):
return self._data.nbytes # type: ignore[no-any-return]
+
+ if hasattr(self.dtype, "itemsize"):
+ itemsize = self.dtype.itemsize
+ elif isinstance(self._data, _arrayapi):
+ xp = _get_data_namespace(self)
+
+ if xp.isdtype(self.dtype, "bool"):
+ itemsize = 1
+ elif xp.isdtype(self.dtype, "integral"):
+ itemsize = xp.iinfo(self.dtype).bits // 8
+ else:
+ itemsize = xp.finfo(self.dtype).bits // 8
else:
- return self.size * self.dtype.itemsize
+ raise TypeError(
+ "cannot compute the number of bytes (no array API nor nbytes / itemsize)"
+ )
+
+ return self.size * itemsize
@property
def dims(self) -> _Dims:
Index: xarray-2024.05.0/xarray/tests/test_dtypes.py
===================================================================
--- xarray-2024.05.0.orig/xarray/tests/test_dtypes.py
+++ xarray-2024.05.0/xarray/tests/test_dtypes.py
@@ -4,6 +4,18 @@ import numpy as np
import pytest
from xarray.core import dtypes
+from xarray.tests import requires_array_api_strict
+
+try:
+ import array_api_strict
+except ImportError:
+
+ class DummyArrayAPINamespace:
+ bool = None
+ int32 = None
+ float64 = None
+
+ array_api_strict = DummyArrayAPINamespace
@pytest.mark.parametrize(
@@ -58,7 +70,6 @@ def test_inf(obj) -> None:
@pytest.mark.parametrize(
"kind, expected",
[
- ("a", (np.dtype("O"), "nan")), # dtype('S')
("b", (np.float32, "nan")), # dtype('int8')
("B", (np.float32, "nan")), # dtype('uint8')
("c", (np.dtype("O"), "nan")), # dtype('S1')
@@ -98,3 +109,54 @@ def test_nat_types_membership() -> None:
assert np.datetime64("NaT").dtype in dtypes.NAT_TYPES
assert np.timedelta64("NaT").dtype in dtypes.NAT_TYPES
assert np.float64 not in dtypes.NAT_TYPES
+
+
+@pytest.mark.parametrize(
+ ["dtype", "kinds", "xp", "expected"],
+ (
+ (np.dtype("int32"), "integral", np, True),
+ (np.dtype("float16"), "real floating", np, True),
+ (np.dtype("complex128"), "complex floating", np, True),
+ (np.dtype("U"), "numeric", np, False),
+ pytest.param(
+ array_api_strict.int32,
+ "integral",
+ array_api_strict,
+ True,
+ marks=requires_array_api_strict,
+ id="array_api-int",
+ ),
+ pytest.param(
+ array_api_strict.float64,
+ "real floating",
+ array_api_strict,
+ True,
+ marks=requires_array_api_strict,
+ id="array_api-float",
+ ),
+ pytest.param(
+ array_api_strict.bool,
+ "numeric",
+ array_api_strict,
+ False,
+ marks=requires_array_api_strict,
+ id="array_api-bool",
+ ),
+ ),
+)
+def test_isdtype(dtype, kinds, xp, expected) -> None:
+ actual = dtypes.isdtype(dtype, kinds, xp=xp)
+ assert actual == expected
+
+
+@pytest.mark.parametrize(
+ ["dtype", "kinds", "xp", "error", "pattern"],
+ (
+ (np.dtype("int32"), "foo", np, (TypeError, ValueError), "kind"),
+ (np.dtype("int32"), np.signedinteger, np, TypeError, "kind"),
+ (np.dtype("float16"), 1, np, TypeError, "kind"),
+ ),
+)
+def test_isdtype_error(dtype, kinds, xp, error, pattern):
+ with pytest.raises(error, match=pattern):
+ dtypes.isdtype(dtype, kinds, xp=xp)
Index: xarray-2024.05.0/xarray/core/npcompat.py
===================================================================
--- xarray-2024.05.0.orig/xarray/core/npcompat.py
+++ xarray-2024.05.0/xarray/core/npcompat.py
@@ -28,3 +28,33 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+try:
+ # requires numpy>=2.0
+ from numpy import isdtype # type: ignore[attr-defined,unused-ignore]
+except ImportError:
+ import numpy as np
+
+ dtype_kinds = {
+ "bool": np.bool_,
+ "signed integer": np.signedinteger,
+ "unsigned integer": np.unsignedinteger,
+ "integral": np.integer,
+ "real floating": np.floating,
+ "complex floating": np.complexfloating,
+ "numeric": np.number,
+ }
+
+ def isdtype(dtype, kind):
+ kinds = kind if isinstance(kind, tuple) else (kind,)
+
+ unknown_dtypes = [kind for kind in kinds if kind not in dtype_kinds]
+ if unknown_dtypes:
+ raise ValueError(f"unknown dtype kinds: {unknown_dtypes}")
+
+ # verified the dtypes already, no need to check again
+ translated_kinds = [dtype_kinds[kind] for kind in kinds]
+ if isinstance(dtype, np.generic):
+ return any(isinstance(dtype, kind) for kind in translated_kinds)
+ else:
+ return any(np.issubdtype(dtype, kind) for kind in translated_kinds)