From 718141fb5c2f4afb2f1dbe7a93306c7a82bb92d7 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 10 Dec 2024 16:50:02 +0000 Subject: [PATCH] don't rename key for text encoders so that pytree matches original. --- src/maxdiffusion/models/modeling_flax_pytorch_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 5f02ec8..9da8646 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -137,7 +137,11 @@ def create_flax_params_from_pytorch_state( # Need to change some parameters name to match Flax names for pt_key, pt_tensor in pt_state_dict.items(): network_alpha_value = get_network_alpha_value(pt_key, network_alphas) - renamed_pt_key = rename_key(pt_key) + + # only rename the unet keys, text encoders are already correct. + if "unet" in pt_key: + renamed_pt_key = rename_key(pt_key) + pt_tuple_key = tuple(renamed_pt_key.split(".")) # conv if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: