Skip to content

Commit

Permalink
add HuggingFace streaming support in data input pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
aireenmei committed Oct 4, 2024
1 parent 33bb598 commit cc51757
Show file tree
Hide file tree
Showing 21 changed files with 854 additions and 306 deletions.
4 changes: 4 additions & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ This folder contains documentation for getting started with and using MaxDiffusi
## Training

* **[Common Training Guide](train_README.md)** - Provides a comprehensive guide to training MaxDiffusion models, including script usage, configuration options, and sharding strategies.

## Data Input

* **[Common Data Input Guide](data_README.md)** - Provides a comprehensive guide to data input pipelines.
50 changes: 50 additions & 0 deletions docs/data_README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Data Input Guide

## Overview
Currently MaxDiffusion supports 3 data input pipelines, controlled by the flag `dataset_type`
| Pipeline | Dataset Location | Dataset formats | Features or limitations |
| -------- | ---------------- | --------------- | ----------------------- |
| HuggingFace (hf) | datasets in HuggingFace Hub or local/Cloud Storage | Formats supported in HF Hub: parquet, arrow, json, csv, txt | data are not loaded in memory but streamed from the saved location, good for big dataset |
| tf | dataset will be downaloaded form HuggingFace Hub to disk | Formats supported in HF Hub: parquet, arrow, json, csv, txt | Will read the whole dataset into memory, works for small dataset |
| tfrecord | local/Cloud Storage | tfrecord | data are not loaded in memory but streamed from the saved location, good for big dataset |

## Usage examples

### HuggingFace Streaming (dataset_type=hf)
#### Example config for streaming from HuggingFace Hub (no download needed):
```
dataset_type: hf
dataset_name: BleachNick/UltraEdit_500k # for using https://huggingface.co/datasets/BleachNick/UltraEdit_500k
image_column: source_image
caption_column: source_caption
train_split: FreeForm
hf_access_token: '' # provide token if using gated dataset or tokenizer
```
#### Example config for streaming from downloaded data in a GCS bucket:
```
dataset_type: hf
dataset_name: parquet # or json, arrow, etc.
hf_train_files: gs://<bucket>/<folder>/*-train-*.parquet # match the train files
```

### tf.data in-memory dataset (dataset_type=tf)
#### Example config
```
dataset_type: tf
dataset_name: diffusers/pokemon-gpt4-captions # will download https://huggingface.co/datasets/diffusers/pokemon-gpt4-captions
dataset_save_location: '/tmp/pokemon-gpt4-captions_xl'
# If cache_latents_text_encoder_outputs=True, apply vae to images and encode text when downloading dataset,
# the saved dataset contains latents and text encoder outputs.
cache_latents_text_encoder_outputs: True
```

### tf.data.TFRecordDataset (dataset_type=tfrecord)
#### Example config
```
dataset_type: tfrecord
train_data_dir: gs://<bucket>/<folder> # will use all TFRecord files under the directory
```

## Best Practice
### Multihost Dataloading
In multihost environment, if use a streaming type of input pipeline and the data format only supports sequential reads (dataset_type in (hf, tfrecord in MaxDiffusion)), the most performant way is to have each data file only accessed by one host, and each host access a subset of data files (shuffle is within the subset of files). This requires (# of data files) > (# of hosts loading data). We recommand users to reshard the dataset if this requirement is not met.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
jax>=0.4.30
jaxlib>=0.4.30
grain-nightly
google-cloud-storage==2.17.0
absl-py
datasets
Expand All @@ -22,6 +23,6 @@ tensorflow>=2.17.0
tensorflow-datasets>=4.9.6
ruff>=0.1.5,<=0.2
git+https://github.com/mlperf/logging.git
opencv-python==4.10.0.84
opencv-python-headless==4.10.0.84
orbax-checkpoint>=0.5.20
tokenizers==0.20.0
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, config, checkpoint_type):
self.rng = jax.random.PRNGKey(self.config.seed)
devices_array = max_utils.create_device_mesh(config)
self.mesh = Mesh(devices_array, self.config.mesh_axes)
self.total_train_batch_size = max_utils.get_global_batch_size(self.config)
self.total_train_batch_size = self.config.total_train_batch_size

