Skip to content

Commit

Permalink
add support for flux vae. ~ wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jfacevedo-google committed Jan 14, 2025
1 parent 8a4ea09 commit d5ac715
Show file tree
Hide file tree
Showing 13 changed files with 366 additions and 24 deletions.
2 changes: 2 additions & 0 deletions src/maxdiffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@
_import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
_import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"]
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
_import_structure["schedulers"].extend(
Expand Down Expand Up @@ -451,6 +452,7 @@
from .models.controlnet_flax import FlaxControlNetModel
from .models.modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
from .models.vae_flax import FlaxAutoencoderKL
from .pipelines import FlaxDiffusionPipeline
from .schedulers import (
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
from .vae_flax import FlaxAutoencoderKL
from .lora import *
from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel

else:
import sys
Expand Down
1 change: 0 additions & 1 deletion src/maxdiffusion/models/embeddings_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def get_sinusoidal_embeddings(
signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
return signal


class FlaxTimestepEmbedding(nn.Module):
r"""
Time step Embedding Module. Learns embeddings for input time steps.
Expand Down
17 changes: 17 additions & 0 deletions src/maxdiffusion/models/flux/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
Copyright 2025 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from .transformers.transformer_flux_flax import FluxTransformer2DModel
15 changes: 15 additions & 0 deletions src/maxdiffusion/models/flux/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Copyright 2025 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
95 changes: 95 additions & 0 deletions src/maxdiffusion/models/flux/modules/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Copyright 2025 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import math
from dataclasses import dataclass
import jax
import jax.numpy as jnp
from chex import Array
from jax.typing import DTypeLike
import flax.linen as nn

def timestep_embedding(
t: Array, dim: int, max_period=10000, time_factor: float = 1000.0
) -> Array:
"""
Generate timestep embeddings.
Args:
t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
dim: the dimension of the output.
max_period: controls the minimum frequency of the embeddings.
time_factor: Tensor of positional embeddings.
Returns:
timestep embeddings.
"""
t = time_factor * t
half = dim // 2

freqs = jnp.exp(
-math.log(max_period) * jnp.arange(start=0, stop=half, dtype=jnp.float32) / half
).astype(dtype=t.dtype)

args = t[:, None].astype(jnp.float32) * freqs[None]
embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)

if dim % 2:
embedding = jnp.concatenate(
[embedding, jnp.zeros_like(embedding[:, :1])], axis=-1
)

if jnp.issubdtype(t.dtype, jnp.floating):
embedding = embedding.astype(t.dtype)

return embedding


class MLPEmbedder(nn.Module):
hidden_dim: int
dtype: jnp.dtype = jnp.float32
weights_dtype: jnp.dtype = jnp.float32
precision: jax.lax.Precision = None

@nn.compact
def __call__(self, x: Array) -> Array:

x = nn.Dense(
self.hidden_dim,
use_bias=True,
dtype=self.dtype,
param_dtype=self.weights_dtype,
precision=self.precision,
kernel_init=nn.with_logical_partitioning(
nn.initializers.lecun_normal(),
("embed", "heads")
)
)(x)
x = nn.silu(x)
x = nn.Dense(
self.hidden_dim,
use_bias=True,
dtype=self.dtype,
param_dtype=self.weights_dtype,
precision=self.precision,
kernel_init=nn.with_logical_partitioning(
nn.initializers.lecun_normal(),
("heads", "embed")
)
)(x)

return x
15 changes: 15 additions & 0 deletions src/maxdiffusion/models/flux/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Copyright 2025 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
134 changes: 134 additions & 0 deletions src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
Copyright 2025 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import Dict, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import flax.linen as nn
from chex import Array

from ..modules.layers import timestep_embedding, MLPEmbedder
from ...modeling_flax_utils import FlaxModelMixin
from ....configuration_utils import ConfigMixin, flax_register_to_config
from ....common_types import BlockSizes

class Identity(nn.Module):
def __call__(self, x: Array) -> Array:
return x

class FluxTransformer2DModel(nn.Module, FlaxModelMixin, ConfigMixin):
r"""
The Tranformer model introduced in Flux.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods
implemented for all models (such as downloading or saving).
This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
general usage and behavior.
"""
patch_size: int = 1
in_channels: int = 64
num_layers: int = 19
num_single_layers: int = 38
attention_head_dim: int = 128
num_attention_heads: int = 24
joint_attention_dim: int = 4096
pooled_projection_dim: int = 768
guidance_embeds: bool = False
axes_dims_rope: Tuple[int] = (16, 56, 56)
flash_min_seq_length: int = 4096
flash_block_sizes: BlockSizes = None
mesh: jax.sharding.Mesh = None
dtype: jnp.dtype = jnp.float32
weights_dtype: jnp.dtype = jnp.float32
precision: jax.lax.Precision = None

def setup(self):
self.out_channels = self.in_channels
self.inner_dim = self.num_attention_heads * self.attention_head_dim

self.img_in = nn.Dense(
self.inner_dim,
dtype=self.dtype,
param_dtype=self.weights_dtype,
precision=self.precision,
kernel_init=nn.with_logical_partitioning(
nn.initializers.lecun_normal(),
("embed", "heads")
)
)

self.time_in = MLPEmbedder(
hidden_dim=self.inner_dim,
dtype=self.dtype,
weights_dtype=self.weights_dtype,
precision=self.precision
)

self.vector_in = MLPEmbedder(
hidden_dim=self.inner_dim,
dtype=self.dtype,
weights_dtype=self.weights_dtype,
precision=self.precision
)

self.guidance_in = (
MLPEmbedder(
hidden_dim=self.inner_dim,
dtype=self.dtype,
weights_dtype=self.weights_dtype,
precision=self.precision
)
if self.guidance_embeds
else Identity()
)

self.txt_in = nn.Dense(
self.inner_dim,
dtype=self.dtype,
param_dtype=self.weights_dtype,
precision=self.precision
)

def __call__(
self,
img: Array,
img_ids: Array,
txt: Array,
txt_ids: Array,
timesteps: Array,
y: Array,
guidance: Array | None = None,
return_dict: bool = True,
train: bool = False):

img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))

if self.guidance_embeds:
if guidance is None:
raise ValueError(
"Didn't get guidance strength for guidance distrilled model."
)

vec = vec + self.guidance_in(timestep_embedding(guidance, 256))

vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
2 changes: 1 addition & 1 deletion src/maxdiffusion/models/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, params, network_alpha

def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
# Step 1: Convert pytorch tensor to numpy
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
pt_state_dict = {k: v.float().numpy() for k, v in pt_state_dict.items()}

# Step 2: Since the model is stateless, run eval_shape to get the pytree structure
random_flax_params = flax_model.init_weights(PRNGKey(init_key), eval_only=True)
Expand Down
3 changes: 2 additions & 1 deletion src/maxdiffusion/models/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def from_pretrained(
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
use_safetensors = kwargs.pop("use_safetensors", None)

user_agent = {
"maxdiffusion": __version__,
Expand Down Expand Up @@ -356,7 +357,7 @@ def from_pretrained(
try:
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
filename=FLAX_WEIGHTS_NAME if not from_pt else SAFETENSORS_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
Expand Down
3 changes: 2 additions & 1 deletion src/maxdiffusion/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
return torch.load(checkpoint_file, map_location="cpu")
else:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
from safetensors import torch as safetensors_torch
return safetensors_torch.load_file(checkpoint_file, device="cpu")
except Exception as e:
try:
with open(checkpoint_file) as f:
Expand Down
Loading

0 comments on commit d5ac715

Please sign in to comment.