Skip to content

Commit

Permalink
add single replica restore with broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
ssusie committed Oct 3, 2024
1 parent 33bb598 commit ce47b67
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 60 deletions.
12 changes: 2 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,18 +178,10 @@ MaxDiffusion started as a fork of [Diffusers](https://github.com/huggingface/dif

Whether you are forking MaxDiffusion for your own needs or intending to contribute back to the community, a full suite of tests can be found in `tests` and `src/maxdiffusion/tests`.

To run unit tests, simply run:
To run unit tests and lint, simply run:
```
python -m pytest
ruff check --fix .
```

This project uses `pylint` and `pyink` to enforce code style. Before submitting a pull request, please ensure your code passes these checks by running:

```
bash code_style.sh
```

This script will automatically format your code with `pyink` and help you identify any remaining style issues.


The full suite of -end-to end tests is in `tests` and `src/maxdiffusion/tests`. We run them with a nightly cadance.
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,12 @@ def config_to_json(model_or_config):
"scheduler_config": ocp.args.JsonSave(config_to_json(pipeline.scheduler)),
}

items["unet_state"] = ocp.args.StandardSave(train_states["unet_state"])
items["vae_state"] = ocp.args.StandardSave(train_states["vae_state"])
items["text_encoder_state"] = ocp.args.StandardSave(train_states["text_encoder_state"])
items["unet_state"] = ocp.args.PyTreeSave(train_states["unet_state"])
items["vae_state"] = ocp.args.PyTreeSave(train_states["vae_state"])
items["text_encoder_state"] = ocp.args.PyTreeSave(train_states["text_encoder_state"])

if hasattr(pipeline, "text_encoder_2"):
items["text_encoder_2_state"] = ocp.args.StandardSave(train_states["text_encoder_2_state"])
items["text_encoder_2_state"] = ocp.args.PyTreeSave(train_states["text_encoder_2_state"])
items["text_encoder_2_config"] = ocp.args.JsonSave(config_to_json(pipeline.text_encoder_2.config))

tokenizer_config = {"path": self.config.tokenizer_model_name_or_path}
Expand Down
86 changes: 78 additions & 8 deletions src/maxdiffusion/checkpointing/checkpointing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
"""Create an Orbax CheckpointManager with specified (Async or not) Checkpointer."""

from typing import Optional, Any
import jax
import numpy as np
import os

import orbax.checkpoint
from maxdiffusion import max_logging
from etils import epath
from flax.training import train_state
Expand Down Expand Up @@ -90,7 +94,7 @@ def load_stable_diffusion_configs(
):
f"""
Loads Orbax configurations for different stable diffusion models
Args:
checkpoint_manager (`orbax.checkpoint.checkpoint_manager`)
checkpoint_type (`str`) : use sd or sdxl
Expand Down Expand Up @@ -140,8 +144,37 @@ def load_params_from_path(
return restored["params"]


def _find_idx(array: np.ndarray, replica_axis_idx: int):
"""Returns the index along given dimension that the current host belongs to."""
idx = None
for idx, val in np.ndenumerate(array):
if val.process_index == jax.process_index():
break
return idx[replica_axis_idx]


def _replica_devices(device_array: np.ndarray, replica_axis_idx: int):
"""Returns the devices from the replica that current host belongs to.
Replicas are assumed to be restricted to the first axis.
Args:
device_array: devices of the mesh that can be obtained by mesh.devices()
replica_axis_idx: axis dimension along which replica is taken
Returns:
devices inside the replica that current host is in
"""
idx = _find_idx(device_array, replica_axis_idx)
replica_result = np.take(device_array, idx, axis=replica_axis_idx)
return np.expand_dims(replica_result, axis=replica_axis_idx)


def load_state_if_possible(
checkpoint_manager: CheckpointManager, abstract_unboxed_pre_state: train_state.TrainState, checkpoint_item: str
checkpoint_manager: CheckpointManager,
abstract_unboxed_pre_state: train_state.TrainState,
checkpoint_item: str,
enable_single_replica_ckpt_restoring: bool,
):
"""Loads TrainState as possible from the inputs.
Expand All @@ -151,6 +184,8 @@ def load_state_if_possible(
abstract_unboxed_pre_state: an unboxed, abstract TrainState that Orbax
matches type against.
checkpoint_item: the name of the checkpoint item that is being loaded. Ex: vae_state
enable_single_replica_ckpt_restoring: bool flag for restoring checkpoitng
with SingleReplicaArrayHandler
Returns:
A tuple of (train_state, train_state_params) where full_train_state captures
Expand All @@ -167,9 +202,44 @@ def load_state_if_possible(
return None
else:
max_logging.log(f"restoring from this run's directory latest step {latest_step}")
try:
item = {checkpoint_item: orbax.checkpoint.args.StandardRestore(item=abstract_unboxed_pre_state)}
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))
except:
max_logging.log(f"could not load {checkpoint_item} from orbax")
return None
# try:
if True:
if not enable_single_replica_ckpt_restoring:
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))

