From a825faaf60dc75e0365f18a0f24acb0fe288b263 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 9 Feb 2021 16:03:14 +0000 Subject: [PATCH 1/2] Compatibility with dask 2021.02.0 --- ci/requirements/environment-windows.yml | 2 +- ci/requirements/environment.yml | 2 +- xarray/core/dataset.py | 32 +++++++++++++++++++------ 3 files changed, 27 insertions(+), 9 deletions(-) Index: xarray-0.16.2/xarray/core/dataset.py =================================================================== --- xarray-0.16.2.orig/xarray/core/dataset.py +++ xarray-0.16.2/xarray/core/dataset.py @@ -809,13 +809,12 @@ class Dataset(Mapping, ImplementsDataset import dask info = [ - (True, k, v.__dask_postcompute__()) + (k, None) + v.__dask_postcompute__() if dask.is_dask_collection(v) - else (False, k, v) + else (k, v, None, None) for k, v in self._variables.items() ] - args = ( - info, + construct_direct_args = ( self._coord_names, self._dims, self._attrs, @@ -823,19 +822,18 @@ class Dataset(Mapping, ImplementsDataset self._encoding, self._file_obj, ) - return self._dask_postcompute, args + return self._dask_postcompute, (info, construct_direct_args) def __dask_postpersist__(self): import dask info = [ - (True, k, v.__dask_postpersist__()) + (k, None, v.__dask_keys__()) + v.__dask_postpersist__() if dask.is_dask_collection(v) - else (False, k, v) + else (k, v, None, None, None) for k, v in self._variables.items() ] - args = ( - info, + construct_direct_args = ( self._coord_names, self._dims, self._attrs, @@ -843,45 +841,37 @@ class Dataset(Mapping, ImplementsDataset self._encoding, self._file_obj, ) - return self._dask_postpersist, args + return self._dask_postpersist, (info, construct_direct_args) @staticmethod - def _dask_postcompute(results, info, *args): + def _dask_postcompute(results, info, construct_direct_args): variables = {} - results2 = list(results[::-1]) - for is_dask, k, v in info: - if is_dask: - func, args2 = v - r = results2.pop() - result = func(r, *args2) + results_iter = iter(results) + for k, v, rebuild, rebuild_args in info: + if v is None: + variables[k] = rebuild(next(results_iter), *rebuild_args) else: - result = v - variables[k] = result + variables[k] = v - final = Dataset._construct_direct(variables, *args) + final = Dataset._construct_direct(variables, *construct_direct_args) return final @staticmethod - def _dask_postpersist(dsk, info, *args): + def _dask_postpersist(dsk, info, construct_direct_args): + from dask.optimization import cull + variables = {} # postpersist is called in both dask.optimize and dask.persist # When persisting, we want to filter out unrelated keys for # each Variable's task graph. - is_persist = len(dsk) == len(info) - for is_dask, k, v in info: - if is_dask: - func, args2 = v - if is_persist: - name = args2[1][0] - dsk2 = {k: v for k, v in dsk.items() if k[0] == name} - else: - dsk2 = dsk - result = func(dsk2, *args2) + for k, v, dask_keys, rebuild, rebuild_args in info: + if v is None: + dsk2, _ = cull(dsk, dask_keys) + variables[k] = rebuild(dsk2, *rebuild_args) else: - result = v - variables[k] = result + variables[k] = v - return Dataset._construct_direct(variables, *args) + return Dataset._construct_direct(variables, *construct_direct_args) def compute(self, **kwargs) -> "Dataset": """Manually trigger loading and/or computation of this dataset's data