diff --git a/dask/checkpoint.py b/dask/checkpoint.py index 8f24a3666f8..eeeddddaf37 100644 --- a/dask/checkpoint.py +++ b/dask/checkpoint.py @@ -10,7 +10,7 @@ from dask.blockwise import BlockIndex from dask.utils import typename -from distributed import wait, default_client +from distributed import get_worker, wait, default_client from distributed.protocol import dask_deserialize, dask_serialize @@ -72,6 +72,7 @@ class ParquetHandler(Handler): def save(cls, part, path, index): fn = f"{path}/part.{index[0]}.parquet" part.to_parquet(fn) + # return get_worker().worker_address return index[0] def load(self): @@ -88,14 +89,14 @@ class Checkpoint: def __init__( self, - client, + npartitions, meta, handler, path, id, load_kwargs, ): - self.client = client + self.npartitions = npartitions self.meta = meta self.backend = typename(meta).partition(".")[0] self.handler = handler @@ -110,6 +111,9 @@ def __repr__(self): fmt = self.handler.format return f"Checkpoint({ctype})" + def __del__(self): + self.clean() + @classmethod def create( cls, @@ -144,19 +148,19 @@ def create( wait(client.run(handler.prepare, path)) - result = df.map_partitions( - handler.save, - path, - BlockIndex((df.npartitions,)), - meta=meta, - enforce_metadata=False, - **save_kwargs, - ).persist(**(compute_kwargs or {})) - wait(result) - del result + npartitions = len( + df.map_partitions( + handler.save, + path, + BlockIndex((df.npartitions,)), + meta=meta, + enforce_metadata=False, + **save_kwargs, + ).compute(**(compute_kwargs or {})) + ) return cls( - client, + npartitions, meta, handler, path, @@ -173,6 +177,11 @@ def load(self): if not self._valid: raise RuntimeError("This checkpoint is no longer valid") + # + # Get client and check workers + # + client = default_client() + # # Find out which partition indices are stored on each worker # @@ -180,13 +189,15 @@ def get_indices(path): # Assume file-name is something like: .. return {int(fn.split(".")[-2]) for fn in glob.glob(path + "/*")} - worker_indices = self.client.run(get_indices, self.path) + worker_indices = client.run(get_indices, self.path) summary = defaultdict(list) for worker, indices in worker_indices.items(): for index in indices: summary[index].append(worker) + assert len(summary) == self.npartitions, "Load failed." + # # Convert each checkpointed partition to a `Handler` object # @@ -195,7 +206,7 @@ def get_indices(path): for i, (worker, indices) in enumerate(summary.items()): assignments[worker] = indices[i % len(indices)] futures.append( - self.client.submit( + client.submit( self.handler, self.path, self.backend, @@ -224,7 +235,8 @@ def _load_partition(obj): def clean(self): """Clean up this checkpoint""" - wait(self.client.run(self.handler.clean, self.path)) + client = default_client() + wait(client.run(self.handler.clean, self.path)) self._valid = False