Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DAOS-16362: pydaos.torch checkpointing #15691

Open
wants to merge 28 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
2ecf810
client: initialize DAOS on Dataset creation rather than on module import
0xE0F Jan 2, 2025
a15fe94
tests: PoC of torch tests
0xE0F Jan 2, 2025
1303be9
client: torch module needs lazy DAOS init
0xE0F Jan 3, 2025
caa6fb4
ftests: add tests for pydaos.torch.Dataset
0xE0F Jan 3, 2025
70cd817
client: add checkpoint interface to pydaos.torch
0xE0F Jan 7, 2025
eaf6b98
ftest: torch checkpoint test
0xE0F Jan 7, 2025
8692003
docs: for pydaos.torch module
0xE0F Jan 7, 2025
de147e4
Linter fixes
0xE0F Jan 7, 2025
9cac883
Apply suggestions from code review
enakta Jan 7, 2025
b313874
Fixes for copyright linter
0xE0F Jan 7, 2025
5790c94
client: pass mode, oflags, class and chunk size to torch shim layer
0xE0F Jan 8, 2025
65e7c9c
nlt: remove torch tests in favour of functional tests
0xE0F Jan 8, 2025
eb0059d
torch: uniform way to return error
0xE0F Jan 8, 2025
bd11078
ftest: add iterable dataset to tests
0xE0F Jan 8, 2025
b6512c9
linter: oflags -> open_flags
0xE0F Jan 8, 2025
770f45d
torch: add parallel, chunked checkpoint writes
0xE0F Jan 9, 2025
2f3488f
Linter fixes
0xE0F Jan 9, 2025
20d417a
ftests: bump timeout for the checkpoint tests
0xE0F Jan 9, 2025
9354f20
Bandit's checker suggestions
enakta Jan 9, 2025
21d8222
ftest: add torch to the requirements
0xE0F Jan 9, 2025
82c629f
ftests: drop version requirement for pytorch
0xE0F Jan 13, 2025
57301cd
ftest: bump pool size for tests
0xE0F Jan 16, 2025
4e189fc
ftest: be modest with vm test machines
0xE0F Jan 19, 2025
d1e2f51
client: torch: use old multiprocess API to support python 3.6
0xE0F Jan 20, 2025
f19fe56
torch: linter fix
0xE0F Jan 20, 2025
b17a6f7
ftest: bump timout for checkpoint tests
0xE0F Jan 20, 2025
038fb6f
Merge branch 'master' into 0xe0f/pydaos.torch.checkpointing
0xE0F Jan 21, 2025
d1d5580
ftest: restarting tests with latest master
0xE0F Jan 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions docs/user/pytorch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# DAOS pytorch interface

PyTorch is fully featured framework for building deep learning models and training them.
It is widely used in the research community and in the industry.
PyTorch allows loading data from various sources and DAOS can be used as a storage backend for training data and models' checkpoints.

