diff --git a/docs/README.md b/docs/README.md index 24a6c8e6..aaad1f4e 100644 --- a/docs/README.md +++ b/docs/README.md @@ -15,3 +15,7 @@ This folder contains documentation for getting started with and using MaxDiffusi ## Training * **[Common Training Guide](train_README.md)** - Provides a comprehensive guide to training MaxDiffusion models, including script usage, configuration options, and sharding strategies. + +## Data Input + +* **[Common Data Input Guide](data_README.md)** - Provides a comprehensive guide to data input pipelines. diff --git a/docs/data_README.md b/docs/data_README.md new file mode 100644 index 00000000..1677138b --- /dev/null +++ b/docs/data_README.md @@ -0,0 +1,50 @@ +# Data Input Guide + +## Overview +Currently MaxDiffusion supports 3 data input pipelines, controlled by the flag `dataset_type` +| Pipeline | Dataset Location | Dataset formats | Features or limitations | +| -------- | ---------------- | --------------- | ----------------------- | +| HuggingFace (hf) | datasets in HuggingFace Hub or local/Cloud Storage | Formats supported in HF Hub: parquet, arrow, json, csv, txt | data are not loaded in memory but streamed from the saved location, good for big dataset | +| tf | dataset will be downaloaded form HuggingFace Hub to disk | Formats supported in HF Hub: parquet, arrow, json, csv, txt | Will read the whole dataset into memory, works for small dataset | +| tfrecord | local/Cloud Storage | tfrecord | data are not loaded in memory but streamed from the saved location, good for big dataset | + +## Usage examples + +### HuggingFace Streaming (dataset_type=hf) +#### Example config for streaming from HuggingFace Hub (no download needed): +``` +dataset_type: hf +dataset_name: BleachNick/UltraEdit_500k # for using https://huggingface.co/datasets/BleachNick/UltraEdit_500k +image_column: source_image +caption_column: source_caption +train_split: FreeForm +hf_access_token: '' # provide token if using gated dataset or tokenizer +``` +#### Example config for streaming from downloaded data in a GCS bucket: +``` +dataset_type: hf +dataset_name: parquet # or json, arrow, etc. +hf_train_files: gs:////*-train-*.parquet # match the train files +``` + +### tf.data in-memory dataset (dataset_type=tf) +#### Example config +``` +dataset_type: tf +dataset_name: diffusers/pokemon-gpt4-captions # will download https://huggingface.co/datasets/diffusers/pokemon-gpt4-captions +dataset_save_location: '/tmp/pokemon-gpt4-captions_xl' +# If cache_latents_text_encoder_outputs=True, apply vae to images and encode text when downloading dataset, +# the saved dataset contains latents and text encoder outputs. +cache_latents_text_encoder_outputs: True +``` + +### tf.data.TFRecordDataset (dataset_type=tfrecord) +#### Example config +``` +dataset_type: tfrecord +train_data_dir: gs:/// # will use all TFRecord files under the directory +``` + +## Best Practice +### Multihost Dataloading +In multihost environment, if use a streaming type of input pipeline and the data format only supports sequential reads (dataset_type in (hf, tfrecord in MaxDiffusion)), the most performant way is to have each data file only accessed by one host, and each host access a subset of data files (shuffle is within the subset of files). This requires (# of data files) > (# of hosts loading data). We recommand users to reshard the dataset if this requirement is not met. diff --git a/requirements.txt b/requirements.txt index 50be802e..3e5c285c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ jax>=0.4.30 jaxlib>=0.4.30 +grain-nightly google-cloud-storage==2.17.0 absl-py datasets @@ -22,6 +23,6 @@ tensorflow>=2.17.0 tensorflow-datasets>=4.9.6 ruff>=0.1.5,<=0.2 git+https://github.com/mlperf/logging.git -opencv-python==4.10.0.84 +opencv-python-headless==4.10.0.84 orbax-checkpoint>=0.5.20 tokenizers==0.20.0 diff --git a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py index d297832e..2bdf1872 100644 --- a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py +++ b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py @@ -54,7 +54,7 @@ def __init__(self, config, checkpoint_type): self.rng = jax.random.PRNGKey(self.config.seed) devices_array = max_utils.create_device_mesh(config) self.mesh = Mesh(devices_array, self.config.mesh_axes) - self.total_train_batch_size = max_utils.get_global_batch_size(self.config) + self.total_train_batch_size = self.config.total_train_batch_size self.checkpoint_manager = create_orbax_checkpoint_manager( self.config.checkpoint_dir, enable_checkpointing=True, save_interval_steps=1, checkpoint_type=checkpoint_type diff --git a/src/maxdiffusion/configs/base14.yml b/src/maxdiffusion/configs/base14.yml index b5f938c1..154877b2 100644 --- a/src/maxdiffusion/configs/base14.yml +++ b/src/maxdiffusion/configs/base14.yml @@ -125,11 +125,21 @@ ici_tensor_parallelism: 1 # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' -# saves transformed dataset of dataset_name. +train_split: 'train' +dataset_type: 'tf' +cache_latents_text_encoder_outputs: True +# cache_latents_text_encoder_outputs only apply to dataset_type="tf", +# only apply to small dataset that fits in memory +# prepare image latents and text encoder outputs +# Reduce memory consumption and reduce step time during training +# transformed dataset is saved at dataset_save_location dataset_save_location: '/tmp/pokemon-gpt4-captions_sd15' train_data_dir: '' dataset_config_name: '' jax_cache_dir: '' +hf_data_dir: '' +hf_train_files: '' +hf_access_token: '' image_column: 'image' caption_column: 'text' resolution: 512 @@ -145,11 +155,6 @@ enable_data_shuffling: True # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 -# Prepare image latents and text encoder outputs -# during dataset creation to reduce memory consumption. -cache_latents_text_encoder_outputs: True - - # Training loop learning_rate: 1.e-7 scale_lr: False @@ -205,4 +210,4 @@ class_prompt: '' prior_loss_weight: 1.0 num_class_images: 100 # If true, set dataset_save_location. -cache_dreambooth_dataset: False \ No newline at end of file +cache_dreambooth_dataset: False diff --git a/src/maxdiffusion/configs/base21.yml b/src/maxdiffusion/configs/base21.yml index c8a154a3..a5b753f4 100644 --- a/src/maxdiffusion/configs/base21.yml +++ b/src/maxdiffusion/configs/base21.yml @@ -127,11 +127,21 @@ ici_tensor_parallelism: 1 # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' -# saves transformed dataset of dataset_name. +train_split: 'train' +dataset_type: 'tf' +cache_latents_text_encoder_outputs: True +# cache_latents_text_encoder_outputs only apply to dataset_type="tf", +# only apply to small dataset that fits in memory +# prepare image latents and text encoder outputs +# Reduce memory consumption and reduce step time during training +# transformed dataset is saved at dataset_save_location dataset_save_location: '/tmp/pokemon-gpt4-captions_sd21' train_data_dir: '' dataset_config_name: '' jax_cache_dir: '' +hf_data_dir: '' +hf_train_files: '' +hf_access_token: '' image_column: 'image' caption_column: 'text' resolution: 768 @@ -201,4 +211,4 @@ class_prompt: '' prior_loss_weight: 1.0 num_class_images: 100 # If true, set dataset_save_location. -cache_dreambooth_dataset: False \ No newline at end of file +cache_dreambooth_dataset: False diff --git a/src/maxdiffusion/configs/base_2_base.yml b/src/maxdiffusion/configs/base_2_base.yml index 2c860109..2bddfd15 100644 --- a/src/maxdiffusion/configs/base_2_base.yml +++ b/src/maxdiffusion/configs/base_2_base.yml @@ -140,11 +140,21 @@ ici_tensor_parallelism: 1 # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' -# saves transformed dataset of dataset_name. +train_split: 'train' +dataset_type: 'tf' +cache_latents_text_encoder_outputs: True +# cache_latents_text_encoder_outputs only apply to dataset_type="tf", +# only apply to small dataset that fits in memory +# prepare image latents and text encoder outputs +# Reduce memory consumption and reduce step time during training +# transformed dataset is saved at dataset_save_location dataset_save_location: '/tmp/pokemon-gpt4-captions' train_data_dir: '' dataset_config_name: '' jax_cache_dir: '' +hf_data_dir: '' +hf_train_files: '' +hf_access_token: '' image_column: 'image' caption_column: 'text' resolution: 512 @@ -160,11 +170,6 @@ enable_data_shuffling: True # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 -# Prepare image latents and text encoder outputs -# during dataset creation to reduce memory consumption. -cache_latents_text_encoder_outputs: True - - # Training loop learning_rate: 1.e-7 scale_lr: False @@ -218,4 +223,4 @@ class_prompt: '' prior_loss_weight: 1.0 num_class_images: 100 # If true, set dataset_save_location. -cache_dreambooth_dataset: False \ No newline at end of file +cache_dreambooth_dataset: False diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index 18c4e94b..01ea710b 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -128,11 +128,21 @@ ici_tensor_parallelism: 1 # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' -# saves transformed dataset of dataset_name. +train_split: 'train' +dataset_type: 'tf' +cache_latents_text_encoder_outputs: True +# cache_latents_text_encoder_outputs only apply to dataset_type="tf", +# only apply to small dataset that fits in memory +# prepare image latents and text encoder outputs +# Reduce memory consumption and reduce step time during training +# transformed dataset is saved at dataset_save_location dataset_save_location: '/tmp/pokemon-gpt4-captions_xl' train_data_dir: '' dataset_config_name: '' jax_cache_dir: '' +hf_data_dir: '' +hf_train_files: '' +hf_access_token: '' image_column: 'image' caption_column: 'text' resolution: 1024 @@ -148,10 +158,6 @@ enable_data_shuffling: True # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 -# Prepare image latents and text encoder outputs -# during dataset creation to reduce memory consumption. -cache_latents_text_encoder_outputs: True - # Training loop learning_rate: 4.e-7 scale_lr: False @@ -204,4 +210,4 @@ enable_mllog: False controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' controlnet_from_pt: True controlnet_conditioning_scale: 0.5 -controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' \ No newline at end of file +controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' diff --git a/src/maxdiffusion/input_pipeline/_hf_data_processing.py b/src/maxdiffusion/input_pipeline/_hf_data_processing.py new file mode 100644 index 00000000..ed9af119 --- /dev/null +++ b/src/maxdiffusion/input_pipeline/_hf_data_processing.py @@ -0,0 +1,141 @@ +""" + Copyright 2024 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import warnings +import datasets +from datasets import load_dataset +from datasets.distributed import split_dataset_by_node +import grain.python as grain + +from maxdiffusion import max_logging +from maxdiffusion import multihost_dataloading + + +def make_hf_streaming_iterator( + config, + dataloading_host_index, + dataloading_host_count, + mesh, + global_batch_size, + tokenize_fn=None, + image_transforms_fn=None, +): + """Streaming data from HF Hub or GCS buckect. + No download regardless of config.cache_latents_text_encoder_outputs""" + ds = load_dataset( + config.dataset_name, + split=config.train_split, + data_dir=config.hf_data_dir, + data_files=config.hf_train_files, + streaming=True, + token=config.hf_access_token, + ) + + ds = ds.shuffle(seed=config.seed) + ds = ds.select_columns([config.caption_column, config.image_column]) + + if tokenize_fn: + ds = ds.map( + function=tokenize_fn, + batched=True, + remove_columns=[config.caption_column], + ) + + if image_transforms_fn: + ds = ds.map( + function=image_transforms_fn, + batched=True, + remove_columns=[config.image_column], + ) + + ds = HFDataSource( + ds, + dataloading_host_index, + dataloading_host_count, + ) + dummy_index_sampler = grain.IndexSampler( + num_records=len(ds), + num_epochs=1, + shard_options=grain.ShardOptions( + shard_index=dataloading_host_index, shard_count=dataloading_host_count, drop_remainder=False + ), + shuffle=False, + seed=0, + ) + operations = [grain.Batch(batch_size=global_batch_size // dataloading_host_count, drop_remainder=True)] + dataloader = grain.DataLoader( + data_source=ds, + operations=operations, + sampler=dummy_index_sampler, + worker_count=1, # only supports one worker for now, more workers results in duplicated data + worker_buffer_size=1, + read_options=grain.ReadOptions(num_threads=1, prefetch_buffer_size=64), + ) + train_iter = multihost_dataloading.MultiHostDataLoadIterator(dataloader, mesh) + return train_iter + + +class HFDataSource(grain.RandomAccessDataSource): + """A class that makes HuggingFace IterableDataset a grain datasource without random access support""" + + def __init__( + self, + dataset: datasets.IterableDataset, + dataloading_host_index: int, + dataloading_host_count: int, + ): + self.dataset = dataset + self.dataloading_host_count = dataloading_host_count + self.dataloading_host_index = dataloading_host_index + self.n_shards = dataset.n_shards + self._check_shard_count() + self.current_shard = dataloading_host_index + self.dataset_shard = split_dataset_by_node(dataset, world_size=self.n_shards, rank=self.current_shard) + self.data_iter = None + + def _check_shard_count(self): + if self.n_shards < self.dataloading_host_count: + warnings.warn( + f"WARNING: Inefficient dataloading. Your train or eval dataset contains {self.n_shards} shards, " + "smaller than number of host loading data. This is known to lead to inefficient dataloading. " + "see https://github.com/AI-Hypercomputer/maxdiffusion/blob/main/docs/data_README.md#best-practice" + ) + self.n_shards = self.dataloading_host_count + + def _update_shard(self): + new_shard = (self.current_shard + self.dataloading_host_count) % self.n_shards + max_logging.log(f"Updating host {self.dataloading_host_index} dataset from shard {self.current_shard} to {new_shard}") + self.current_shard = new_shard + self.dataset_shard = split_dataset_by_node(self.dataset, world_size=self.n_shards, rank=self.current_shard) + self.data_iter = iter(self.dataset_shard) + + def __len__(self): + """Return length of the HF dataset. Since HuggingFace IterableDataset does not have length, + a fake length bigger than the dataset is returned""" + return 10_000_000_000 + + def __getitem__(self, index): + """Since HuggingFace IterableDataset does not support random access by index. + The next item in the iterator is returned.""" + if not self.data_iter: + self.data_iter = iter(self.dataset_shard) + + while True: + try: + data = next(self.data_iter) + return data + except StopIteration: + self._update_shard() diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py new file mode 100644 index 00000000..667bfd9d --- /dev/null +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -0,0 +1,115 @@ +""" + Copyright 2024 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import os +import tensorflow as tf +import tensorflow.experimental.numpy as tnp +from datasets import load_dataset, load_from_disk + +from maxdiffusion import multihost_dataloading + +AUTOTUNE = tf.data.experimental.AUTOTUNE + + +def load_as_tf_dataset(dataset, global_batch_size, shuffle, dataloading_host_count): + dataset = dataset.with_format("tensorflow")[:] + tf_dataset = tf.data.Dataset.from_tensor_slices(dataset) + + if shuffle: + tf_dataset = tf_dataset.shuffle(len(tf_dataset)) + tf_dataset = tf_dataset.batch(global_batch_size // dataloading_host_count, drop_remainder=True) + tf_dataset = tf_dataset.prefetch(tf.data.experimental.AUTOTUNE) + tf_dataset = tf_dataset.repeat(-1) + + return tf_dataset + + +def make_tf_iterator( + config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, tokenize_fn, image_transforms_fn +): + + if config.cache_latents_text_encoder_outputs and os.path.isdir(config.dataset_save_location): + train_ds = load_from_disk(config.dataset_save_location) + else: + train_ds = load_dataset(config.dataset_name, split=config.train_split) + train_ds = train_ds.map( + function=tokenize_fn, + batched=True, + remove_columns=[config.caption_column], + num_proc=1 if config.cache_latents_text_encoder_outputs else 4, + desc="Running tokenizer on train dataset", + ) + # need to do it before load_as_tf_dataset + # since raw images are different sizes + # will break from_tensor_slices + train_ds = train_ds.map( + function=image_transforms_fn, + batched=True, + remove_columns=[config.image_column], + num_proc=1 if config.cache_latents_text_encoder_outputs else config.transform_images_num_proc, + desc="Transforming images", + ) + if config.cache_latents_text_encoder_outputs: + train_ds.save_to_disk(config.dataset_save_location) + train_ds.cleanup_cache_files() + + train_ds = load_as_tf_dataset(train_ds, global_batch_size, True, dataloading_host_count) + train_ds = train_ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index) + + train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh) + return train_iter + + +# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py +def make_tfrecord_iterator( + config, + dataloading_host_index, + dataloading_host_count, + mesh, + global_batch_size, +): + """Iterator for TFRecord format. For Laion dataset, + check out preparation script + maxdiffusion/pedagogical_examples/to_tfrecords.py + """ + feature_description = { + "latents": tf.io.FixedLenFeature([], tf.string), + "hidden_states": tf.io.FixedLenFeature([], tf.string), + } + + def _parse_tfrecord_fn(example): + return tf.io.parse_single_example(example, feature_description) + + def prepare_sample(features): + latents = tf.io.parse_tensor(tnp.asarray(features["latents"]), out_type=tf.float32) + hidden_states = tf.io.parse_tensor(tnp.asarray(features["hidden_states"]), out_type=tf.float32) + return {"pixel_values": latents, "input_ids": hidden_states} + + filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*")) + train_ds = ( + tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE) + .map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE) + .map(prepare_sample, num_parallel_calls=AUTOTUNE) + .shuffle(global_batch_size * 10) + .batch(global_batch_size // dataloading_host_count, drop_remainder=True) + .prefetch(AUTOTUNE) + .repeat(-1) + ) + + train_ds = train_ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index) + + train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh) + return train_iter diff --git a/src/maxdiffusion/input_pipeline/_tfsd_data_processing.py b/src/maxdiffusion/input_pipeline/_tfsd_data_processing.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py index 478ca9dc..00bfcd9b 100644 --- a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py +++ b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py @@ -17,10 +17,11 @@ import os from functools import partial import tensorflow as tf -import tensorflow.experimental.numpy as tnp -from datasets import load_dataset, load_from_disk, Dataset +from datasets import load_from_disk, Dataset import jax +from maxdiffusion.input_pipeline import _hf_data_processing +from maxdiffusion.input_pipeline import _tfds_data_processing from maxdiffusion import multihost_dataloading from maxdiffusion.maxdiffusion_utils import tokenize_captions, transform_images, vae_apply from maxdiffusion.dreambooth.dreambooth_constants import ( @@ -40,95 +41,46 @@ AUTOTUNE = tf.data.experimental.AUTOTUNE -# taken from https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/examples/tensorflow/contrastive-image-text/run_clip.py#L225 -def load_as_tf_dataset(dataset, batch_size, shuffle, config): - dataset = dataset.with_format("tensorflow")[:] - tf_dataset = tf.data.Dataset.from_tensor_slices(dataset) - - if shuffle: - tf_dataset = tf_dataset.shuffle(len(tf_dataset)) - tf_dataset = tf_dataset.batch(batch_size // jax.process_count(), drop_remainder=True) - tf_dataset = tf_dataset.prefetch(tf.data.experimental.AUTOTUNE) - tf_dataset = tf_dataset.repeat(-1) - - return tf_dataset - - -# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py -def make_laion400m_train_iterator( +def make_data_iterator( config, + dataloading_host_index, + dataloading_host_count, mesh, global_batch_size, + tokenize_fn=None, + image_transforms_fn=None, ): - """Iterator for Laion dataset. - To see how to prepare this dataset, look at - maxdiffusion/pedagogical_examples/to_tfrecords.py - """ - feature_description = { - "latents": tf.io.FixedLenFeature([], tf.string), - "hidden_states": tf.io.FixedLenFeature([], tf.string), - } - - def _parse_tfrecord_fn(example): - return tf.io.parse_single_example(example, feature_description) - - def prepare_sample(features): - latents = tf.io.parse_tensor(tnp.asarray(features["latents"]), out_type=tf.float32) - hidden_states = tf.io.parse_tensor(tnp.asarray(features["hidden_states"]), out_type=tf.float32) - return {"pixel_values": latents, "input_ids": hidden_states} - - filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*")) - train_ds = ( - tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE) - .map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE) - .map(prepare_sample, num_parallel_calls=AUTOTUNE) - .shuffle(global_batch_size * 10) - .batch(global_batch_size // jax.process_count(), drop_remainder=True) - .prefetch(AUTOTUNE) - .repeat(-1) - ) - - train_ds = train_ds.shard(num_shards=jax.process_count(), index=jax.process_index()) - - train_iter = multihost_dataloading.get_batch_sharded_data_pipeline(train_ds, mesh) - return train_iter - - -def make_pokemon_train_iterator(config, mesh, global_batch_size, tokenize_fn, image_transforms_fn): - dataset_save_location = config.dataset_save_location - if os.path.isdir(dataset_save_location): - train_ds = load_from_disk(dataset_save_location) - else: - train_ds = load_dataset(config.dataset_name, split="train") - - captions_column = config.caption_column - image_column = config.image_column - cache_latents_text_encoder_outputs = config.cache_latents_text_encoder_outputs - train_ds = train_ds.map( - function=tokenize_fn, - batched=True, - remove_columns=[captions_column], - num_proc=1 if cache_latents_text_encoder_outputs else 4, - desc="Running tokenizer on train dataset", + """Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord)""" + if config.dataset_type == "hf": + return _hf_data_processing.make_hf_streaming_iterator( + config, + dataloading_host_index, + dataloading_host_count, + mesh, + global_batch_size, + tokenize_fn=tokenize_fn, + image_transforms_fn=image_transforms_fn, ) - # need to do it before load_as_tf_dataset - # since raw images are different sizes - # will break from_tensor_slices - train_ds = train_ds.map( - function=image_transforms_fn, - batched=True, - remove_columns=[image_column], - num_proc=1 if cache_latents_text_encoder_outputs else config.transform_images_num_proc, - desc="Transforming images", + elif config.dataset_type == "tf": + return _tfds_data_processing.make_tf_iterator( + config, + dataloading_host_index, + dataloading_host_count, + mesh, + global_batch_size, + tokenize_fn=tokenize_fn, + image_transforms_fn=image_transforms_fn, ) - train_ds.save_to_disk(dataset_save_location) - train_ds.cleanup_cache_files() - - train_ds = load_as_tf_dataset(train_ds, global_batch_size, True, config) - train_ds = train_ds.shard(num_shards=jax.process_count(), index=jax.process_index()) - - train_iter = multihost_dataloading.get_batch_sharded_data_pipeline(train_ds, mesh) - return train_iter + elif config.dataset_type == "tfrecord": + return _tfds_data_processing.make_tfrecord_iterator( + config, + dataloading_host_index, + dataloading_host_count, + mesh, + global_batch_size, + ) + else: + assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf)" def make_dreambooth_train_iterator(config, mesh, global_batch_size, tokenizer, vae, vae_params): @@ -230,10 +182,12 @@ def make_dreambooth_train_iterator(config, mesh, global_batch_size, tokenizer, v instance_train_ds.save_to_disk(instance_dataset_full_path) class_train_ds.save_to_disk(class_dataset_full_path) - instance_train_ds = load_as_tf_dataset(instance_train_ds, global_batch_size, True, config) - class_train_ds = load_as_tf_dataset(class_train_ds, global_batch_size, True, config) + instance_train_ds = _tfds_data_processing.load_as_tf_dataset( + instance_train_ds, global_batch_size, True, jax.process_count() + ) + class_train_ds = _tfds_data_processing.load_as_tf_dataset(class_train_ds, global_batch_size, True, jax.process_count()) train_ds = tf.data.Dataset.zip((instance_train_ds, class_train_ds)) train_ds = train_ds.shard(num_shards=jax.process_count(), index=jax.process_index()) - train_iter = multihost_dataloading.get_batch_sharded_data_pipeline(train_ds, mesh) + train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh) return train_iter diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 66465e8a..e3efcaa7 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -550,8 +550,8 @@ def calculate_num_params_from_pytree(params): return total_parameters -def get_global_batch_size(config): - return config.per_device_batch_size * jax.device_count() +def get_global_batch_size(per_device_batch_size): + return per_device_batch_size * jax.device_count() def maybe_initialize_jax_distributed_system(raw_keys): diff --git a/src/maxdiffusion/maxdiffusion_utils.py b/src/maxdiffusion/maxdiffusion_utils.py index 1b22ce6a..7a0337d4 100644 --- a/src/maxdiffusion/maxdiffusion_utils.py +++ b/src/maxdiffusion/maxdiffusion_utils.py @@ -50,7 +50,13 @@ def vae_apply(images, sample_rng, vae, vae_params): def transform_images( - examples, image_column, image_resolution, rng, global_batch_size, pixel_ids_key="pixel_values", p_vae_apply=None + examples, + image_column, + image_resolution, + rng=None, + global_batch_size=None, + pixel_ids_key="pixel_values", + p_vae_apply=None, ): """Preprocess images to latents.""" images = list(examples[image_column]) @@ -210,7 +216,9 @@ def calculate_unet_tflops(config, pipeline, batch_size, rngs, train): def tokenize_captions(examples, caption_column, tokenizer, input_ids_key="input_ids", p_encode=None): """Tokenize captions for sd1.x,sd2.x models.""" captions = list(examples[caption_column]) - text_inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True) + text_inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="np" + ) if p_encode: encoder_hidden_states = p_encode(np.stack(text_inputs.input_ids)) @@ -222,13 +230,32 @@ def tokenize_captions(examples, caption_column, tokenizer, input_ids_key="input_ return examples +def tokenize_captions_xl(examples, caption_column, tokenizers, p_encode=None): + inputs = [] + captions = list(examples[caption_column]) + for _tokenizer in tokenizers: + text_inputs = _tokenizer( + captions, padding="max_length", max_length=_tokenizer.model_max_length, truncation=True, return_tensors="np" + ) + inputs.append(text_inputs.input_ids) + inputs = np.stack(inputs, axis=1) + + if p_encode: + prompt_embeds, text_embeds = p_encode(inputs) + # pyarrow dataset doesn't support bf16, so cast to float32. + examples["prompt_embeds"] = np.float32(prompt_embeds) + examples["text_embeds"] = np.float32(text_embeds) + examples["input_ids"] = inputs + return examples + + def get_shaped_batch(config, pipeline): """Return the shape of the batch - this is what eval_shape would return for the output of create_data_iterator_with_tokenizer, but eval_shape doesn't work, see b/306901078. This function works with sd1.x and 2.x. """ vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) - total_train_batch_size = config.per_device_batch_size * jax.device_count() + total_train_batch_size = config.total_train_batch_size if config.cache_latents_text_encoder_outputs: batch_image_shape = ( total_train_batch_size, @@ -253,3 +280,18 @@ def get_shaped_batch(config, pipeline): def encode(input_ids, text_encoder, text_encoder_params): return text_encoder(input_ids, params=text_encoder_params, train=False)[0] + + +def encode_xl(input_ids, text_encoders, text_encoder_params): + te_1_inputs = input_ids[:, 0, :] + te_2_inputs = input_ids[:, 1, :] + + prompt_embeds = text_encoders[0](te_1_inputs, params=text_encoder_params[0], output_hidden_states=True) + prompt_embeds = prompt_embeds["hidden_states"][-2] + + prompt_embeds_2_out = text_encoders[1](te_2_inputs, params=text_encoder_params[1], output_hidden_states=True) + prompt_embeds_2 = prompt_embeds_2_out["hidden_states"][-2] + text_embeds = prompt_embeds_2_out["text_embeds"] + prompt_embeds = jnp.concatenate([prompt_embeds, prompt_embeds_2], axis=-1) + + return prompt_embeds, text_embeds diff --git a/src/maxdiffusion/multihost_dataloading.py b/src/maxdiffusion/multihost_dataloading.py index 2de182e8..4be0ba8d 100644 --- a/src/maxdiffusion/multihost_dataloading.py +++ b/src/maxdiffusion/multihost_dataloading.py @@ -21,7 +21,8 @@ https://github.com/sholtodouglas/multihost_dataloading """ from functools import partial # pylint: disable=g-importing-member -from typing import Callable +from typing import Union +from collections.abc import Iterator, Iterable import tensorflow as tf # pylint: disable=g-import-not-at-top import time import numpy as np @@ -61,22 +62,7 @@ def _form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array: return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers) -def get_batch_sharded_data_pipeline(dataset: tf.data.Dataset, global_mesh: Mesh) -> Callable[[], jax.Array]: - """Each device loads batch_size/num_devices, - To do this, each host first loads batch_size/num_hosts, then shards that - equally across it's devices. - Args: - dataset: tf dataset over all files - Returns: - sharded_dataset: per_host dataset - """ - dataset = iter(dataset.as_numpy_iterator()) - multihost_generator = partial(get_next_batch_sharded, dataset, global_mesh) - - return multihost_generator - - -def get_next_batch_sharded(local_dataset: tf.data.Dataset, global_mesh: Mesh) -> jax.Array: +def get_next_batch_sharded(local_dataset: Iterator, global_mesh: Mesh) -> jax.Array: """Splits the host loaded data equally over all devices.""" SLEEP_TIME = 10 @@ -87,7 +73,7 @@ def get_next_batch_sharded(local_dataset: tf.data.Dataset, global_mesh: Mesh) -> while not loaded_data_success and data_load_attempts < MAX_DATA_LOAD_ATTEMPTS: data_load_attempts += 1 try: - local_data = local_dataset.next() + local_data = next(local_dataset) loaded_data_success = True except tf.errors.FailedPreconditionError: max_logging.log("Failed to get next data batch, retrying") @@ -100,3 +86,32 @@ def get_next_batch_sharded(local_dataset: tf.data.Dataset, global_mesh: Mesh) -> input_gdas = jtu.tree_map_with_path(partial(_form_global_array, global_mesh=global_mesh), local_data) return input_gdas + + +class MultiHostDataLoadIterator: + """fold get_next_batch_sharded into a iterator class""" + + def __init__(self, dataloader: Union[tf.data.Dataset, Iterable], global_mesh: Mesh): + self.global_mesh = global_mesh + self.dataloader = dataloader + if isinstance(self.dataloader, tf.data.Dataset): + self.local_iterator = self.dataloader.as_numpy_iterator() + elif isinstance(self.dataloader, Iterable): + self.local_iterator = iter(self.dataloader) + else: + raise ValueError("Type error: dataloader should be either tf.data.Dataset or Iterable.") + + def reset(self): + if isinstance(self.dataloader, tf.data.Dataset): + self.local_iterator = self.dataloader.as_numpy_iterator() + elif isinstance(self.dataloader, Iterable): + self.local_iterator = iter(self.dataloader) + else: + raise ValueError("Type error: dataloader should be either tf.data.Dataset or grain.DataLoader.") + + def __iter__(self): + self.reset() + return self + + def __next__(self): + return get_next_batch_sharded(self.local_iterator, self.global_mesh) diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 89869f6a..f96ce837 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -103,9 +103,11 @@ def __init__(self, argv: list[str], **kwargs): _HyperParameters.user_init(raw_keys) self.keys = raw_keys + for k in sorted(raw_keys.keys()): + max_logging.log(f"Config param {k}: {raw_keys[k]}") def _load_kwargs(self, argv: list[str]): - args_dict = dict(a.split("=") for a in argv[2:]) + args_dict = dict(a.split("=", 1) for a in argv[2:]) return args_dict @staticmethod @@ -144,6 +146,11 @@ def user_init(raw_keys): raw_keys["dataset_name"] = max_utils.download_blobs(raw_keys["dataset_name"], raw_keys["dataset_save_location"]) raw_keys["dataset_save_location"] = raw_keys["dataset_name"] + if "hf_train_files" in raw_keys and not raw_keys["hf_train_files"]: + raw_keys["hf_train_files"] = None + + raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"]) + def get_num_target_devices(raw_keys): return len(jax.devices()) diff --git a/src/maxdiffusion/tests/input_pipeline_interface_test.py b/src/maxdiffusion/tests/input_pipeline_interface_test.py index 5a646eb4..4c69505e 100644 --- a/src/maxdiffusion/tests/input_pipeline_interface_test.py +++ b/src/maxdiffusion/tests/input_pipeline_interface_test.py @@ -31,8 +31,7 @@ from .. import pyconfig from .. import max_utils from maxdiffusion.input_pipeline.input_pipeline_interface import ( - make_laion400m_train_iterator, - make_pokemon_train_iterator, + make_data_iterator, make_dreambooth_train_iterator, ) @@ -42,8 +41,7 @@ from maxdiffusion import (FlaxStableDiffusionPipeline, FlaxStableDiffusionXLPipeline) from maxdiffusion.models import FlaxAutoencoderKL -from maxdiffusion.maxdiffusion_utils import (encode, tokenize_captions) -from maxdiffusion.trainers.sdxl_trainer import (encode_xl, tokenize_captions_xl) +from maxdiffusion.maxdiffusion_utils import (encode, tokenize_captions, encode_xl, tokenize_captions_xl) from maxdiffusion.maxdiffusion_utils import vae_apply, transform_images @@ -109,7 +107,7 @@ def test_make_dreambooth_train_iterator(self): config, mesh, global_batch_size, pipeline.tokenizer, pipeline.vae, params["vae"] ) - data = train_iterator() + data = next(train_iterator) device_count = jax.device_count() vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) @@ -135,13 +133,115 @@ def test_make_dreambooth_train_iterator(self): cleanup(instance_class_local_dir) cleanup(class_class_local_dir) - def test_make_pokemon_iterator_cache(self): + def test_make_pokemon_hf_iterator(self): + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_2_base.yml"), + "pretrained_model_name_or_path=gs://maxdiffusion-github-runner-test-assets/checkpoints/models--stabilityai--stable-diffusion-2-base", + "dataset_name=diffusers/pokemon-gpt4-captions", + "from_pt=False", + "dataset_type=hf", + ], + unittest=True, + ) + config = pyconfig.config + + global_batch_size = config.per_device_batch_size * jax.device_count() + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=config.activations_dtype, + safety_checker=None, + feature_extractor=None, + from_pt=config.from_pt, + ) + p_encode = None + p_vae_apply = None + rng = None + tokenize_fn = partial( + tokenize_captions, caption_column=config.caption_column, tokenizer=pipeline.tokenizer, p_encode=p_encode + ) + image_transforms_fn = partial( + transform_images, + image_column=config.image_column, + image_resolution=config.resolution, + rng=rng, + global_batch_size=global_batch_size, + p_vae_apply=p_vae_apply, + ) + + train_iterator = make_data_iterator( + config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn + ) + data = next(train_iterator) + device_count = jax.device_count() + + assert data["input_ids"].shape == (device_count, 77) + assert data["pixel_values"].shape == (device_count, 3, config.resolution, config.resolution) + + def test_make_pokemon_hf_iterator_sdxl(self): + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_xl.yml"), + "pretrained_model_name_or_path=gs://maxdiffusion-github-runner-test-assets/checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0", + "per_device_batch_size=1", + "dataset_name=diffusers/pokemon-gpt4-captions", + "dataset_type=hf", + ], + unittest=True, + ) + config = pyconfig.config + + global_batch_size = config.per_device_batch_size * jax.device_count() + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=config.activations_dtype, + safety_checker=None, + feature_extractor=None, + from_pt=config.from_pt, + ) + p_encode = None + p_vae_apply = None + rng = None + tokenize_fn = partial( + tokenize_captions_xl, + caption_column=config.caption_column, + tokenizers=[pipeline.tokenizer, pipeline.tokenizer_2], + p_encode=p_encode, + ) + image_transforms_fn = partial( + transform_images, + image_column=config.image_column, + image_resolution=config.resolution, + rng=rng, + global_batch_size=global_batch_size, + p_vae_apply=p_vae_apply, + ) + + train_iterator = make_data_iterator( + config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn + ) + data = next(train_iterator) + device_count = jax.device_count() + + assert data["input_ids"].shape == (device_count, 2, 77) + assert data["pixel_values"].shape == (device_count, 3, config.resolution, config.resolution) + + def test_make_pokemon_tf_iterator_cache(self): pyconfig.initialize( [ None, os.path.join(THIS_DIR, "..", "configs", "base_2_base.yml"), "cache_latents_text_encoder_outputs=True", "dataset_name=diffusers/pokemon-gpt4-captions", + "dataset_type=tf", ], unittest=True, ) @@ -178,8 +278,10 @@ def test_make_pokemon_iterator_cache(self): p_vae_apply=p_vae_apply, ) - train_iterator = make_pokemon_train_iterator(config, mesh, global_batch_size, tokenize_fn, image_transforms_fn) - data = train_iterator() + train_iterator = make_data_iterator( + config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn + ) + data = next(train_iterator) device_count = jax.device_count() vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) @@ -202,6 +304,7 @@ def test_make_pokemon_iterator_no_cache(self): "tokenize_captions_num_proc=1", "transform_images_num_proc=1", "dataset_name=diffusers/pokemon-gpt4-captions", + "dataset_type=tf", ], unittest=True, ) @@ -238,8 +341,10 @@ def test_make_pokemon_iterator_no_cache(self): p_vae_apply=p_vae_apply, ) - train_iterator = make_pokemon_train_iterator(config, mesh, global_batch_size, tokenize_fn, image_transforms_fn) - data = train_iterator() + train_iterator = make_data_iterator( + config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn + ) + data = next(train_iterator) device_count = jax.device_count() encoder_hidden_states = data["input_ids"] @@ -255,6 +360,7 @@ def test_make_pokemon_iterator_sdxl_cache(self): "cache_latents_text_encoder_outputs=True", "per_device_batch_size=1", "dataset_name=diffusers/pokemon-gpt4-captions", + "dataset_type=tf", ], unittest=True, ) @@ -300,8 +406,10 @@ def test_make_pokemon_iterator_sdxl_cache(self): p_vae_apply=p_vae_apply, ) - train_iterator = make_pokemon_train_iterator(config, mesh, global_batch_size, tokenize_fn, image_transforms_fn) - data = train_iterator() + train_iterator = make_data_iterator( + config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn + ) + data = next(train_iterator) device_count = jax.device_count() vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) @@ -317,13 +425,78 @@ def test_make_pokemon_iterator_sdxl_cache(self): config.resolution // vae_scale_factor, ) - def test_make_laion_iterator(self): + def test_make_pokemon_tf_iterator_sdxl_no_cache(self): + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_xl.yml"), + "pretrained_model_name_or_path=gs://maxdiffusion-github-runner-test-assets/checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0", + "cache_latents_text_encoder_outputs=False", + "per_device_batch_size=1", + "dataset_name=diffusers/pokemon-gpt4-captions", + "dataset_type=tf", + ], + unittest=True, + ) + config = pyconfig.config + + cleanup(config.dataset_save_location) + + global_batch_size = config.per_device_batch_size * jax.device_count() + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=config.activations_dtype, + safety_checker=None, + feature_extractor=None, + from_pt=config.from_pt, + ) + rng = jax.random.PRNGKey(config.seed) + p_encode = None + p_vae_apply = None + if config.cache_latents_text_encoder_outputs: + p_encode = jax.jit( + partial( + encode_xl, + text_encoders=[pipeline.text_encoder, pipeline.text_encoder_2], + text_encoder_params=[params["text_encoder"], params["text_encoder_2"]], + ) + ) + p_vae_apply = jax.jit(partial(vae_apply, vae=pipeline.vae, vae_params=params["vae"])) + tokenize_fn = partial( + tokenize_captions_xl, + caption_column=config.caption_column, + tokenizers=[pipeline.tokenizer, pipeline.tokenizer_2], + p_encode=p_encode, + ) + image_transforms_fn = partial( + transform_images, + image_column=config.image_column, + image_resolution=config.resolution, + rng=rng, + global_batch_size=global_batch_size, + p_vae_apply=p_vae_apply, + ) + + train_iterator = make_data_iterator( + config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn + ) + data = next(train_iterator) + device_count = jax.device_count() + + assert data["input_ids"].shape == (device_count, 2, 77) + assert data["pixel_values"].shape == (device_count, 3, config.resolution, config.resolution) + + def test_make_laion_tfrecord_iterator(self): pyconfig.initialize( [ None, os.path.join(THIS_DIR, "..", "configs", "base_2_base.yml"), "cache_latents_text_encoder_outputs=True", "train_data_dir=gs://jfacevedo-maxdiffusion/laion400m/processed/laion400m_tfrec", + "dataset_type=tfrecord", ], unittest=True, ) @@ -341,12 +514,8 @@ def test_make_laion_iterator(self): from_pt=config.from_pt, ) - train_iterator = make_laion400m_train_iterator( - config, - mesh, - global_batch_size, - ) - data = train_iterator() + train_iterator = make_data_iterator(config, jax.process_index(), jax.process_count(), mesh, global_batch_size) + data = next(train_iterator) device_count = jax.device_count() vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) diff --git a/src/maxdiffusion/train_utils.py b/src/maxdiffusion/train_utils.py index 42ba60d8..648c10d7 100644 --- a/src/maxdiffusion/train_utils.py +++ b/src/maxdiffusion/train_utils.py @@ -31,7 +31,7 @@ def load_next_batch(train_iter, example_batch, config): if config.reuse_example_batch and example_batch is not None: return example_batch else: - return train_iter() + return next(train_iter) def validate_train_config(config): diff --git a/src/maxdiffusion/trainers/dreambooth_trainer.py b/src/maxdiffusion/trainers/dreambooth_trainer.py index 2decc008..bf10dc65 100644 --- a/src/maxdiffusion/trainers/dreambooth_trainer.py +++ b/src/maxdiffusion/trainers/dreambooth_trainer.py @@ -74,7 +74,7 @@ def __init__(self, config): def get_shaped_batch(self, config, pipeline): vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) - total_train_batch_size = config.per_device_batch_size * jax.device_count() + total_train_batch_size = config.total_train_batch_size batch_image_shape = ( total_train_batch_size, 4, diff --git a/src/maxdiffusion/trainers/sdxl_trainer.py b/src/maxdiffusion/trainers/sdxl_trainer.py index 03264223..88c9733f 100644 --- a/src/maxdiffusion/trainers/sdxl_trainer.py +++ b/src/maxdiffusion/trainers/sdxl_trainer.py @@ -25,7 +25,7 @@ from flax.linen import partitioning as nn_partitioning from maxdiffusion.trainers.stable_diffusion_trainer import (StableDiffusionTrainer) -from maxdiffusion.input_pipeline.input_pipeline_interface import (make_pokemon_train_iterator) +from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator) from maxdiffusion import (max_utils, maxdiffusion_utils, max_logging) @@ -59,80 +59,92 @@ def get_shaped_batch(self, config, pipeline): output of create_data_iterator_with_tokenizer, but eval_shape doesn't work, see b/306901078.""" vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) - total_train_batch_size = config.per_device_batch_size * jax.device_count() + total_train_batch_size = config.total_train_batch_size + shaped_batch = {} - batch_latent_shape = ( - total_train_batch_size, - pipeline.unet.config.in_channels, - config.resolution // vae_scale_factor, - config.resolution // vae_scale_factor, - ) + if self.config.dataset_type == "tf" and self.config.cache_latents_text_encoder_outputs: + batch_image_shape = ( + total_train_batch_size, + pipeline.unet.config.in_channels, + config.resolution // vae_scale_factor, + config.resolution // vae_scale_factor, + ) - text_embeds_dim = pipeline.unet.config.projection_class_embeddings_input_dim - ( - 6 * pipeline.unet.config.addition_time_embed_dim - ) - text_embeds_shape = (total_train_batch_size, text_embeds_dim) - input_ids_shape = (total_train_batch_size, 2, pipeline.text_encoder.config.max_position_embeddings) - prompt_embeds_shape = (total_train_batch_size, pipeline.text_encoder.config.max_position_embeddings, 2048) + text_embeds_dim = pipeline.unet.config.projection_class_embeddings_input_dim - ( + 6 * pipeline.unet.config.addition_time_embed_dim + ) + text_embeds_shape = (total_train_batch_size, text_embeds_dim) + prompt_embeds_shape = (total_train_batch_size, pipeline.text_encoder.config.max_position_embeddings, 2048) + shaped_batch["prompt_embeds"] = jax.ShapeDtypeStruct(prompt_embeds_shape, jnp.float32) + shaped_batch["text_embeds"] = jax.ShapeDtypeStruct(text_embeds_shape, jnp.float32) + else: + batch_image_shape = (total_train_batch_size, 3, self.config.resolution, self.config.resolution) - shaped_batch = {} - shaped_batch["pixel_values"] = jax.ShapeDtypeStruct(batch_latent_shape, jnp.float32) - shaped_batch["prompt_embeds"] = jax.ShapeDtypeStruct(prompt_embeds_shape, jnp.float32) - shaped_batch["text_embeds"] = jax.ShapeDtypeStruct(text_embeds_shape, jnp.float32) + input_ids_shape = (total_train_batch_size, 2, pipeline.text_encoder.config.max_position_embeddings) + shaped_batch["pixel_values"] = jax.ShapeDtypeStruct(batch_image_shape, jnp.float32) shaped_batch["input_ids"] = jax.ShapeDtypeStruct(input_ids_shape, jnp.int32) return shaped_batch def get_data_shardings(self): data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) - data_sharding = { - "input_ids": data_sharding, - "pixel_values": data_sharding, - "prompt_embeds": data_sharding, - "text_embeds": data_sharding, - } + if self.config.dataset_type == "tf" and self.config.cache_latents_text_encoder_outputs: + data_sharding = { + "input_ids": data_sharding, + "pixel_values": data_sharding, + "prompt_embeds": data_sharding, + "text_embeds": data_sharding, + } + else: + data_sharding = { + "input_ids": data_sharding, + "pixel_values": data_sharding, + } return data_sharding def load_dataset(self, pipeline, params, train_states): config = self.config total_train_batch_size = self.total_train_batch_size mesh = self.mesh - - # ideally : diffusers/pokemon-gpt4-captions, but if loading from gcs, make sure the folder has pokemon in the name. - if "pokemon" in self.config.dataset_name: - p_encode = None - p_vae_apply = None - if config.cache_latents_text_encoder_outputs: - p_encode = jax.jit( - partial( - encode_xl, - text_encoders=[pipeline.text_encoder, pipeline.text_encoder_2], - text_encoder_params=[train_states["text_encoder_state"].params, train_states["text_encoder_2_state"].params], - ) - ) - p_vae_apply = jax.jit( - partial(maxdiffusion_utils.vae_apply, vae=pipeline.vae, vae_params=train_states["vae_state"].params) - ) - else: - raise ValueError("cache_latents_text_encoder_outputs = False currently not supported!") - - tokenize_fn = partial( - tokenize_captions_xl, - caption_column=config.caption_column, - tokenizers=[pipeline.tokenizer, pipeline.tokenizer_2], - p_encode=p_encode, + p_encode = None + p_vae_apply = None + rng = None + if config.dataset_type == "tf" and config.cache_latents_text_encoder_outputs: + p_encode = jax.jit( + partial( + maxdiffusion_utils.encode_xl, + text_encoders=[pipeline.text_encoder, pipeline.text_encoder_2], + text_encoder_params=[train_states["text_encoder_state"].params, train_states["text_encoder_2_state"].params], + ) ) - image_transforms_fn = partial( - maxdiffusion_utils.transform_images, - image_column=config.image_column, - image_resolution=config.resolution, - rng=self.rng, - global_batch_size=total_train_batch_size, - p_vae_apply=p_vae_apply, + p_vae_apply = jax.jit( + partial(maxdiffusion_utils.vae_apply, vae=pipeline.vae, vae_params=train_states["vae_state"].params) ) + rng = self.rng - data_iterator = make_pokemon_train_iterator(config, mesh, total_train_batch_size, tokenize_fn, image_transforms_fn) - else: - raise ValueError(f"{config.dataset_name} is currently not supported in this pipeline.") + tokenize_fn = partial( + maxdiffusion_utils.tokenize_captions_xl, + caption_column=config.caption_column, + tokenizers=[pipeline.tokenizer, pipeline.tokenizer_2], + p_encode=p_encode, + ) + image_transforms_fn = partial( + maxdiffusion_utils.transform_images, + image_column=config.image_column, + image_resolution=config.resolution, + rng=rng, + global_batch_size=total_train_batch_size, + p_vae_apply=p_vae_apply, + ) + + data_iterator = make_data_iterator( + config, + jax.process_index(), + jax.process_count(), + mesh, + total_train_batch_size, + tokenize_fn=tokenize_fn, + image_transforms_fn=image_transforms_fn, + ) return data_iterator @@ -142,14 +154,28 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): p_train_step = jax.jit( partial(_train_step, pipeline=pipeline, params=params, config=self.config), - in_shardings=(state_shardings["unet_state_shardings"], data_shardings, None), + in_shardings=( + state_shardings["unet_state_shardings"], + state_shardings["vae_state_shardings"], + None, + None, + data_shardings, + None, + ), out_shardings=(state_shardings["unet_state_shardings"], None, None), donate_argnums=(0,), ) max_logging.log("Precompiling...") s = time.time() dummy_batch = self.get_shaped_batch(self.config, pipeline) - p_train_step = p_train_step.lower(train_states["unet_state"], dummy_batch, train_rngs) + p_train_step = p_train_step.lower( + train_states["unet_state"], + train_states["vae_state"], + train_states["text_encoder_state"], + train_states["text_encoder_2_state"], + dummy_batch, + train_rngs, + ) p_train_step = p_train_step.compile() max_logging.log(f"Compile time: {(time.time() - s )}") return p_train_step @@ -158,6 +184,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera writer = max_utils.initialize_summary_writer(self.config) unet_state = train_states["unet_state"] + vae_state = train_states["vae_state"] text_encoder_state = train_states["text_encoder_state"] text_encoder_2_state = train_states["text_encoder_2_state"] num_model_parameters = max_utils.calculate_num_params_from_pytree(unet_state.params) @@ -187,8 +214,15 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera _, train_rngs = jax.random.split(self.rng) for step in np.arange(start_step, self.config.max_train_steps): + if self.config.enable_profiler and step == first_profiling_step: + max_utils.activate_profiler(self.config) + example_batch = load_next_batch(data_iterator, example_batch, self.config) - (unet_state, train_metric, train_rngs) = p_train_step(unet_state, example_batch, train_rngs) + + with jax.profiler.StepTraceAnnotation("train", step_num=step): + (unet_state, train_metric, train_rngs) = p_train_step( + unet_state, vae_state, text_encoder_state, text_encoder_2_state, example_batch, train_rngs + ) samples_count = self.total_train_batch_size * (step + 1) new_time = datetime.datetime.now() @@ -199,28 +233,30 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera if self.config.write_metrics: write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) last_step_completion = new_time - if step == first_profiling_step: - max_utils.activate_profiler(self.config) - if step == last_profiling_step: - max_utils.deactivate_profiler(self.config) if step != 0 and self.config.checkpoint_every != -1 and samples_count % self.config.checkpoint_every == 0: train_states["unet_state"] = unet_state + train_states["vae_state"] = vae_state train_states["text_encoder_state"] = text_encoder_state train_states["text_encoder_2_state"] = text_encoder_2_state self.save_checkpoint(step, pipeline, params, train_states) + if self.config.enable_profiler and step == last_profiling_step: + max_utils.deactivate_profiler(self.config) + if self.config.write_metrics: - write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) + write_metrics( + writer, local_metrics_file, running_gcs_metrics, train_metric, self.config.max_train_steps - 1, self.config + ) train_states["unet_state"] = unet_state train_states["text_encoder_state"] = text_encoder_state train_states["text_encoder_2_state"] = text_encoder_2_state - self.save_checkpoint(step, pipeline, params, train_states) + self.save_checkpoint(self.config.max_train_steps - 1, pipeline, params, train_states) self.checkpoint_manager.wait_until_finished() -def _train_step(unet_state, batch, train_rng, pipeline, params, config): +def _train_step(unet_state, vae_state, text_encoder_state, text_encoder_2_state, batch, train_rng, pipeline, params, config): _, gen_dummy_rng = jax.random.split(train_rng) sample_rng, timestep_bias_rng, new_train_rng = jax.random.split(gen_dummy_rng, 3) @@ -230,12 +266,22 @@ def _train_step(unet_state, batch, train_rng, pipeline, params, config): state_params = {"unet": unet_state.params} def compute_loss(state_params): - if config.cache_latents_text_encoder_outputs: + if config.dataset_type == "tf" and config.cache_latents_text_encoder_outputs: latents = batch["pixel_values"] prompt_embeds = batch["prompt_embeds"] text_embeds = batch["text_embeds"] else: - raise ValueError("cache_latents_text_encoder_outputs = False currently not supported!") + latents = maxdiffusion_utils.vae_apply( + images=batch["pixel_values"], + sample_rng=sample_rng, + vae=pipeline.vae, + vae_params=vae_state.params, + ) + prompt_embeds, text_embeds = maxdiffusion_utils.encode_xl( + input_ids=batch["input_ids"], + text_encoders=[pipeline.text_encoder, pipeline.text_encoder_2], + text_encoder_params=[text_encoder_state.params, text_encoder_2_state.params], + ) # Sample noise that we'll add to the latents noise_rng, timestep_rng = jax.random.split(sample_rng) @@ -308,37 +354,3 @@ def compute_loss(state_params): metrics = {"scalar": {"learning/loss": loss}, "scalars": {}} return new_state, metrics, new_train_rng - - -def encode_xl(input_ids, text_encoders, text_encoder_params): - te_1_inputs = input_ids[:, 0, :] - te_2_inputs = input_ids[:, 1, :] - - prompt_embeds = text_encoders[0](te_1_inputs, params=text_encoder_params[0], output_hidden_states=True) - prompt_embeds = prompt_embeds["hidden_states"][-2] - - prompt_embeds_2_out = text_encoders[1](te_2_inputs, params=text_encoder_params[1], output_hidden_states=True) - prompt_embeds_2 = prompt_embeds_2_out["hidden_states"][-2] - text_embeds = prompt_embeds_2_out["text_embeds"] - prompt_embeds = jnp.concatenate([prompt_embeds, prompt_embeds_2], axis=-1) - - return prompt_embeds, text_embeds - - -def tokenize_captions_xl(examples, caption_column, tokenizers, p_encode=None): - inputs = [] - captions = list(examples[caption_column]) - for _tokenizer in tokenizers: - text_inputs = _tokenizer( - captions, padding="max_length", max_length=_tokenizer.model_max_length, truncation=True, return_tensors="np" - ) - inputs.append(text_inputs.input_ids) - inputs = np.stack(inputs, axis=1) - - if p_encode: - prompt_embeds, text_embeds = p_encode(inputs) - # pyarrow dataset doesn't support bf16, so cast to float32. - examples["prompt_embeds"] = np.float32(prompt_embeds) - examples["text_embeds"] = np.float32(text_embeds) - examples["input_ids"] = inputs - return examples diff --git a/src/maxdiffusion/trainers/stable_diffusion_trainer.py b/src/maxdiffusion/trainers/stable_diffusion_trainer.py index 01c497db..c46ea38e 100644 --- a/src/maxdiffusion/trainers/stable_diffusion_trainer.py +++ b/src/maxdiffusion/trainers/stable_diffusion_trainer.py @@ -28,7 +28,7 @@ from maxdiffusion import (FlaxDDPMScheduler, maxdiffusion_utils, train_utils, max_utils, max_logging) -from maxdiffusion.input_pipeline.input_pipeline_interface import (make_pokemon_train_iterator, make_laion400m_train_iterator) +from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator) from maxdiffusion.checkpointing.base_stable_diffusion_checkpointer import (STABLE_DIFFUSION_CHECKPOINT) @@ -53,7 +53,7 @@ def get_shaped_batch(self, config, pipeline): """ vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) total_train_batch_size = self.total_train_batch_size - if config.cache_latents_text_encoder_outputs: + if config.dataset_type == "tf" and config.cache_latents_text_encoder_outputs: batch_image_shape = ( total_train_batch_size, 4, @@ -89,41 +89,45 @@ def get_data_shardings(self): return data_sharding def load_dataset(self, pipeline, params, train_states): - # ideally : diffusers/pokemon-gpt4-captions, but if loading from gcs, make sure the folder has pokemon in the name. - if "pokemon" in self.config.dataset_name: - p_encode = None - p_vae_apply = None - if self.config.cache_latents_text_encoder_outputs: - p_encode = jax.jit( - partial( - maxdiffusion_utils.encode, - text_encoder=pipeline.text_encoder, - text_encoder_params=train_states["text_encoder_state"].params, - ) - ) - p_vae_apply = jax.jit( - partial(maxdiffusion_utils.vae_apply, vae=pipeline.vae, vae_params=train_states["vae_state"].params) - ) - tokenize_fn = partial( - maxdiffusion_utils.tokenize_captions, - caption_column=self.config.caption_column, - tokenizer=pipeline.tokenizer, - p_encode=p_encode, - ) - image_transforms_fn = partial( - maxdiffusion_utils.transform_images, - image_column=self.config.image_column, - image_resolution=self.config.resolution, - rng=self.rng, - global_batch_size=self.total_train_batch_size, - p_vae_apply=p_vae_apply, + p_encode = None + p_vae_apply = None + rng = None + if self.config.dataset_type == "tf" and self.config.cache_latents_text_encoder_outputs: + p_encode = jax.jit( + partial( + maxdiffusion_utils.encode, + text_encoder=pipeline.text_encoder, + text_encoder_params=train_states["text_encoder_state"].params, + ) ) - data_iterator = make_pokemon_train_iterator( - self.config, self.mesh, self.total_train_batch_size, tokenize_fn, image_transforms_fn + p_vae_apply = jax.jit( + partial(maxdiffusion_utils.vae_apply, vae=pipeline.vae, vae_params=train_states["vae_state"].params) ) - else: - data_iterator = make_laion400m_train_iterator(self.config, self.mesh, self.total_train_batch_size) + rng = self.rng + tokenize_fn = partial( + maxdiffusion_utils.tokenize_captions, + caption_column=self.config.caption_column, + tokenizer=pipeline.tokenizer, + p_encode=p_encode, + ) + image_transforms_fn = partial( + maxdiffusion_utils.transform_images, + image_column=self.config.image_column, + image_resolution=self.config.resolution, + rng=rng, + global_batch_size=self.total_train_batch_size, + p_vae_apply=p_vae_apply, + ) + data_iterator = make_data_iterator( + self.config, + jax.process_index(), + jax.process_count(), + self.mesh, + self.total_train_batch_size, + tokenize_fn=tokenize_fn, + image_transforms_fn=image_transforms_fn, + ) return data_iterator def compile_train_step(self, pipeline, params, train_states, state_shardings, data_shardings): @@ -185,10 +189,15 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera _, train_rngs = jax.random.split(self.rng) for step in np.arange(start_step, self.config.max_train_steps): + if self.config.enable_profiler and step == first_profiling_step: + max_utils.activate_profiler(self.config) + example_batch = train_utils.load_next_batch(data_iterator, example_batch, self.config) - unet_state, text_encoder_state, train_metric, train_rngs = p_train_step( - unet_state, vae_state, text_encoder_state, example_batch, train_rngs - ) + + with jax.profiler.StepTraceAnnotation("train", step_num=step): + unet_state, text_encoder_state, train_metric, train_rngs = p_train_step( + unet_state, vae_state, text_encoder_state, example_batch, train_rngs + ) samples_count = self.total_train_batch_size * (step + 1) new_time = datetime.datetime.now() @@ -198,10 +207,6 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera if self.config.write_metrics: train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) last_step_completion = new_time - if step == first_profiling_step: - max_utils.activate_profiler(self.config) - if step == last_profiling_step: - max_utils.deactivate_profiler(self.config) if step != 0 and self.config.checkpoint_every != -1 and samples_count % self.config.checkpoint_every == 0: train_states["unet_state"] = unet_state @@ -209,14 +214,19 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera train_states["text_encoder"] = text_encoder_state self.save_checkpoint(step, pipeline, params, train_states) + if self.config.enable_profiler and step == last_profiling_step: + max_utils.deactivate_profiler(self.config) + if self.config.write_metrics: - train_utils.write_metrics(writer, local_metrics_file, running_gcs_metrics, train_metric, step, self.config) + train_utils.write_metrics( + writer, local_metrics_file, running_gcs_metrics, train_metric, self.config.max_train_steps - 1, self.config + ) train_states["unet_state"] = unet_state train_states["vae_state"] = vae_state train_states["text_encoder"] = text_encoder_state # save the inference states of the last checkpoint so they can be easily loaded during gen. - self.save_checkpoint(step, pipeline, params, train_states) + self.save_checkpoint(self.config.max_train_steps - 1, pipeline, params, train_states) self.checkpoint_manager.wait_until_finished() @@ -231,7 +241,7 @@ def _train_step(unet_state, vae_state, text_encoder_state, batch, train_rng, pip def compute_loss(state_params): - if config.cache_latents_text_encoder_outputs: + if config.dataset_type == "tf" and config.cache_latents_text_encoder_outputs: latents = batch["pixel_values"] encoder_hidden_states = batch["input_ids"] else: @@ -250,7 +260,9 @@ def compute_loss(state_params): batch["input_ids"], pipeline.text_encoder, state_params["text_encoder"] ) else: - encoder_hidden_states = maxdiffusion_utils.encode(batch["input_ids"], pipeline.text_encoder, params["text_encoder"]) + encoder_hidden_states = maxdiffusion_utils.encode( + batch["input_ids"], pipeline.text_encoder, text_encoder_state.params + ) # Sample noise that we'll add to the latents noise_rng, timestep_rng = jax.random.split(sample_rng)