Skip to content

Commit

Permalink
add text encoder support.
Browse files Browse the repository at this point in the history
  • Loading branch information
entrpn committed Dec 11, 2024
1 parent 718141f commit 115467f
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 18 deletions.
15 changes: 8 additions & 7 deletions src/maxdiffusion/generate_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def run(config):
pipeline.unet, None, config, checkpoint_loader.mesh, weights_init_fn, False
)

# load unet params from orbax checkpoint
unet_params = load_params_from_path(
config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "unet_state"
)
Expand Down Expand Up @@ -253,14 +254,14 @@ def run(config):
vae_state, vae_state_shardings = checkpoint_loader.create_vae_state(
pipeline, params, checkpoint_item_name="vae_state", is_training=False
)
text_encoder_state, text_encoder_state_shardings = checkpoint_loader.create_text_encoder_state(
pipeline, params, checkpoint_item_name="text_encoder_state", is_training=False
)

text_encoder_2_state, text_encoder_2_state_shardings = checkpoint_loader.create_text_encoder_2_state(
pipeline, params, checkpoint_item_name="text_encoder_2_state", is_training=False
)
with nn.intercept_methods(lora_interceptor):
text_encoder_state, text_encoder_state_shardings = checkpoint_loader.create_text_encoder_state(
pipeline, params, checkpoint_item_name="text_encoder_state", is_training=False
)

text_encoder_2_state, text_encoder_2_state_shardings = checkpoint_loader.create_text_encoder_2_state(
pipeline, params, checkpoint_item_name="text_encoder_2_state", is_training=False
)
states = {}
state_shardings = {}

Expand Down
28 changes: 21 additions & 7 deletions src/maxdiffusion/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,22 +121,36 @@ def _get_lora_layer(cls, module_path, module, rank, network_alphas):

def rename_for_interceptor(params_keys, network_alphas):
new_params_keys = []
new_network_alphas = {}
for layer_lora in params_keys:
if "lora" in layer_lora:
new_layer_lora = layer_lora[: layer_lora.index("lora")]
if new_layer_lora not in new_params_keys:
new_params_keys.append(new_layer_lora)
network_alpha = network_alphas[layer_lora]
del network_alphas[layer_lora]
network_alphas[new_layer_lora] = network_alpha
return new_params_keys, network_alphas
new_network_alphas[new_layer_lora] = network_alpha
return new_params_keys, new_network_alphas

@classmethod
def make_lora_interceptor(cls, params, rank, network_alphas):
# Only unet interceptor supported for now.
network_alphas_for_interceptor = {}

unet_lora_keys = flax.traverse_util.flatten_dict(params["unet"]).keys()
unet_lora_keys, network_alphas = cls.rename_for_interceptor(unet_lora_keys, network_alphas)

lora_keys, unet_alphas = cls.rename_for_interceptor(unet_lora_keys, network_alphas)
network_alphas_for_interceptor.update(unet_alphas)

text_encoder_keys = flax.traverse_util.flatten_dict(params["text_encoder"]).keys()
text_encoder_keys, text_encoder_alphas = cls.rename_for_interceptor(text_encoder_keys, network_alphas)
lora_keys.extend(text_encoder_keys)
network_alphas_for_interceptor.update(text_encoder_alphas)

if "text_encoder_2" in params.keys():
text_encoder_2_keys = flax.traverse_util.flatten_dict(params["text_encoder_2"]).keys()
text_encoder_2_keys, text_encoder_2_alphas = cls.rename_for_interceptor(text_encoder_2_keys, network_alphas)
lora_keys.extend(text_encoder_2_keys)
network_alphas_for_interceptor.update(text_encoder_2_alphas)

def _intercept(next_fn, args, kwargs, context):
mod = context.module
while mod is not None:
Expand All @@ -146,8 +160,8 @@ def _intercept(next_fn, args, kwargs, context):
h = next_fn(*args, **kwargs)
if context.method_name == "__call__":
module_path = context.module.path
if module_path in unet_lora_keys:
lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas)
if module_path in lora_keys:
lora_layer = cls._get_lora_layer(module_path, context.module, rank, network_alphas_for_interceptor)
return lora_layer(h, *args, **kwargs)
return h

Expand Down
23 changes: 19 additions & 4 deletions src/maxdiffusion/models/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,15 @@ def create_flax_params_from_pytorch_state(
# Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items():
network_alpha_value = get_network_alpha_value(pt_key, network_alphas)


# rename text encoders fc1 lora layers.
pt_key = pt_key.replace("lora_linear_layer","lora")

# only rename the unet keys, text encoders are already correct.
if "unet" in pt_key:
renamed_pt_key = rename_key(pt_key)

else:
renamed_pt_key = pt_key
pt_tuple_key = tuple(renamed_pt_key.split("."))
# conv
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
Expand All @@ -151,13 +155,24 @@ def create_flax_params_from_pytorch_state(
flax_tensor = pt_tensor
else:
flax_key_list = [*pt_tuple_key]
for rename_from, rename_to in (
if "text_encoder" in pt_tuple_key or "text_encoder_2" in pt_tuple_key:
rename_from_to = (
("to_k_lora", ("k_proj", "lora")),
("to_q_lora", ("q_proj", "lora")),
("to_v_lora", ("v_proj", "lora")),
("to_out_lora", ("out_proj", "lora")),
("weight", "kernel"),
)
# the unet
else:
rename_from_to = (
("to_k_lora", ("to_k", "lora")),
("to_q_lora", ("to_q", "lora")),
("to_v_lora", ("to_v", "lora")),
("to_out_lora", ("to_out_0", "lora")),
("weight", "kernel"),
):
)
for rename_from, rename_to in rename_from_to:
tmp = []
for s in flax_key_list:
if s == rename_from:
Expand Down

0 comments on commit 115467f

Please sign in to comment.