Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

override Flux transformer #639

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

dxqbYD
Copy link
Collaborator

@dxqbYD dxqbYD commented Jan 12, 2025

People have started to use and share Flux .safetensors files that only contain the Flux transformer, not the entire model pipeline. The entire pipeline is almost twice the size of the transformer only.

Examples:

  • finetunes on CivitAI, to offer a smaller download size of around 11GB instead of 20
  • I have myself recently tried to train based on Flux merged with a LoRA. But this Comfy workflow to merge a LoRA only outputs the transformer: https://civitai.com/models/982277/flux-lora-merge

Using this PR, you can load the model pipeline from black-forest-labs/FLUX.1-dev, but override the transformer from a safetensors file.

@@ -252,7 +253,7 @@ def __create_base_components(
if allow_override_prior:
# prior model
components.label(self.scroll_frame, row, 0, "Prior Model",
tooltip="Filename, directory or Hugging Face repository of the prior model")
tooltip="Filename, directory or Hugging Face repository of the prior model. For Flux, it must be a safetensors file that only contains the transformer.")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why prior is used for the Flux transformer. What is "prior" about the transformer? It seems to be something from Stable Cascade, where prior referred to a different model, not the one that is trained.

But prior is already used in the UI for the transformer, see "Override Prior Data Type", so this PR doesn't change that.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "prior" name originally came from the Würstchen-v2 model. before that, every model was using a unet architecture, and I didn't want to reuse that name. The train tab already uses the "transformer" name in most places. I guess it makes sense to do the same here. But it still needs to write into the prior config variable to be compatible with the rest of the code.

@dxqbYD
Copy link
Collaborator Author

dxqbYD commented Jan 12, 2025

Note: This PR only fully works once diffusers have fixed this issue huggingface/diffusers#10540 by merging this PR huggingface/diffusers#10541.

@dxqbYD dxqbYD mentioned this pull request Jan 12, 2025
if transformer_model_name:
transformer = FluxTransformer2DModel.from_single_file(
transformer_model_name,
torch_dtype = weight_dtypes.prior.torch_dtype(),
Copy link
Collaborator Author

@dxqbYD dxqbYD Jan 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line is an issue RAM-wise if dtype of prior is NF4 or FLOAT8. In that case, torch_dtype() returns None, which causes from_single_file() to load the transformer in float32, before it is converted back to low precision.

too much RAM for many systems. not sure how to solve this yet, except ask the user for a "load dtype".
as it is, loading a transformer with prior dtype bfloat16 requires less ram than loading it with dtype float8.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

less? hm

pipeline = FluxPipeline.from_single_file(
pretrained_model_link_or_path=base_model_name,
safety_checker=None,
transformer=transformer,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this will work if transformer is None. I just tested this and the pipeline it returns doesn't have a transformer

    pipeline = FluxPipeline.from_single_file(
        pretrained_model_link_or_path="path/to/flux1-fill-dev.safetensors",
        safety_checker=None,
        transformer=None,
        text_encoder=text_encoder,
        text_encoder_2=text_encoder_2,
        vae=vae,
    )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants