-
-
Notifications
You must be signed in to change notification settings - Fork 163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
override Flux transformer #639
base: master
Are you sure you want to change the base?
Conversation
@@ -252,7 +253,7 @@ def __create_base_components( | |||
if allow_override_prior: | |||
# prior model | |||
components.label(self.scroll_frame, row, 0, "Prior Model", | |||
tooltip="Filename, directory or Hugging Face repository of the prior model") | |||
tooltip="Filename, directory or Hugging Face repository of the prior model. For Flux, it must be a safetensors file that only contains the transformer.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure why prior
is used for the Flux transformer. What is "prior" about the transformer? It seems to be something from Stable Cascade, where prior
referred to a different model, not the one that is trained.
But prior
is already used in the UI for the transformer, see "Override Prior Data Type", so this PR doesn't change that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The "prior" name originally came from the Würstchen-v2 model. before that, every model was using a unet architecture, and I didn't want to reuse that name. The train tab already uses the "transformer" name in most places. I guess it makes sense to do the same here. But it still needs to write into the prior config variable to be compatible with the rest of the code.
Note: This PR only fully works once diffusers have fixed this issue huggingface/diffusers#10540 by merging this PR huggingface/diffusers#10541. |
if transformer_model_name: | ||
transformer = FluxTransformer2DModel.from_single_file( | ||
transformer_model_name, | ||
torch_dtype = weight_dtypes.prior.torch_dtype(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this line is an issue RAM-wise if dtype of prior is NF4 or FLOAT8. In that case, torch_dtype() returns None, which causes from_single_file() to load the transformer in float32, before it is converted back to low precision.
too much RAM for many systems. not sure how to solve this yet, except ask the user for a "load dtype".
as it is, loading a transformer with prior dtype bfloat16 requires less ram than loading it with dtype float8.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
less? hm
pipeline = FluxPipeline.from_single_file( | ||
pretrained_model_link_or_path=base_model_name, | ||
safety_checker=None, | ||
transformer=transformer, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this will work if transformer is None. I just tested this and the pipeline it returns doesn't have a transformer
pipeline = FluxPipeline.from_single_file(
pretrained_model_link_or_path="path/to/flux1-fill-dev.safetensors",
safety_checker=None,
transformer=None,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
vae=vae,
)
People have started to use and share Flux .safetensors files that only contain the Flux transformer, not the entire model pipeline. The entire pipeline is almost twice the size of the transformer only.
Examples:
Using this PR, you can load the model pipeline from
black-forest-labs/FLUX.1-dev
, but override the transformer from a safetensors file.