[DFS plugin](https://github.com/daos-stack/daos/tree/master/src/client/pydaos/torch) implements PyTorch interfaces for loading data from DAOS: Map and Iterable style datasets.
This allows to use all features of `torch.utils.data.DataLoader` to load data from DAOS POSIX containers, including parallel data loading, batching, shuffling, etc.

## Installation

To install the plugin, you need to have PyTorch installed. Please follow the official [PyTorch installation guide](https://pytorch.org/get-started/).
`pydoas.torch` module comes with DAOS client package. Please refer to DAOS installation guide for your distribution.


## Usage

To use DAOS as a storage backend for PyTorch, you need to have DAOS agent running on the nodes where PyTorch is running and correctly configured ACLs for the container.

Here's an example of how to use Map-style dataset with DAOS directly:

```python
import torch
from torch.utils.data import DataLoader
from pydaos.torch import Dataset

dataset = Dataset(pool='pool', container='container', path='/training/samples')
# That's it, when the Dataset is created, it will connect to DAOS, scan the namaspace of the container
# and will be ready to load data from it.

for i, sample in enumerate(dataset):
print(f"Sample {i} size: {len(sample)}")
```

To use Dataset with DataLoader, you can pass it directly to DataLoader constructor:

```python

dataloader = DataLoader(dataset,
batch_size=4,
shuffle=True,
num_workers=4,
worker_init_fn=dataset.worker_init)

# and use DataLoader as usual
for batch in dataloader:
print(f"Batch size: {len(batch)}")
```

The only notable difference is that you need to set `worker_init_fn` method of the dataset to correctly initialize the DAOS connection in the worker processes.

## Checkpoints

DAOS can be used to store model checkpoints as well.
PyTorch provides a way to save and load model checkpoints using [torch.save](https://pytorch.org/docs/main/generated/torch.save.html) and [torch.load](https://pytorch.org/docs/main/generated/torch.load.html) functions

`pydaos.torch` provides a way to save and load model checkpoints directly to/from DAOS container (could be same or different container than the one used for data).:

```python
import torch
from pydaos.torch import Checkpoint

# ...

chkp = Checkpoint(pool, cont, prefix='/training/checkpoints')

with chkp.writer('model.pt') as w:
torch.save(model.state_dict(), w)

# Later, to load the model

with chkp.reader('model.pt') as r:
torch.load(r)

```

See [pydaos.torch](https://github.com/daos-stack/daos/blob/master/src/client/pydaos/torch/Readme.md) plugin for an example of how to use checkpoints with DLIO benchmark
1 change: 1 addition & 0 deletions requirements-ftest.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ avocado-framework-plugin-varianter-yaml-to-mux==82
clustershell
paramiko
distro
torch
71 changes: 71 additions & 0 deletions src/client/pydaos/torch/Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,74 @@ for i in range(1, cols * rows + 1):
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
```


### Checkpoint interface

Torch framework provides a way to save and load model's checkpoints: `torch.save` and `torch.load` functions are used to save and load the model state dictionary.
The `torch.save` function expects a state dictionary object and a file like object `Union[str, PathLike, BinaryIO, IO[bytes]]`.
To implement such interface, `pydaos.torch.WriteBuffer` class is introduced, which is a wrapper around `io.BufferedIOBase` object, behaving like a writable stream.
`WriteBuffer` can operate in two modes: in-memory buffer and chunked buffer. In-memory buffer accumulates data in memory and writes it to the DAOS container when `close()` method is called.
Chunked buffer writes the data to the DAOS container in chunks of fixed size. There are optional parameters to limit number of chunks in-flight and number of worker processes to use.
Implementation of the loader is pretty straightforward - it reads the data from the file with existing API and returns it as a buffer.

For convenience, the `pydoas.torch.Checkpoint` class is provided that manages the DAOS connections and provides `reader` and `writer` methods.


Example of using the checkpointing interface in DLIO benchmark:

```python
import logging
import torch
from pydaos.torch import Checkpoint as DaosCheckpoint

from dlio_benchmark.checkpointing.base_checkpointing import BaseCheckpointing
from dlio_benchmark.utils.utility import Profile
from dlio_benchmark.utils.config import ConfigArguments

from dlio_benchmark.common.constants import MODULE_CHECKPOINT

dlp = Profile(MODULE_CHECKPOINT)


class PyDaosTorchCheckpointing(BaseCheckpointing):
__instance = None

@staticmethod
def get_instance():
""" Static access method. """
if PyDaosTorchCheckpointing.__instance is None:
logging.basicConfig(level=logging.INFO)
PyDaosTorchCheckpointing.__instance = PyDaosTorchCheckpointing()
return PyDaosTorchCheckpointing.__instance

@dlp.log_init
def __init__(self):
super().__init__("pt")

args = ConfigArguments.get_instance()
prefix = args.checkpoint_folder
pool = args.checkpoint_daos_pool
cont = args.checkpoint_daos_cont

logging.info(f"Checkpointing is set to DAOS pool: {pool}, container: {cont} with prefix: {prefix}")
self.ckpt = DaosCheckpoint(pool, cont, prefix)

@dlp.log
def get_tensor(self, size):
return torch.randint(high=1, size=(size,), dtype=torch.int8)

@dlp.log
def save_state(self, suffix, state):
name = self.get_name(suffix)
with self.ckpt.writer(name) as f:
torch.save(state, f)

@dlp.log
def checkpoint(self, epoch, step_number):
super().checkpoint(epoch, step_number)

@dlp.log
def finalize(self):
super().finalize()
```
64 changes: 52 additions & 12 deletions src/client/pydaos/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,72 @@
# (C) Copyright 2024 Intel Corporation.
# (C) Copyright 2024 Google LLC
# (C) Copyright 2024 Enakta Labs Ltd
# (C) Copyright 2024-2025 Intel Corporation.
# (C) Copyright 2025 Hewlett Packard Enterprise Development LP
# (C) Copyright 2024-2025 Google LLC
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should probably not touch those.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Linting was failing without these 🤷‍♂️

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

utils/cq/check_update_copyright.sh

Just add handling for Enakta there and you should be fine

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can still land it with copyright issues though if you just want to change them

# (C) Copyright 2024-2025 Enakta Labs Ltd
#
# SPDX-License-Identifier: BSD-2-Clause-Patent
#
# pylint: disable=cyclic-import
"""
PyTorch DAOS Module allowing using DFS as Dataset
"""

import atexit

from . import torch_shim # pylint: disable=relative-beyond-top-level,import-self

DAOS_MAGIC = 0x7A8B


# The module loader procedure guarantees that __init__.py is going to be run only once
_rc = torch_shim.module_init()
if _rc != 0:
raise ValueError(f"Could not initialize DAOS module: rc={_rc}")
class DaosClient():
# pylint: disable=too-few-public-methods
# pylint: disable=attribute-defined-outside-init
"""
DaosClient is responsible for handling DAOS init/fini.

The class implements the Singleton pattern and only
allows a single instance to be instantiated during
the lifetime of a process.
"""
_instance = None

@classmethod
def cleanup(cls):
"""Trigger the instance cleanup process."""
if cls._instance is None:
return
cls._instance = None

def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
# pylint: disable=protected-access
cls._instance._open()
return cls._instance

def _open(self):
# Initialize DAOS
self.connected = False
_rc = torch_shim.module_init(DAOS_MAGIC)
if _rc != 0:
raise ValueError(f"Could not initialize DAOS module: rc={_rc}")
self.connected = True

def _close(self):
if not self.connected:
return
_rc = torch_shim.module_fini(DAOS_MAGIC)
if _rc != 0:
raise ValueError(f"Could not finalize DAOS module: rc={_rc}")
self.connected = False

def __del__(self):
if not torch_shim or not self.connected:
return
self._close()


@atexit.register
def _fini():
rc = torch_shim.module_fini()
if rc != 0:
raise ValueError(f"Could not finalize DAOS module, rc={rc}")
def _cleanup():
DaosClient.cleanup()


from .torch_api import * # noqa: F403,E402
Expand Down
Loading
Loading