Skip to content

Commit

Permalink
small tweaks - cleanup is still a problem
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Jan 26, 2024
1 parent fbc7832 commit f9b07fa
Showing 1 changed file with 29 additions and 17 deletions.
46 changes: 29 additions & 17 deletions dask/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dask.blockwise import BlockIndex
from dask.utils import typename

from distributed import wait, default_client
from distributed import get_worker, wait, default_client
from distributed.protocol import dask_deserialize, dask_serialize


Expand Down Expand Up @@ -72,6 +72,7 @@ class ParquetHandler(Handler):
def save(cls, part, path, index):
fn = f"{path}/part.{index[0]}.parquet"
part.to_parquet(fn)
# return get_worker().worker_address
return index[0]

def load(self):
Expand All @@ -88,14 +89,14 @@ class Checkpoint:

def __init__(
self,
client,
npartitions,
meta,
handler,
path,
id,
load_kwargs,
):
self.client = client
self.npartitions = npartitions
self.meta = meta
self.backend = typename(meta).partition(".")[0]
self.handler = handler
Expand All @@ -110,6 +111,9 @@ def __repr__(self):
fmt = self.handler.format
return f"Checkpoint({ctype})<path={path}, format={fmt}>"

def __del__(self):
self.clean()

@classmethod
def create(
cls,
Expand Down Expand Up @@ -144,19 +148,19 @@ def create(

wait(client.run(handler.prepare, path))

result = df.map_partitions(
handler.save,
path,
BlockIndex((df.npartitions,)),
meta=meta,
enforce_metadata=False,
**save_kwargs,
).persist(**(compute_kwargs or {}))
wait(result)
del result
npartitions = len(
df.map_partitions(
handler.save,
path,
BlockIndex((df.npartitions,)),
meta=meta,
enforce_metadata=False,
**save_kwargs,
).compute(**(compute_kwargs or {}))
)

return cls(
client,
npartitions,
meta,
handler,
path,
Expand All @@ -173,20 +177,27 @@ def load(self):
if not self._valid:
raise RuntimeError("This checkpoint is no longer valid")

#
# Get client and check workers
#
client = default_client()

#
# Find out which partition indices are stored on each worker
#
def get_indices(path):
# Assume file-name is something like: <name>.<index>.<fmt>
return {int(fn.split(".")[-2]) for fn in glob.glob(path + "/*")}

worker_indices = self.client.run(get_indices, self.path)
worker_indices = client.run(get_indices, self.path)

summary = defaultdict(list)
for worker, indices in worker_indices.items():
for index in indices:
summary[index].append(worker)

assert len(summary) == self.npartitions, "Load failed."

#
# Convert each checkpointed partition to a `Handler` object
#
Expand All @@ -195,7 +206,7 @@ def get_indices(path):
for i, (worker, indices) in enumerate(summary.items()):
assignments[worker] = indices[i % len(indices)]
futures.append(
self.client.submit(
client.submit(
self.handler,
self.path,
self.backend,
Expand Down Expand Up @@ -224,7 +235,8 @@ def _load_partition(obj):

def clean(self):
"""Clean up this checkpoint"""
wait(self.client.run(self.handler.clean, self.path))
client = default_client()
wait(client.run(self.handler.clean, self.path))
self._valid = False


Expand Down

0 comments on commit f9b07fa

Please sign in to comment.