Skip to content

Commit

Permalink
checkpoint prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Jan 26, 2024
1 parent 008f4d4 commit fbc7832
Showing 1 changed file with 231 additions and 0 deletions.
231 changes: 231 additions & 0 deletions dask/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
import contextlib
import glob
import os
import pickle
from collections import defaultdict
from importlib import import_module

import dask.dataframe as dd
from dask.base import tokenize
from dask.blockwise import BlockIndex
from dask.utils import typename

from distributed import wait, default_client
from distributed.protocol import dask_deserialize, dask_serialize


class Handler:
"""Base class for format-specific checkpointing handlers
A ``Handler`` object will be responsible for a single partition.
"""
format = None # General format label

def __init__(self, path, backend, index, **kwargs):
self.path = path
self.backend = backend
self.index = index
self.kwargs = kwargs

@classmethod
def clean(cls, dirpath):
"""Clean the target directory"""
import shutil

if os.path.isdir(dirpath):
with contextlib.suppress(FileNotFoundError):
shutil.rmtree(dirpath)

@classmethod
def prepare(cls, dirpath):
"""Create the target directory"""
os.makedirs(dirpath, exist_ok=True)

@classmethod
def save(cls, part, path, index, id):
"""Persist the target partition to disk"""
raise NotImplementedError() # Logic depends on format

def load(self):
"""Collect the saved partition"""
raise NotImplementedError() # Logic depends on format


@dask_serialize.register(Handler)
def _serialize_unloaded(obj):
# Make sure we read the partition into memory if
# this partition is moved to a different worker
return None, [pickle.dumps(obj.load())]


@dask_deserialize.register(Handler)
def _deserialize_unloaded(header, frames):
# Deserializing a `Handler` object returns the wrapped data
return pickle.loads(frames[0])


class ParquetHandler(Handler):
"""Parquet-specific checkpointing handler for DataFrame collections"""
format = "parquet"

@classmethod
def save(cls, part, path, index):
fn = f"{path}/part.{index[0]}.parquet"
part.to_parquet(fn)
return index[0]

def load(self):
lib = import_module(self.backend)
fn = glob.glob(f"{self.path}/*.{self.index}.parquet")
return lib.read_parquet(fn, **self.kwargs)


class Checkpoint:
"""Checkpoint a Dask collection on disk
The storage location does not need to be shared between workers.
"""

def __init__(
self,
client,
meta,
handler,
path,
id,
load_kwargs,
):
self.client = client
self.meta = meta
self.backend = typename(meta).partition(".")[0]
self.handler = handler
self.path = path
self.id = id
self.load_kwargs = load_kwargs or {}
self._valid = True

def __repr__(self):
ctype = type(self.meta).__name__
path = self.path
fmt = self.handler.format
return f"Checkpoint({ctype})<path={path}, format={fmt}>"

@classmethod
def create(
cls,
df,
dirpath,
id=None,
format="parquet",
overwrite=False,
compute_kwargs=None,
load_kwargs=None,
**save_kwargs,
):
"""Create a new Checkpoint object"""

# Get handler
if format == "parquet":
handler = ParquetHandler
else:
# Only parquet supported for now
raise NotImplementedError()

id = id or tokenize(df, dirpath, format)
meta = df._meta

client = default_client()

wait(client.run(handler.prepare, dirpath))
path = f"{dirpath}/{id}"

if overwrite:
wait(client.run(handler.clean, path))

wait(client.run(handler.prepare, path))

result = df.map_partitions(
handler.save,
path,
BlockIndex((df.npartitions,)),
meta=meta,
enforce_metadata=False,
**save_kwargs,
).persist(**(compute_kwargs or {}))
wait(result)
del result

return cls(
client,
meta,
handler,
path,
id,
load_kwargs,
)

def load(self):
"""Load a checkpointed collection
Note that this will not immediately persist the partitions
in memory. Rather, it will output a lazy Dask collection.
"""
if not self._valid:
raise RuntimeError("This checkpoint is no longer valid")

#
# Find out which partition indices are stored on each worker
#
def get_indices(path):
# Assume file-name is something like: <name>.<index>.<fmt>
return {int(fn.split(".")[-2]) for fn in glob.glob(path + "/*")}

worker_indices = self.client.run(get_indices, self.path)

summary = defaultdict(list)
for worker, indices in worker_indices.items():
for index in indices:
summary[index].append(worker)

#
# Convert each checkpointed partition to a `Handler` object
#
assignments = {}
futures = []
for i, (worker, indices) in enumerate(summary.items()):
assignments[worker] = indices[i % len(indices)]
futures.append(
self.client.submit(
self.handler,
self.path,
self.backend,
i,
workers=[assignments[i]],
**self.load_kwargs,
)
)
wait(futures)

#
# Crate a new collection from the delayed `Handler` objects
#
meta = self.meta
return dd.from_delayed(futures, meta=meta, verify_meta=False).map_partitions(
self._load_partition,
meta=meta,
)

@staticmethod
def _load_partition(obj):
"""Load a checkpointed partition"""
if isinstance(obj, Handler):
return obj.load()
return obj

def clean(self):
"""Clean up this checkpoint"""
wait(self.client.run(self.handler.clean, self.path))
self._valid = False


checkpoint = Checkpoint.create # Expose "public" API

0 comments on commit fbc7832

Please sign in to comment.