self.checkpoint_manager = create_orbax_checkpoint_manager(
self.config.checkpoint_dir, enable_checkpointing=True, save_interval_steps=1, checkpoint_type=checkpoint_type
Expand Down
19 changes: 12 additions & 7 deletions src/maxdiffusion/configs/base14.yml
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,21 @@ ici_tensor_parallelism: 1
# Dataset
# Replace with dataset path or train_data_dir. One has to be set.
dataset_name: 'diffusers/pokemon-gpt4-captions'
# saves transformed dataset of dataset_name.
train_split: 'train'
dataset_type: 'tf'
cache_latents_text_encoder_outputs: True
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
# only apply to small dataset that fits in memory
# prepare image latents and text encoder outputs
# Reduce memory consumption and reduce step time during training
# transformed dataset is saved at dataset_save_location
dataset_save_location: '/tmp/pokemon-gpt4-captions_sd15'
train_data_dir: ''
dataset_config_name: ''
jax_cache_dir: ''
hf_data_dir: ''
hf_train_files: ''
hf_access_token: ''
image_column: 'image'
caption_column: 'text'
resolution: 512
Expand All @@ -145,11 +155,6 @@ enable_data_shuffling: True
# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1

# Prepare image latents and text encoder outputs
# during dataset creation to reduce memory consumption.
cache_latents_text_encoder_outputs: True


# Training loop
learning_rate: 1.e-7
scale_lr: False
Expand Down Expand Up @@ -205,4 +210,4 @@ class_prompt: ''
prior_loss_weight: 1.0
num_class_images: 100
# If true, set dataset_save_location.
cache_dreambooth_dataset: False
cache_dreambooth_dataset: False
14 changes: 12 additions & 2 deletions src/maxdiffusion/configs/base21.yml
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,21 @@ ici_tensor_parallelism: 1
# Dataset
# Replace with dataset path or train_data_dir. One has to be set.
dataset_name: 'diffusers/pokemon-gpt4-captions'
# saves transformed dataset of dataset_name.
train_split: 'train'
dataset_type: 'tf'
cache_latents_text_encoder_outputs: True
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
# only apply to small dataset that fits in memory
# prepare image latents and text encoder outputs
# Reduce memory consumption and reduce step time during training
# transformed dataset is saved at dataset_save_location
dataset_save_location: '/tmp/pokemon-gpt4-captions_sd21'
train_data_dir: ''
dataset_config_name: ''
jax_cache_dir: ''
hf_data_dir: ''
hf_train_files: ''
hf_access_token: ''
image_column: 'image'
caption_column: 'text'
resolution: 768
Expand Down Expand Up @@ -201,4 +211,4 @@ class_prompt: ''
prior_loss_weight: 1.0
num_class_images: 100
# If true, set dataset_save_location.
cache_dreambooth_dataset: False
cache_dreambooth_dataset: False
19 changes: 12 additions & 7 deletions src/maxdiffusion/configs/base_2_base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,21 @@ ici_tensor_parallelism: 1
# Dataset
# Replace with dataset path or train_data_dir. One has to be set.
dataset_name: 'diffusers/pokemon-gpt4-captions'
# saves transformed dataset of dataset_name.
train_split: 'train'
dataset_type: 'tf'
cache_latents_text_encoder_outputs: True
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
# only apply to small dataset that fits in memory
# prepare image latents and text encoder outputs
# Reduce memory consumption and reduce step time during training
# transformed dataset is saved at dataset_save_location
dataset_save_location: '/tmp/pokemon-gpt4-captions'
train_data_dir: ''
dataset_config_name: ''
jax_cache_dir: ''
hf_data_dir: ''
hf_train_files: ''
hf_access_token: ''
image_column: 'image'
caption_column: 'text'
resolution: 512
Expand All @@ -160,11 +170,6 @@ enable_data_shuffling: True
# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1

# Prepare image latents and text encoder outputs
# during dataset creation to reduce memory consumption.
cache_latents_text_encoder_outputs: True


# Training loop
learning_rate: 1.e-7
scale_lr: False
Expand Down Expand Up @@ -218,4 +223,4 @@ class_prompt: ''
prior_loss_weight: 1.0
num_class_images: 100
# If true, set dataset_save_location.
cache_dreambooth_dataset: False
cache_dreambooth_dataset: False
18 changes: 12 additions & 6 deletions src/maxdiffusion/configs/base_xl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,21 @@ ici_tensor_parallelism: 1
# Dataset
# Replace with dataset path or train_data_dir. One has to be set.
dataset_name: 'diffusers/pokemon-gpt4-captions'
# saves transformed dataset of dataset_name.
train_split: 'train'
dataset_type: 'tf'
cache_latents_text_encoder_outputs: True
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
# only apply to small dataset that fits in memory
# prepare image latents and text encoder outputs
# Reduce memory consumption and reduce step time during training
# transformed dataset is saved at dataset_save_location
dataset_save_location: '/tmp/pokemon-gpt4-captions_xl'
train_data_dir: ''
dataset_config_name: ''
jax_cache_dir: ''
hf_data_dir: ''
hf_train_files: ''
hf_access_token: ''
image_column: 'image'
caption_column: 'text'
resolution: 1024
Expand All @@ -148,10 +158,6 @@ enable_data_shuffling: True
# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1

# Prepare image latents and text encoder outputs
# during dataset creation to reduce memory consumption.
cache_latents_text_encoder_outputs: True

# Training loop
learning_rate: 4.e-7
scale_lr: False
Expand Down Expand Up @@ -204,4 +210,4 @@ enable_mllog: False
controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0'
controlnet_from_pt: True
controlnet_conditioning_scale: 0.5
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'
141 changes: 141 additions & 0 deletions src/maxdiffusion/input_pipeline/_hf_data_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""
Copyright 2024 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import warnings
import datasets
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
import grain.python as grain

from maxdiffusion import max_logging
from maxdiffusion import multihost_dataloading


def make_hf_streaming_iterator(
config,
dataloading_host_index,
dataloading_host_count,
mesh,
global_batch_size,
tokenize_fn=None,
image_transforms_fn=None,
):
"""Streaming data from HF Hub or GCS buckect.
No download regardless of config.cache_latents_text_encoder_outputs"""
ds = load_dataset(
config.dataset_name,
split=config.train_split,
data_dir=config.hf_data_dir,
data_files=config.hf_train_files,
streaming=True,
token=config.hf_access_token,
)

ds = ds.shuffle(seed=config.seed)
ds = ds.select_columns([config.caption_column, config.image_column])

if tokenize_fn:
ds = ds.map(
function=tokenize_fn,
batched=True,
remove_columns=[config.caption_column],
)

if image_transforms_fn:
ds = ds.map(
function=image_transforms_fn,
batched=True,
remove_columns=[config.image_column],
)

ds = HFDataSource(
ds,
dataloading_host_index,
dataloading_host_count,
)
dummy_index_sampler = grain.IndexSampler(
num_records=len(ds),
num_epochs=1,
shard_options=grain.ShardOptions(
shard_index=dataloading_host_index, shard_count=dataloading_host_count, drop_remainder=False
),
shuffle=False,
seed=0,
)
operations = [grain.Batch(batch_size=global_batch_size // dataloading_host_count, drop_remainder=True)]
dataloader = grain.DataLoader(
data_source=ds,
operations=operations,
sampler=dummy_index_sampler,
worker_count=1, # only supports one worker for now, more workers results in duplicated data
worker_buffer_size=1,
read_options=grain.ReadOptions(num_threads=1, prefetch_buffer_size=64),
)
train_iter = multihost_dataloading.MultiHostDataLoadIterator(dataloader, mesh)
return train_iter


class HFDataSource(grain.RandomAccessDataSource):
"""A class that makes HuggingFace IterableDataset a grain datasource without random access support"""

def __init__(
self,
dataset: datasets.IterableDataset,
dataloading_host_index: int,
dataloading_host_count: int,
):
self.dataset = dataset
self.dataloading_host_count = dataloading_host_count
self.dataloading_host_index = dataloading_host_index
self.n_shards = dataset.n_shards
self._check_shard_count()
self.current_shard = dataloading_host_index
self.dataset_shard = split_dataset_by_node(dataset, world_size=self.n_shards, rank=self.current_shard)
self.data_iter = None

def _check_shard_count(self):
if self.n_shards < self.dataloading_host_count:
warnings.warn(
f"WARNING: Inefficient dataloading. Your train or eval dataset contains {self.n_shards} shards, "
"smaller than number of host loading data. This is known to lead to inefficient dataloading. "
"see https://github.com/AI-Hypercomputer/maxdiffusion/blob/main/docs/data_README.md#best-practice"
)
self.n_shards = self.dataloading_host_count

def _update_shard(self):
new_shard = (self.current_shard + self.dataloading_host_count) % self.n_shards
max_logging.log(f"Updating host {self.dataloading_host_index} dataset from shard {self.current_shard} to {new_shard}")
self.current_shard = new_shard
self.dataset_shard = split_dataset_by_node(self.dataset, world_size=self.n_shards, rank=self.current_shard)
self.data_iter = iter(self.dataset_shard)

def __len__(self):
"""Return length of the HF dataset. Since HuggingFace IterableDataset does not have length,
a fake length bigger than the dataset is returned"""
return 10_000_000_000

def __getitem__(self, index):
"""Since HuggingFace IterableDataset does not support random access by index.
The next item in the iterator is returned."""
if not self.data_iter:
self.data_iter = iter(self.dataset_shard)

while True:
try:
data = next(self.data_iter)
return data
except StopIteration:
self._update_shard()
Loading

0 comments on commit cc51757

Please sign in to comment.