From 8d53ef6b2915ca48f8a8c5bbdccee7ded0354e97 Mon Sep 17 00:00:00 2001 From: ssusie Date: Thu, 3 Oct 2024 17:56:26 +0000 Subject: [PATCH] add single replica restore and broadcast --- README.md | 12 +-- .../base_stable_diffusion_checkpointer.py | 8 +- .../checkpointing/checkpointing_utils.py | 81 +++++++++++++++++-- src/maxdiffusion/configs/base14.yml | 18 +++-- src/maxdiffusion/configs/base21.yml | 18 +++-- src/maxdiffusion/configs/base_2_base.yml | 18 +++-- src/maxdiffusion/configs/base_xl.yml | 16 ++-- .../configs/base_xl_lightning.yml | 15 ++-- src/maxdiffusion/max_utils.py | 7 +- 9 files changed, 134 insertions(+), 59 deletions(-) diff --git a/README.md b/README.md index 84be41de..59456b13 100644 --- a/README.md +++ b/README.md @@ -178,18 +178,10 @@ MaxDiffusion started as a fork of [Diffusers](https://github.com/huggingface/dif Whether you are forking MaxDiffusion for your own needs or intending to contribute back to the community, a full suite of tests can be found in `tests` and `src/maxdiffusion/tests`. -To run unit tests, simply run: +To run unit tests and lint, simply run: ``` python -m pytest +ruff check --fix . ``` -This project uses `pylint` and `pyink` to enforce code style. Before submitting a pull request, please ensure your code passes these checks by running: - -``` -bash code_style.sh -``` - -This script will automatically format your code with `pyink` and help you identify any remaining style issues. - - The full suite of -end-to end tests is in `tests` and `src/maxdiffusion/tests`. We run them with a nightly cadance. diff --git a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py index d297832e..b9064b2a 100644 --- a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py +++ b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py @@ -224,12 +224,12 @@ def config_to_json(model_or_config): "scheduler_config": ocp.args.JsonSave(config_to_json(pipeline.scheduler)), } - items["unet_state"] = ocp.args.StandardSave(train_states["unet_state"]) - items["vae_state"] = ocp.args.StandardSave(train_states["vae_state"]) - items["text_encoder_state"] = ocp.args.StandardSave(train_states["text_encoder_state"]) + items["unet_state"] = ocp.args.PyTreeSave(train_states["unet_state"]) + items["vae_state"] = ocp.args.PyTreeSave(train_states["vae_state"]) + items["text_encoder_state"] = ocp.args.PyTreeSave(train_states["text_encoder_state"]) if hasattr(pipeline, "text_encoder_2"): - items["text_encoder_2_state"] = ocp.args.StandardSave(train_states["text_encoder_2_state"]) + items["text_encoder_2_state"] = ocp.args.PyTreeSave(train_states["text_encoder_2_state"]) items["text_encoder_2_config"] = ocp.args.JsonSave(config_to_json(pipeline.text_encoder_2.config)) tokenizer_config = {"path": self.config.tokenizer_model_name_or_path} diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index 074d642b..65b49c8d 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -18,7 +18,11 @@ """Create an Orbax CheckpointManager with specified (Async or not) Checkpointer.""" from typing import Optional, Any +import jax +import numpy as np import os + +import orbax.checkpoint from maxdiffusion import max_logging from etils import epath from flax.training import train_state @@ -90,7 +94,7 @@ def load_stable_diffusion_configs( ): f""" Loads Orbax configurations for different stable diffusion models - + Args: checkpoint_manager (`orbax.checkpoint.checkpoint_manager`) checkpoint_type (`str`) : use sd or sdxl @@ -140,8 +144,37 @@ def load_params_from_path( return restored["params"] +def _find_idx(array: np.ndarray, replica_axis_idx: int): + """Returns the index along given dimension that the current host belongs to.""" + idx = None + for idx, val in np.ndenumerate(array): + if val.process_index == jax.process_index(): + break + return idx[replica_axis_idx] + + +def _replica_devices(device_array: np.ndarray, replica_axis_idx: int): + """Returns the devices from the replica that current host belongs to. + + Replicas are assumed to be restricted to the first axis. + + Args: + device_array: devices of the mesh that can be obtained by mesh.devices() + replica_axis_idx: axis dimension along which replica is taken + + Returns: + devices inside the replica that current host is in + """ + idx = _find_idx(device_array, replica_axis_idx) + replica_result = np.take(device_array, idx, axis=replica_axis_idx) + return np.expand_dims(replica_result, axis=replica_axis_idx) + + def load_state_if_possible( - checkpoint_manager: CheckpointManager, abstract_unboxed_pre_state: train_state.TrainState, checkpoint_item: str + checkpoint_manager: CheckpointManager, + abstract_unboxed_pre_state: train_state.TrainState, + checkpoint_item: str, + enable_single_replica_ckpt_restoring: bool, ): """Loads TrainState as possible from the inputs. @@ -151,6 +184,8 @@ def load_state_if_possible( abstract_unboxed_pre_state: an unboxed, abstract TrainState that Orbax matches type against. checkpoint_item: the name of the checkpoint item that is being loaded. Ex: vae_state + enable_single_replica_ckpt_restoring: bool flag for restoring checkpoitng + with SingleReplicaArrayHandler Returns: A tuple of (train_state, train_state_params) where full_train_state captures @@ -167,9 +202,41 @@ def load_state_if_possible( return None else: max_logging.log(f"restoring from this run's directory latest step {latest_step}") - try: - item = {checkpoint_item: orbax.checkpoint.args.StandardRestore(item=abstract_unboxed_pre_state)} + # try: + if True: + if not enable_single_replica_ckpt_restoring: + item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} + return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) + + def map_to_pspec(data): + pspec = data.sharding.spec + mesh = data.sharding.mesh + if not enable_single_replica_ckpt_restoring: + return ocp.type_handlers.ArrayRestoreArgs(mesh=mesh, mesh_axes=pspec) + replica_axis_index = 0 + replica_devices = _replica_devices(mesh.devices, replica_axis_index) + replica_mesh = jax.sharding.Mesh(replica_devices, mesh.axis_names) + single_replica_sharding = jax.sharding.NamedSharding(replica_mesh, pspec) + + return ocp.type_handlers.SingleReplicaArrayRestoreArgs( + sharding=jax.sharding.NamedSharding(mesh, pspec), + single_replica_sharding=single_replica_sharding, + global_shape=data.shape, + dtype=data.dtype, + ) + + array_handler = ocp.type_handlers.SingleReplicaArrayHandler( + replica_axis_index=0, + broadcast_memory_limit_bytes=1024 * 1024 * 1000, # 1000 MB limit + ) + ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True) + + restore_args = jax.tree_util.tree_map( + map_to_pspec, + abstract_unboxed_pre_state, + ) + item = {checkpoint_item: ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args)} return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item)) - except: - max_logging.log(f"could not load {checkpoint_item} from orbax") - return None + # except: + # max_logging.log(f"could not load {checkpoint_item} from orbax") + # return None diff --git a/src/maxdiffusion/configs/base14.yml b/src/maxdiffusion/configs/base14.yml index b5f938c1..0909a75b 100644 --- a/src/maxdiffusion/configs/base14.yml +++ b/src/maxdiffusion/configs/base14.yml @@ -66,7 +66,7 @@ timestep_bias: { begin: 0, # when using strategy=range, the final step (inclusive) to bias. end: 1000, - # portion of timesteps to bias. + # portion of timesteps to bias. # 0.5 will bias one half of the timesteps. Value of strategy determines # whether the biased portions are in the earlier or later timesteps. portion: 0.25 @@ -75,7 +75,7 @@ timestep_bias: { # Override parameters from checkpoints's scheduler. diffusion_scheduler_config: { _class_name: '', - # values are v_prediction or leave empty to use scheduler's default. + # values are v_prediction or leave empty to use scheduler's default. prediction_type: '', rescale_zero_terminal_snr: False, timestep_spacing: '' @@ -87,12 +87,12 @@ base_output_directory: "" mesh_axes: ['data', 'fsdp', 'tensor'] # batch : batch dimension of data and activations -# hidden : +# hidden : # embed : attention qkv dense layer hidden dim named as embed # heads : attention head dim = num_heads * head_dim # length : attention sequence length -# temb_in : dense.shape[0] of resnet dense before conv -# out_c : dense.shape[1] of resnet dense before conv +# temb_in : dense.shape[0] of resnet dense before conv +# out_c : dense.shape[1] of resnet dense before conv # out_channels : conv.shape[-1] activation # keep_1 : conv.shape[0] weight # keep_2 : conv.shape[1] weight @@ -118,7 +118,7 @@ data_sharding: [['data', 'fsdp', 'tensor']] dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: 1 dcn_tensor_parallelism: 1 -ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e +ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 @@ -144,6 +144,8 @@ enable_data_shuffling: True # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False # Prepare image latents and text encoder outputs # during dataset creation to reduce memory consumption. @@ -165,7 +167,7 @@ per_device_batch_size: 1 warmup_steps_fraction: 0.0 learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. -# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before # dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. # AdamW optimizer parameters @@ -205,4 +207,4 @@ class_prompt: '' prior_loss_weight: 1.0 num_class_images: 100 # If true, set dataset_save_location. -cache_dreambooth_dataset: False \ No newline at end of file +cache_dreambooth_dataset: False diff --git a/src/maxdiffusion/configs/base21.yml b/src/maxdiffusion/configs/base21.yml index c8a154a3..6ab2cc30 100644 --- a/src/maxdiffusion/configs/base21.yml +++ b/src/maxdiffusion/configs/base21.yml @@ -66,7 +66,7 @@ timestep_bias: { begin: 0, # when using strategy=range, the final step (inclusive) to bias. end: 1000, - # portion of timesteps to bias. + # portion of timesteps to bias. # 0.5 will bias one half of the timesteps. Value of strategy determines # whether the biased portions are in the earlier or later timesteps. portion: 0.25 @@ -75,7 +75,7 @@ timestep_bias: { # Override parameters from checkpoints's scheduler. diffusion_scheduler_config: { _class_name: '', - # values are v_prediction or leave empty to use scheduler's default. + # values are v_prediction or leave empty to use scheduler's default. prediction_type: '', rescale_zero_terminal_snr: False, timestep_spacing: '' @@ -89,12 +89,12 @@ base_output_directory: "" mesh_axes: ['data', 'fsdp', 'tensor'] # batch : batch dimension of data and activations -# hidden : +# hidden : # embed : attention qkv dense layer hidden dim named as embed # heads : attention head dim = num_heads * head_dim # length : attention sequence length -# temb_in : dense.shape[0] of resnet dense before conv -# out_c : dense.shape[1] of resnet dense before conv +# temb_in : dense.shape[0] of resnet dense before conv +# out_c : dense.shape[1] of resnet dense before conv # out_channels : conv.shape[-1] activation # keep_1 : conv.shape[0] weight # keep_2 : conv.shape[1] weight @@ -120,7 +120,7 @@ data_sharding: [['data', 'fsdp', 'tensor']] dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: 1 dcn_tensor_parallelism: 1 -ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e +ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 @@ -146,6 +146,8 @@ enable_data_shuffling: True # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False # Prepare image latents and text encoder outputs # during dataset creation to reduce memory consumption. @@ -165,7 +167,7 @@ per_device_batch_size: 1 warmup_steps_fraction: 0.0 learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. -# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before # dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. # AdamW optimizer parameters @@ -201,4 +203,4 @@ class_prompt: '' prior_loss_weight: 1.0 num_class_images: 100 # If true, set dataset_save_location. -cache_dreambooth_dataset: False \ No newline at end of file +cache_dreambooth_dataset: False diff --git a/src/maxdiffusion/configs/base_2_base.yml b/src/maxdiffusion/configs/base_2_base.yml index 2c860109..6f485b19 100644 --- a/src/maxdiffusion/configs/base_2_base.yml +++ b/src/maxdiffusion/configs/base_2_base.yml @@ -78,7 +78,7 @@ timestep_bias: { begin: 0, # when using strategy=range, the final step (inclusive) to bias. end: 1000, - # portion of timesteps to bias. + # portion of timesteps to bias. # 0.5 will bias one half of the timesteps. Value of strategy determines # whether the biased portions are in the earlier or later timesteps. portion: 0.25 @@ -88,7 +88,7 @@ timestep_bias: { # Override parameters from checkpoints's scheduler. diffusion_scheduler_config: { _class_name: '', - # values are v_prediction or leave empty to use scheduler's default. + # values are v_prediction or leave empty to use scheduler's default. prediction_type: '', rescale_zero_terminal_snr: False, timestep_spacing: '' @@ -102,12 +102,12 @@ base_output_directory: "" mesh_axes: ['data', 'fsdp', 'tensor'] # batch : batch dimension of data and activations -# hidden : +# hidden : # embed : attention qkv dense layer hidden dim named as embed # heads : attention head dim = num_heads * head_dim # length : attention sequence length -# temb_in : dense.shape[0] of resnet dense before conv -# out_c : dense.shape[1] of resnet dense before conv +# temb_in : dense.shape[0] of resnet dense before conv +# out_c : dense.shape[1] of resnet dense before conv # out_channels : conv.shape[-1] activation # keep_1 : conv.shape[0] weight # keep_2 : conv.shape[1] weight @@ -133,7 +133,7 @@ data_sharding: [['data', 'fsdp', 'tensor']] dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: 1 dcn_tensor_parallelism: 1 -ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e +ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 @@ -159,6 +159,8 @@ enable_data_shuffling: True # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False # Prepare image latents and text encoder outputs # during dataset creation to reduce memory consumption. @@ -178,7 +180,7 @@ per_device_batch_size: 1 warmup_steps_fraction: 0.0 learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. -# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before # dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. # AdamW optimizer parameters @@ -218,4 +220,4 @@ class_prompt: '' prior_loss_weight: 1.0 num_class_images: 100 # If true, set dataset_save_location. -cache_dreambooth_dataset: False \ No newline at end of file +cache_dreambooth_dataset: False diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index 18c4e94b..abce1aee 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -67,7 +67,7 @@ timestep_bias: { begin: 0, # when using strategy=range, the final step (inclusive) to bias. end: 1000, - # portion of timesteps to bias. + # portion of timesteps to bias. # 0.5 will bias one half of the timesteps. Value of strategy determines # whether the biased portions are in the earlier or later timesteps. portion: 0.25 @@ -76,7 +76,7 @@ timestep_bias: { # Override parameters from checkpoints's scheduler. diffusion_scheduler_config: { _class_name: '', - # values are v_prediction or leave empty to use scheduler's default. + # values are v_prediction or leave empty to use scheduler's default. prediction_type: '', rescale_zero_terminal_snr: False, timestep_spacing: '' @@ -90,12 +90,12 @@ base_output_directory: "" mesh_axes: ['data', 'fsdp', 'tensor'] # batch : batch dimension of data and activations -# hidden : +# hidden : # embed : attention qkv dense layer hidden dim named as embed # heads : attention head dim = num_heads * head_dim # length : attention sequence length -# temb_in : dense.shape[0] of resnet dense before conv -# out_c : dense.shape[1] of resnet dense before conv +# temb_in : dense.shape[0] of resnet dense before conv +# out_c : dense.shape[1] of resnet dense before conv # out_channels : conv.shape[-1] activation # keep_1 : conv.shape[0] weight # keep_2 : conv.shape[1] weight @@ -147,6 +147,8 @@ enable_data_shuffling: True # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False # Prepare image latents and text encoder outputs # during dataset creation to reduce memory consumption. @@ -166,7 +168,7 @@ per_device_batch_size: 2 warmup_steps_fraction: 0.0 learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. -# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before # dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. # AdamW optimizer parameters @@ -204,4 +206,4 @@ enable_mllog: False controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' controlnet_from_pt: True controlnet_conditioning_scale: 0.5 -controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' \ No newline at end of file +controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' diff --git a/src/maxdiffusion/configs/base_xl_lightning.yml b/src/maxdiffusion/configs/base_xl_lightning.yml index 8973abf9..f8db2f6e 100644 --- a/src/maxdiffusion/configs/base_xl_lightning.yml +++ b/src/maxdiffusion/configs/base_xl_lightning.yml @@ -55,7 +55,7 @@ text_encoder_learning_rate: 4.25e-6 # Override parameters from checkpoints's scheduler. diffusion_scheduler_config: { _class_name: 'DDIMScheduler', - # values are v_prediction or leave empty to use scheduler's default. + # values are v_prediction or leave empty to use scheduler's default. prediction_type: '', rescale_zero_terminal_snr: False, timestep_spacing: 'trailing' @@ -69,12 +69,12 @@ base_output_directory: "" mesh_axes: ['data', 'fsdp', 'tensor'] # batch : batch dimension of data and activations -# hidden : +# hidden : # embed : attention qkv dense layer hidden dim named as embed # heads : attention head dim = num_heads * head_dim # length : attention sequence length -# temb_in : dense.shape[0] of resnet dense before conv -# out_c : dense.shape[1] of resnet dense before conv +# temb_in : dense.shape[0] of resnet dense before conv +# out_c : dense.shape[1] of resnet dense before conv # out_channels : conv.shape[-1] activation # keep_1 : conv.shape[0] weight # keep_2 : conv.shape[1] weight @@ -116,6 +116,9 @@ resolution: 1024 center_crop: False random_flip: False +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False + # Training loop learning_rate: 4.e-7 scale_lr: False @@ -130,7 +133,7 @@ per_device_batch_size: 2 warmup_steps_fraction: 0.0 learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. -# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before # dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. # AdamW optimizer parameters @@ -159,4 +162,4 @@ lightning_from_pt: True lightning_repo: "ByteDance/SDXL-Lightning" lightning_ckpt: "sdxl_lightning_4step_unet.safetensors" -enable_mllog: False \ No newline at end of file +enable_mllog: False diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 66465e8a..62d8a4ee 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -386,7 +386,12 @@ def setup_initial_state( with nn_partitioning.axis_rules(config.logical_axis_rules): if checkpoint_manager and checkpoint_item: max_logging.log(f"setup_initial_state for {checkpoint_item}") - state = checkpointing_utils.load_state_if_possible(checkpoint_manager, unboxed_abstract_state, checkpoint_item) + state = checkpointing_utils.load_state_if_possible( + checkpoint_manager, + unboxed_abstract_state, + checkpoint_item, + config.enable_single_replica_ckpt_restoring, + ) if state: state = state[checkpoint_item] if not state: