Skip to content

Commit

Permalink
Stable Diffusion 3.x and Flux Optimization (#22986)
Browse files Browse the repository at this point in the history
### Description

It has dependency on the following PRs:
- #23297

Optimize the ONNX pipeline for Stable Diffusion 3.x and Flux 1.0 models
(fp32 or fp16).
- [x] Update optimize_pipeline script
- [x] Update benchmkark script
- [x] Update document about Stable Diffusion 3.x and Flux 1.0 models
- [x] Add graph optimizations for MMDit model
  - [x] FastGelu fusion
  - [x]  RMSNorm fusion
  - [x]  MultiHeadAttention fusion
- [x] Add graph optimizations for Flux transformer models
  - [x]  MultiHeadAttention fusion
- [x] Update graph optimizations for t5
- [x] Add tests

Optimize the ONNX pipeline for Stable Diffusion 3.x and Flux 1.0 models:
```
python optimize_pipeline.py -i ./flux1_schnell_onnx/fp32 -o ./flux1_schnell_onnx/fp16 --float16

  Optimize flux1_schnell_onnx/fp32/transformer/model.onnx ...
  Fused LayerNormalization: 115
  Fused SimplifiedLayerNormalization: 152
  Fused FastGelu: 76
  Fused MultiHeadAttention: 57
```

### H100 Benchmark Results

* GPU: NVIDIA H100 80GB HBM3
* Image Size: 1024x1024
* Batch Size: 1

Model | Steps | Precision | Engine | Latency (Seconds) | GPU Memory (MB)
-- | -- | -- | -- | -- | --
Flux 1.0 Dev | 50 | BF16 | Torch 2.5.1 (compile) | 8.198 | 37,603
Flux 1.0 Dev | 50 | FP16+BF16 | Optimum (ORT) | 10.762 | 41,469
Flux 1.0 Dev | 50 | FP16+FP32 | Optimum (ORT) | 10.891 | 43,545
Flux 1.0 Dev | 50 | BF16 | Torch 2.5.1 (eager) | 12.339 | 36,651
Flux 1.0 Schnell | 4 | BF16 | Torch 2.5.1 (compile) | 0.775 | 37,857
Flux 1.0 Schnell | 4 | FP16+BF16 | Optimum (ORT) | 0.931 | 41,433
Flux 1.0 Schnell | 4 | FP16+FP32 | Optimum (ORT) | 0.939 | 43,809
Flux 1.0 Schnell | 4 | BF16 | Torch 2.5.1 (eager) | 1.120 | 36,629
SD 3.5 Large | 50 | BF16 | Torch 2.5.1 (compile) | 7.466 | 32,217
SD 3.5 Large | 50 | FP16+BF16 | Optimum (ORT) | 10.275 | 36,609
SD 3.5 Large | 50 | FP16+FP32 | Optimum (ORT) | 10.283 | 36,729
SD 3.5 Large | 50 | BF16 | Torch 2.5.1 (eager) | 11.615 | 31,517
SD 3.5 Medium | 50 | BF16 | Torch 2.5.1 (compile) | 3.240 | 21,143
SD 3.5 Medium | 50 | FP16+BF16 | Optimum (ORT) | 4.799 | 25,097
SD 3.5 Medium | 50 | FP16+FP32 | Optimum (ORT) | 4.838 | 25,109
SD 3.5 Medium | 50 | BF16 | Torch 2.5.1 (eager) | 5.582 | 20,489

### A100 Benchmark Results

* GPU: A100-SXM4-80GB
* Image Size: 1024x1024
* Batch Size: 1

Model | Steps | Precision | Engine | Latency (Seconds) | GPU Memory (MB)
-- | -- | -- | -- | -- | --
Flux 1.0 Dev | 50 | BF16 | Torch 2.5.1 (compile) | 17.593 | 37,723
Flux 1.0 Dev | 50 | FP16+BF16 | Optimum (ORT) | 21.918 | 41,348
Flux 1.0 Dev | 50 | FP16+FP32 | Optimum (ORT) | 22.060 | 44,860
Flux 1.0 Dev | 50 | BF16 | Torch 2.5.1 (eager) | 24.267 | 36,847
Flux 1.0 Schnell | 4 | BF16 | Torch 2.5.1 (compile) | 1.627 | 37,881
Flux 1.0 Schnell | 4 | FP16+BF16 | Optimum (ORT) | 1.884 | 41,537
Flux 1.0 Schnell | 4 | FP16+FP32 | Optimum (ORT) | 1.902 | 44,858
Flux 1.0 Schnell | 4 | BF16 | Torch 2.5.1 (eager) | 2.162 | 36,831
SD 3.5 Large | 50 | BF16 | Torch 2.5.1 (compile) | 15.881 | 32,307
SD 3.5 Large | 50 | FP16+FP32 | Optimum (ORT) | 19.837 | 36,451
SD 3.5 Large | 50 | FP16+BF16 | Optimum (ORT) | 19.964 | 36,461
SD 3.5 Large | 50 | BF16 | Torch 2.5.1 (eager) | 22.477 | 31,513
SD 3.5 Medium | 50 | BF16 | Torch 2.5.1 (compile) | 6.476 | 21,341
SD 3.5 Medium | 50 | FP16+FP32 | Optimum (ORT) | 8.775 | 25,183
SD 3.5 Medium | 50 | BF16 | Torch 2.5.1 (eager) | 10.057 | 20,433

### Future Works

* Triton kernel for matrix multiplication and auto tuning.
* FP8/Int8 quantization

### Motivation and Context

SD 3.5 Architecture:

https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/resolve/main/mmdit-x.png
  • Loading branch information
tianleiwu authored Jan 14, 2025
1 parent 04030f6 commit 6550f4b
Show file tree
Hide file tree
Showing 19 changed files with 2,089 additions and 525 deletions.
61 changes: 25 additions & 36 deletions onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -125,42 +125,31 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);

if (data.bias == nullptr) {
assert(nullptr == fused_runner);
// For quantized attention, bias has been added so only need transpose here.
// gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH
assert(qk_head_size == v_head_size);
int matrix_to_trans = (past_present_share_buffer ? 1 : 3);
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.gemm_buffer, qkv, 3));
data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
} else {
// For fused TRT attention, transpose qkv to BxSxNx3xH (format 2)
// For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3)
// For unfused kernel, transpose to 3xBxNxSxH (format 1)
// For fused causal kernel, use format 1 since we need have K and V to update present state,
// at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel.
const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1));
data.qkv_format = use_fused_kernel
? AttentionQkvFormat::QKV_BSN3H
: (use_flash_or_efficient_attention
? AttentionQkvFormat::Q_K_V_BSNH
: (use_fused_causal
? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH
: AttentionQkvFormat::Q_K_V_BNSH));

// For fused causal, we will update gemm_buffer with bias directly.
T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr;

int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3);
// format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v
// format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H)
LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias,
3, parameters.do_rotary, parameters.rotary_embedding,
parameters.past_sequence_length);
}
// For fused TRT attention, transpose qkv to BxSxNx3xH (format 2)
// For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3)
// For unfused kernel, transpose to 3xBxNxSxH (format 1)
// For fused causal kernel, use format 1 since we need have K and V to update present state,
// at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel.
const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1));
data.qkv_format = use_fused_kernel
? AttentionQkvFormat::QKV_BSN3H
: (use_flash_or_efficient_attention
? AttentionQkvFormat::Q_K_V_BSNH
: (use_fused_causal
? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH
: AttentionQkvFormat::Q_K_V_BNSH));

// For fused causal, we will update gemm_buffer with bias directly.
T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr;

int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3);
// format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v
// format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H)
LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias,
3, parameters.do_rotary, parameters.rotary_embedding,
parameters.past_sequence_length);
return Status::OK();
}

Expand Down
12 changes: 10 additions & 2 deletions onnxruntime/python/tools/transformers/compare_bert_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,23 @@ def compare(baseline_results, treatment_results, verbose, rtol=1e-1, atol=1e-3):
# Validate the output of baseline and treatment, to make sure the results are similar.
diff_count = 0
max_abs_diff = 0
max_diff_percentage = 0
case_passed = True
for test_case_id, results in enumerate(baseline_results):
case_passed = True
for i in range(len(results)):
treatment_output = treatment_results[test_case_id][i]
abs_diff = np.amax(np.abs(treatment_output - results[i]))
abs_diff_tensor = np.abs(treatment_output - results[i])
abs_diff = np.amax(abs_diff_tensor)
if verbose and abs_diff > atol:
print("abs_diff", abs_diff)
print("treatment", treatment_output)
print("baseline", results[i])

count_exceeding = np.sum(abs_diff_tensor > atol)
total_elements = abs_diff_tensor.size
percentage_exceeding = (count_exceeding / total_elements) * 100
max_diff_percentage = max(max_diff_percentage, percentage_exceeding)

max_abs_diff = max(max_abs_diff, abs_diff)
if not np.allclose(results[i].tolist(), treatment_output.tolist(), rtol=rtol, atol=atol):
if case_passed:
Expand All @@ -66,6 +73,7 @@ def compare(baseline_results, treatment_results, verbose, rtol=1e-1, atol=1e-3):
)

print(f"maximum absolute difference={max_abs_diff}")
print(f"maximum percentage of elements that exceeds atol={atol} is {max_diff_percentage:.3f}%")
return max_abs_diff, case_passed


Expand Down
39 changes: 0 additions & 39 deletions onnxruntime/python/tools/transformers/fusion_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,45 +355,6 @@ def split_kv(self, present_k_name: str, present_v_name: str, kv_node: str):
self.node_name_to_graph_name[gather_k_name] = self.this_graph_name
self.node_name_to_graph_name[gather_v_name] = self.this_graph_name

def transpose_kv(self, past_k: str, past_v: str):
"""Transpose past_k and past_v from (B,N,P,H) to (B,P,N,H)
Args:
past_k (str): name of past K value of shape (B,N,P,H)
past_v (str): name of past V value of shape (B,N,P,H)
Returns:
past_k_transpose (str): name of past K value of shape (B,P,N,H)
past_v_transpose (str): name of past V value of shape (B,P,N,H)
"""
past_k_transpose = (past_k + "_transposed").replace(".", "_")
past_v_transpose = (past_v + "_transposed").replace(".", "_")
transpose_k_name = self.model.create_node_name("Transpose")
transpose_v_name = self.model.create_node_name("Transpose")

transpose_k = helper.make_node(
"Transpose",
inputs=[past_k],
outputs=[past_k_transpose],
name=transpose_k_name,
perm=[0, 2, 1, 3],
)
transpose_v = helper.make_node(
"Transpose",
inputs=[past_v],
outputs=[past_v_transpose],
name=transpose_v_name,
perm=[0, 2, 1, 3],
)

# Add reshape nodes to graph
self.nodes_to_add.append(transpose_k)
self.nodes_to_add.append(transpose_v)
self.node_name_to_graph_name[transpose_k_name] = self.this_graph_name
self.node_name_to_graph_name[transpose_v_name] = self.this_graph_name

return past_k_transpose, past_v_transpose

def create_combined_qkv_bias(
self,
q_add: NodeProto,
Expand Down
122 changes: 122 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_fastgelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def fuse(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
if self.fuse_3(tanh_node, input_name_to_nodes, output_name_to_node):
return

if self.fuse_4(tanh_node, input_name_to_nodes, output_name_to_node):
return

def fuse_1(self, tanh_node, input_name_to_nodes, output_name_to_node) -> Optional[bool]:
"""
Fuse Gelu with tanh into one node:
Expand Down Expand Up @@ -358,3 +361,122 @@ def fuse_3(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict
self.nodes_to_add.append(fused_node)
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
return True

def fuse_4(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:
"""
This pattern is from stable diffusion 3.5 model.
Fuse Gelu with tanh into one node:
+-----------------+------------------+
| | |
| v v
[root] ==> Mul --> Mul --> Mul -----> Add --> Mul --> Tanh --> Add -----> Mul --> Mul -->
| (A=0.0447) (A=0.7978) (A=1) ^ (A=0.5)
| |
+-------------------------------------------------------------------------+
Note that constant input for Add and Mul could be first or second input.
"""
if tanh_node.output[0] not in input_name_to_nodes:
return

children = input_name_to_nodes[tanh_node.output[0]]
if len(children) != 1 or children[0].op_type != "Add":
return
add_after_tanh = children[0]

if not self.model.has_constant_input(add_after_tanh, 1.0):
return

if add_after_tanh.output[0] not in input_name_to_nodes:
return
children = input_name_to_nodes[add_after_tanh.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return
mul_after_tanh = children[0]

if mul_after_tanh.output[0] not in input_name_to_nodes:
return
children = input_name_to_nodes[mul_after_tanh.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return
mul_half = children[0]
if not self.model.has_constant_input(mul_half, 0.5):
return

root_input = mul_after_tanh.input[0 if mul_after_tanh.input[1] == add_after_tanh.output[0] else 1]

mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
if mul_before_tanh is None:
return

i = self.model.find_constant_input(mul_before_tanh, 0.7978, delta=0.0001)
if i < 0:
return

add_before_tanh = self.model.match_parent(mul_before_tanh, "Add", 0 if i == 1 else 1, output_name_to_node)
if add_before_tanh is None:
return

if add_before_tanh.input[0] == root_input:
another = 1
elif add_before_tanh.input[1] == root_input:
another = 0
else:
return

mul_after_pow = self.model.match_parent(add_before_tanh, "Mul", another, output_name_to_node)
if mul_after_pow is None:
return

i = self.model.find_constant_input(mul_after_pow, 0.0447, delta=0.0001)
if i < 0:
return

mul = self.model.match_parent(mul_after_pow, "Mul", 0 if i == 1 else 1, output_name_to_node)
if mul is None:
return

if mul.input[0] == root_input:
another = 1
elif mul.input[1] == root_input:
another = 0
else:
return

mul2 = self.model.match_parent(mul, "Mul", another, output_name_to_node)
if mul2 is None:
return

if mul2.input[0] != root_input or mul2.input[1] != root_input:
return

subgraph_nodes = [
mul2,
mul,
mul_after_pow,
add_before_tanh,
mul_before_tanh,
tanh_node,
add_after_tanh,
mul_after_tanh,
mul_half,
]

if not self.model.is_safe_to_fuse_nodes(
subgraph_nodes,
[mul_half.output[0]],
input_name_to_nodes,
output_name_to_node,
):
return

self.nodes_to_remove.extend(subgraph_nodes)
fused_node = helper.make_node(
"FastGelu",
inputs=[root_input],
outputs=mul_half.output,
name=self.model.create_node_name("FastGelu"),
)
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
return True
4 changes: 3 additions & 1 deletion onnxruntime/python/tools/transformers/fusion_group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
instance_norm_scale = self.model.get_constant_value(instance_norm.input[1])
if instance_norm_scale is None or len(instance_norm_scale.shape) != 1:
return
num_groups = int(instance_norm_scale.shape[0])

instance_norm_bias = self.model.get_constant_value(instance_norm.input[2])
if instance_norm_bias is None or instance_norm_scale.shape != instance_norm_scale.shape:
Expand Down Expand Up @@ -156,7 +157,8 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
)

new_node.attribute.extend(instance_norm.attribute)
new_node.attribute.extend([helper.make_attribute("groups", 32)])

new_node.attribute.extend([helper.make_attribute("groups", num_groups)])
new_node.attribute.extend([helper.make_attribute("activation", 1 if has_swish_activation else 0)])

if not self.channels_last:
Expand Down
Loading

0 comments on commit 6550f4b

Please sign in to comment.