Skip to content

Commit

Permalink
code style formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
entrpn committed Dec 13, 2024
1 parent 0bfe06a commit 7d8d883
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 5 deletions.
2 changes: 0 additions & 2 deletions src/maxdiffusion/generate_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,11 @@ def run(config):
# maybe load lora and create interceptor
params, lora_interceptors = maybe_load_lora(config, pipeline, params)


if config.lightning_repo:
pipeline, params = load_sdxllightning_unet(config, pipeline, params)

# Don't restore the full train state, instead, just restore params
# and create an inference state.
#with nn.intercept_methods(lora_interceptor):
with ExitStack() as stack:
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
unet_state, unet_state_shardings = max_utils.setup_initial_state(
Expand Down
4 changes: 3 additions & 1 deletion src/maxdiffusion/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ def make_lora_interceptor(cls, params, rank, network_alphas, adapter_name):
network_alphas_for_interceptor.update(text_encoder_alphas)
if "text_encoder_2" in params.keys():
text_encoder_2_keys = flax.traverse_util.flatten_dict(params["text_encoder_2"]).keys()
text_encoder_2_keys, text_encoder_2_alphas = cls.rename_for_interceptor(text_encoder_2_keys, network_alphas, adapter_name)
text_encoder_2_keys, text_encoder_2_alphas = cls.rename_for_interceptor(
text_encoder_2_keys, network_alphas, adapter_name
)
lora_keys.extend(text_encoder_2_keys)
network_alphas_for_interceptor.update(text_encoder_2_alphas)

Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/maxdiffusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _noop_interceptor(next_fn, args, kwargs, context):
# before being loaded.
# TODO - merge LoRAs here.
interceptors = []
for i in range (len(lora_config["lora_model_name_or_path"])):
for i in range(len(lora_config["lora_model_name_or_path"])):
params, rank, network_alphas = pipeline.load_lora_weights(
lora_config["lora_model_name_or_path"][i],
weight_name=lora_config["weight_name"][i],
Expand Down
8 changes: 7 additions & 1 deletion src/maxdiffusion/models/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,13 @@ def get_network_alpha_value(pt_key, network_alphas):


def create_flax_params_from_pytorch_state(
pt_state_dict, unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, network_alphas, adapter_name, is_lora=False
pt_state_dict,
unet_state_dict,
text_encoder_state_dict,
text_encoder_2_state_dict,
network_alphas,
adapter_name,
is_lora=False,
):
rank = None
renamed_network_alphas = {}
Expand Down

0 comments on commit 7d8d883

Please sign in to comment.