def map_to_pspec(data):
pspec = data.sharding.spec
mesh = data.sharding.mesh
if not enable_single_replica_ckpt_restoring:
return ocp.type_handlers.ArrayRestoreArgs(mesh=mesh, mesh_axes=pspec)
replica_axis_index = 0
replica_devices = _replica_devices(mesh.devices, replica_axis_index)
replica_mesh = jax.sharding.Mesh(replica_devices, mesh.axis_names)
single_replica_sharding = jax.sharding.NamedSharding(replica_mesh, pspec)

return ocp.type_handlers.SingleReplicaArrayRestoreArgs(
sharding=jax.sharding.NamedSharding(mesh, pspec),
single_replica_sharding=single_replica_sharding,
global_shape=data.shape,
dtype=data.dtype,
)

array_handler = ocp.type_handlers.SingleReplicaArrayHandler(
replica_axis_index=0,
broadcast_memory_limit_bytes=1024 * 1024 * 1000, # 1000 MB limit
)
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)

restore_args = jax.tree_util.tree_map(
map_to_pspec,
abstract_unboxed_pre_state,
)
item = {checkpoint_item: ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args)}
return checkpoint_manager.restore(
latest_step,
args=orbax.checkpoint.args.Composite(**item)
)
# except:
# max_logging.log(f"could not load {checkpoint_item} from orbax")
# return None
18 changes: 10 additions & 8 deletions src/maxdiffusion/configs/base14.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ timestep_bias: {
begin: 0,
# when using strategy=range, the final step (inclusive) to bias.
end: 1000,
# portion of timesteps to bias.
# portion of timesteps to bias.
# 0.5 will bias one half of the timesteps. Value of strategy determines
# whether the biased portions are in the earlier or later timesteps.
portion: 0.25
Expand All @@ -75,7 +75,7 @@ timestep_bias: {
# Override parameters from checkpoints's scheduler.
diffusion_scheduler_config: {
_class_name: '',
# values are v_prediction or leave empty to use scheduler's default.
# values are v_prediction or leave empty to use scheduler's default.
prediction_type: '',
rescale_zero_terminal_snr: False,
timestep_spacing: ''
Expand All @@ -87,12 +87,12 @@ base_output_directory: ""
mesh_axes: ['data', 'fsdp', 'tensor']

# batch : batch dimension of data and activations
# hidden :
# hidden :
# embed : attention qkv dense layer hidden dim named as embed
# heads : attention head dim = num_heads * head_dim
# length : attention sequence length
# temb_in : dense.shape[0] of resnet dense before conv
# out_c : dense.shape[1] of resnet dense before conv
# temb_in : dense.shape[0] of resnet dense before conv
# out_c : dense.shape[1] of resnet dense before conv
# out_channels : conv.shape[-1] activation
# keep_1 : conv.shape[0] weight
# keep_2 : conv.shape[1] weight
Expand All @@ -118,7 +118,7 @@ data_sharding: [['data', 'fsdp', 'tensor']]
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: 1
dcn_tensor_parallelism: 1
ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1

Expand All @@ -144,6 +144,8 @@ enable_data_shuffling: True

# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1
# enables one replica to read the ckpt then broadcast to the rest
enable_single_replica_ckpt_restoring: False

# Prepare image latents and text encoder outputs
# during dataset creation to reduce memory consumption.
Expand All @@ -165,7 +167,7 @@ per_device_batch_size: 1
warmup_steps_fraction: 0.0
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.

# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.

# AdamW optimizer parameters
Expand Down Expand Up @@ -205,4 +207,4 @@ class_prompt: ''
prior_loss_weight: 1.0
num_class_images: 100
# If true, set dataset_save_location.
cache_dreambooth_dataset: False
cache_dreambooth_dataset: False
18 changes: 10 additions & 8 deletions src/maxdiffusion/configs/base21.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ timestep_bias: {
begin: 0,
# when using strategy=range, the final step (inclusive) to bias.
end: 1000,
# portion of timesteps to bias.
# portion of timesteps to bias.
# 0.5 will bias one half of the timesteps. Value of strategy determines
# whether the biased portions are in the earlier or later timesteps.
portion: 0.25
Expand All @@ -75,7 +75,7 @@ timestep_bias: {
# Override parameters from checkpoints's scheduler.
diffusion_scheduler_config: {
_class_name: '',
# values are v_prediction or leave empty to use scheduler's default.
# values are v_prediction or leave empty to use scheduler's default.
prediction_type: '',
rescale_zero_terminal_snr: False,
timestep_spacing: ''
Expand All @@ -89,12 +89,12 @@ base_output_directory: ""
mesh_axes: ['data', 'fsdp', 'tensor']

# batch : batch dimension of data and activations
# hidden :
# hidden :
# embed : attention qkv dense layer hidden dim named as embed
# heads : attention head dim = num_heads * head_dim
# length : attention sequence length
# temb_in : dense.shape[0] of resnet dense before conv
# out_c : dense.shape[1] of resnet dense before conv
# temb_in : dense.shape[0] of resnet dense before conv
# out_c : dense.shape[1] of resnet dense before conv
# out_channels : conv.shape[-1] activation
# keep_1 : conv.shape[0] weight
# keep_2 : conv.shape[1] weight
Expand All @@ -120,7 +120,7 @@ data_sharding: [['data', 'fsdp', 'tensor']]
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: 1
dcn_tensor_parallelism: 1
ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1

Expand All @@ -146,6 +146,8 @@ enable_data_shuffling: True

# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1
# enables one replica to read the ckpt then broadcast to the rest
enable_single_replica_ckpt_restoring: False

# Prepare image latents and text encoder outputs
# during dataset creation to reduce memory consumption.
Expand All @@ -165,7 +167,7 @@ per_device_batch_size: 1
warmup_steps_fraction: 0.0
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.

# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.

# AdamW optimizer parameters
Expand Down Expand Up @@ -201,4 +203,4 @@ class_prompt: ''
prior_loss_weight: 1.0
num_class_images: 100
# If true, set dataset_save_location.
cache_dreambooth_dataset: False
cache_dreambooth_dataset: False
18 changes: 10 additions & 8 deletions src/maxdiffusion/configs/base_2_base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ timestep_bias: {
begin: 0,
# when using strategy=range, the final step (inclusive) to bias.
end: 1000,
# portion of timesteps to bias.
# portion of timesteps to bias.
# 0.5 will bias one half of the timesteps. Value of strategy determines
# whether the biased portions are in the earlier or later timesteps.
portion: 0.25
Expand All @@ -88,7 +88,7 @@ timestep_bias: {
# Override parameters from checkpoints's scheduler.
diffusion_scheduler_config: {
_class_name: '',
# values are v_prediction or leave empty to use scheduler's default.
# values are v_prediction or leave empty to use scheduler's default.
prediction_type: '',
rescale_zero_terminal_snr: False,
timestep_spacing: ''
Expand All @@ -102,12 +102,12 @@ base_output_directory: ""
mesh_axes: ['data', 'fsdp', 'tensor']

# batch : batch dimension of data and activations
# hidden :
# hidden :
# embed : attention qkv dense layer hidden dim named as embed
# heads : attention head dim = num_heads * head_dim
# length : attention sequence length
# temb_in : dense.shape[0] of resnet dense before conv
# out_c : dense.shape[1] of resnet dense before conv
# temb_in : dense.shape[0] of resnet dense before conv
# out_c : dense.shape[1] of resnet dense before conv
# out_channels : conv.shape[-1] activation
# keep_1 : conv.shape[0] weight
# keep_2 : conv.shape[1] weight
Expand All @@ -133,7 +133,7 @@ data_sharding: [['data', 'fsdp', 'tensor']]
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: 1
dcn_tensor_parallelism: 1
ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1

Expand All @@ -159,6 +159,8 @@ enable_data_shuffling: True

# checkpoint every number of samples, -1 means don't checkpoint.
checkpoint_every: -1
# enables one replica to read the ckpt then broadcast to the rest
enable_single_replica_ckpt_restoring: False

# Prepare image latents and text encoder outputs
# during dataset creation to reduce memory consumption.
Expand All @@ -178,7 +180,7 @@ per_device_batch_size: 1
warmup_steps_fraction: 0.0
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.

# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.

# AdamW optimizer parameters
Expand Down Expand Up @@ -218,4 +220,4 @@ class_prompt: ''
prior_loss_weight: 1.0
num_class_images: 100
# If true, set dataset_save_location.
cache_dreambooth_dataset: False
cache_dreambooth_dataset: False
Loading

0 comments on commit ce47b67

Please sign in to comment.