From 05660c00ccd2d60df8912991240c8a49f26564c0 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 17 Jun 2024 11:57:24 -0500 Subject: [PATCH 1/6] Add scheduler_id to sd3 api for unified signature --- .../turbine_models/custom_models/sd3_inference/sd3_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py b/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py index 686e2b453..5a80b2633 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py @@ -44,7 +44,6 @@ class SharkSD3Pipeline: def __init__( self, hf_model_name: str, - # scheduler_id: str, height: int, width: int, shift: float, @@ -63,6 +62,7 @@ def __init__( vae_decomp_attn: bool = True, custom_vae: str = "", cpu_scheduling: bool = False, + scheduler_id: str = None, #compatibility only, always uses EulerFlowScheduler ): self.hf_model_name = hf_model_name # self.scheduler_id = scheduler_id From 4692e11339a295a16e963181bbb6d2841714388e Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 17 Jun 2024 18:08:06 -0500 Subject: [PATCH 2/6] Remove sentencepiece from reqs. --- models/requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/models/requirements.txt b/models/requirements.txt index ead79c1d9..bdd1892e8 100644 --- a/models/requirements.txt +++ b/models/requirements.txt @@ -1,5 +1,4 @@ protobuf -sentencepiece shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main transformers==4.37.1 torchsde From 8b775aae35a5a0b13bc4081115327ef74747a077 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 18 Jun 2024 01:12:24 -0500 Subject: [PATCH 3/6] Temporarily comment out create_hal_driver usage for old iree version compat (DNM) --- models/turbine_models/model_runner.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/models/turbine_models/model_runner.py b/models/turbine_models/model_runner.py index 1b27ca83b..41dc8746e 100644 --- a/models/turbine_models/model_runner.py +++ b/models/turbine_models/model_runner.py @@ -1,7 +1,7 @@ import argparse import sys from iree import runtime as ireert -from iree.runtime._binding import create_hal_driver +#from iree.runtime._binding import create_hal_driver class vmfbRunner: @@ -11,14 +11,14 @@ def __init__(self, device, vmfb_path, external_weight_path=None, extra_plugin=No # If an extra plugin is requested, add a global flag to load the plugin # and create the driver using the non-caching creation function, as # the caching creation function may ignore the flag. - if extra_plugin: - ireert.flags.parse_flags(f"--executable_plugin={extra_plugin}") - haldriver = create_hal_driver(device) + # if extra_plugin: + # ireert.flags.parse_flags(f"--executable_plugin={extra_plugin}") + # haldriver = create_hal_driver(device) # No plugin requested: create the driver with the caching create # function. - else: - haldriver = ireert.get_driver(device) + #else: + haldriver = ireert.get_driver(device) if "://" in device: try: device_idx = int(device.split("://")[-1]) From fd2a2ba40e2110e666f91847c1c2c5a19a4126fd Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 18 Jun 2024 16:00:32 -0500 Subject: [PATCH 4/6] Fixes for vae precision/attn decomposition, numerics validation --- .../sd3_inference/sd3_cmd_opts.py | 8 ++++- .../sd3_inference/sd3_pipeline.py | 32 +++++++++++-------- .../sd3_inference/sd3_vae_runner.py | 25 ++++++++++----- .../sd3_inference/text_encoder_impls.py | 4 ++- .../sdxl_inference/unet_runner.py | 6 ++-- models/turbine_models/model_runner.py | 12 +++---- 6 files changed, 54 insertions(+), 33 deletions(-) diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py b/models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py index b3250ea35..ac97d77e4 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py @@ -247,6 +247,12 @@ def is_valid_file(arg): default="fp16", help="Precision of Stable Diffusion weights and graph.", ) +p.add_argument( + "--vae_precision", + type=str, + default=None, + help="Precision of Stable Diffusion VAE weights and graph.", +) p.add_argument( "--max_length", type=int, default=77, help="Sequence Length of Stable Diffusion" ) @@ -257,7 +263,7 @@ def is_valid_file(arg): p.add_argument( "--vae_decomp_attn", type=bool, - default=True, + default=False, help="Decompose attention for VAE decode only at fx graph level", ) p.add_argument( diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py b/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py index 5a80b2633..7f1ec7022 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py @@ -46,7 +46,6 @@ def __init__( hf_model_name: str, height: int, width: int, - shift: float, precision: str, max_length: int, batch_size: int, @@ -59,10 +58,12 @@ def __init__( pipeline_dir: str = "./shark_vmfbs", external_weights_dir: str = "./shark_weights", external_weights: str = "safetensors", - vae_decomp_attn: bool = True, - custom_vae: str = "", + vae_decomp_attn: bool = False, cpu_scheduling: bool = False, + vae_precision: str = "fp32", scheduler_id: str = None, #compatibility only, always uses EulerFlowScheduler + shift: float = 1.0, + ): self.hf_model_name = hf_model_name # self.scheduler_id = scheduler_id @@ -120,10 +121,11 @@ def __init__( self.external_weights_dir = external_weights_dir self.external_weights = external_weights self.vae_decomp_attn = vae_decomp_attn - self.custom_vae = custom_vae + self.custom_vae = None self.cpu_scheduling = cpu_scheduling self.torch_dtype = torch.float32 if self.precision == "fp32" else torch.float16 - self.vae_dtype = torch.float32 + self.vae_precision = vae_precision if vae_precision else self.precision + self.vae_dtype = torch.float32 if vae_precision == "fp32" else torch.float16 # TODO: set this based on user-inputted guidance scale and negative prompt. self.do_classifier_free_guidance = True # False if any(x in hf_model_name for x in ["turbo", "lightning"]) else True @@ -206,7 +208,12 @@ def is_prepared(self, vmfbs, weights): ) if w_key == "clip": default_name = os.path.join( - self.external_weights_dir, f"sd3_clip_fp16.irpa" + self.external_weights_dir, f"sd3_text_encoders_{self.precision}.irpa" + ) + if w_key == "mmdit": + default_name = os.path.join( + self.external_weights_dir, + f"sd3_mmdit_{self.precision}." + self.external_weights, ) if weights[w_key] is None and os.path.exists(default_name): weights[w_key] = os.path.join(default_name) @@ -357,7 +364,7 @@ def export_submodel( self.batch_size, self.height, self.width, - "fp32", + self.vae_precision, "vmfb", self.external_weights, vae_external_weight_path, @@ -586,7 +593,8 @@ def generate_images( dtype=self.vae_dtype, ) else: - latents = sample.astype("float32") + vae_numpy_dtype = np.float32 if self.vae_precision == "fp32" else np.float16 + latents = sample.astype(vae_numpy_dtype) vae_start = time.time() vae_out = self.runners["vae"].ctx.modules.compiled_vae["decode"](latents) @@ -634,7 +642,7 @@ def generate_images( out_image = Image.fromarray(image) images.extend([[out_image]]) if return_imgs: - return images + return images[0] for idx_batch, image_batch in enumerate(images): for idx, image in enumerate(image_batch): img_path = ( @@ -767,7 +775,6 @@ def run_diffusers_cpu( args.hf_model_name, args.height, args.width, - args.shift, args.precision, args.max_length, args.batch_size, @@ -779,9 +786,8 @@ def run_diffusers_cpu( args.decomp_attn, args.pipeline_dir, args.external_weights_dir, - args.external_weights, - args.vae_decomp_attn, - custom_vae=None, + external_weights=args.external_weights, + vae_decomp_attn=args.vae_decomp_attn, cpu_scheduling=args.cpu_scheduling, vae_precision=args.vae_precision, ) diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_vae_runner.py b/models/turbine_models/custom_models/sd3_inference/sd3_vae_runner.py index 23db4ab73..31b23b429 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_vae_runner.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_vae_runner.py @@ -15,8 +15,8 @@ def run_vae( ): runner = vmfbRunner(device, vmfb_path, external_weight_path) inputs = [ireert.asdevicearray(runner.config.device, example_input)] - results = runner.ctx.modules.compiled_vae["decode"](*inputs) - + results = runner.ctx.modules.compiled_vae["decode"](*inputs).to_host() + results = imagearray_from_vae_out(results) return results @@ -32,11 +32,19 @@ def run_torch_vae(hf_model_name, variant, example_input): elif variant == "encode": results = vae_model.encode(example_input) np_torch_output = results.detach().cpu().numpy() + np_torch_output = imagearray_from_vae_out(np_torch_output) return np_torch_output +def imagearray_from_vae_out(image): + if image.ndim == 4: + image = image[0] + image = torch.from_numpy(image).cpu().permute(1, 2, 0).float().numpy() + image = (image * 255).round().astype("uint8") + return image if __name__ == "__main__": from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args + import numpy as np dtype = torch.float16 if args.precision == "fp16" else torch.float32 if args.vae_variant == "decode": @@ -57,9 +65,9 @@ def run_torch_vae(hf_model_name, variant, example_input): ) print( "TURBINE OUTPUT:", - turbine_results.to_host(), - turbine_results.to_host().shape, - turbine_results.to_host().dtype, + turbine_results, + turbine_results.shape, + turbine_results.dtype, ) if args.compare_vs_torch: print("generating torch output: ") @@ -69,9 +77,10 @@ def run_torch_vae(hf_model_name, variant, example_input): args.hf_model_name, args.vae_variant, example_input.float() ) print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) - err = utils.largest_error(torch_output, turbine_results) - print("Largest Error: ", err) - assert err < 2e-3 + # Allow a small amount of wiggle room for rounding errors (1) + np.testing.assert_allclose( + turbine_results, torch_output, rtol=1, atol=1 + ) # TODO: Figure out why we occasionally segfault without unlinking output variables turbine_results = None diff --git a/models/turbine_models/custom_models/sd3_inference/text_encoder_impls.py b/models/turbine_models/custom_models/sd3_inference/text_encoder_impls.py index 29b9d2f80..747b60d9b 100644 --- a/models/turbine_models/custom_models/sd3_inference/text_encoder_impls.py +++ b/models/turbine_models/custom_models/sd3_inference/text_encoder_impls.py @@ -341,8 +341,10 @@ def __init__(self): self.clip_g = SDXLClipGTokenizer(clip_tokenizer) self.t5xxl = T5XXLTokenizer() - def tokenize_with_weights(self, text: str): + def tokenize_with_weights(self, text: str | list[str]): out = {} + if isinstance(text, list): + text = text[0] out["g"] = self.clip_g.tokenize_with_weights(text) out["l"] = self.clip_l.tokenize_with_weights(text) out["t5xxl"] = self.t5xxl.tokenize_with_weights(text) diff --git a/models/turbine_models/custom_models/sdxl_inference/unet_runner.py b/models/turbine_models/custom_models/sdxl_inference/unet_runner.py index 4437b9eae..9d0b405c3 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet_runner.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet_runner.py @@ -31,9 +31,8 @@ def run_unet( ireert.asdevicearray(runner.config.device, prompt_embeds), ireert.asdevicearray(runner.config.device, text_embeds), ireert.asdevicearray(runner.config.device, time_ids), - ireert.asdevicearray(runner.config.device, guidance_scale), ] - results = runner.ctx.modules.compiled_unet["main"](*inputs) + results = runner.ctx.modules.compiled_unet["run_forward"](*inputs) return results @@ -57,7 +56,6 @@ def run_unet_steps( ireert.asdevicearray(runner.config.device, prompt_embeds), ireert.asdevicearray(runner.config.device, text_embeds), ireert.asdevicearray(runner.config.device, time_ids), - ireert.asdevicearray(runner.config.device, (guidance_scale,)), ] for i, t in tqdm(enumerate(scheduler.timesteps)): timestep = t @@ -69,7 +67,7 @@ def run_unet_steps( inputs[1] = timestep = ireert.asdevicearray( runner.config.device, (timestep,), dtype="int64" ) - noise_pred = runner.ctx.modules.compiled_unet["main"](*inputs).to_host() + noise_pred = runner.ctx.modules.compiled_unet["run_forward"](*inputs).to_host() sample = scheduler.step( torch.from_numpy(noise_pred).cpu(), timestep, diff --git a/models/turbine_models/model_runner.py b/models/turbine_models/model_runner.py index 41dc8746e..1b27ca83b 100644 --- a/models/turbine_models/model_runner.py +++ b/models/turbine_models/model_runner.py @@ -1,7 +1,7 @@ import argparse import sys from iree import runtime as ireert -#from iree.runtime._binding import create_hal_driver +from iree.runtime._binding import create_hal_driver class vmfbRunner: @@ -11,14 +11,14 @@ def __init__(self, device, vmfb_path, external_weight_path=None, extra_plugin=No # If an extra plugin is requested, add a global flag to load the plugin # and create the driver using the non-caching creation function, as # the caching creation function may ignore the flag. - # if extra_plugin: - # ireert.flags.parse_flags(f"--executable_plugin={extra_plugin}") - # haldriver = create_hal_driver(device) + if extra_plugin: + ireert.flags.parse_flags(f"--executable_plugin={extra_plugin}") + haldriver = create_hal_driver(device) # No plugin requested: create the driver with the caching create # function. - #else: - haldriver = ireert.get_driver(device) + else: + haldriver = ireert.get_driver(device) if "://" in device: try: device_idx = int(device.split("://")[-1]) From b1f20f1c07e6b75d358fffbfc3672f9dd6a49fc8 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 18 Jun 2024 20:28:35 -0500 Subject: [PATCH 5/6] Fix numerics, add some features to VAE runner, add cpu scheduling options --- .../sd3_inference/sd3_cmd_opts.py | 6 + .../custom_models/sd3_inference/sd3_mmdit.py | 2 +- .../sd3_inference/sd3_mmdit_runner.py | 3 +- .../sd3_inference/sd3_pipeline.py | 109 +++++++++++++----- .../sd3_inference/sd3_schedulers.py | 41 ++++++- .../custom_models/sd3_inference/sd3_vae.py | 1 + .../sd3_inference/sd3_vae_runner.py | 16 ++- 7 files changed, 141 insertions(+), 37 deletions(-) diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py b/models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py index ac97d77e4..55cf3b72d 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py @@ -346,6 +346,12 @@ def is_valid_file(arg): action="store_true", help="Just compile attention reproducer for mmdit.", ) +p.add_argument( + "--vae_input_path", + type=str, + default=None, + help="Path to input latents for VAE inference numerics validation.", +) ############################################################################## diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py index 9d6ea012d..05d3e00cb 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py @@ -207,7 +207,7 @@ def export_mmdit_model( torch.empty(hidden_states_shape, dtype=dtype), torch.empty(encoder_hidden_states_shape, dtype=dtype), torch.empty(pooled_projections_shape, dtype=dtype), - torch.empty(1, dtype=dtype), + torch.empty(init_batch_dim, dtype=dtype), ] decomp_list = [] diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py index a0be81192..06100eab3 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py @@ -154,7 +154,7 @@ def find_errs(turbine_output, torch_output, dim=[], failed_dims=[], errs=[]): (batch_size, args.max_length * 2, 4096), dtype=dtype ) pooled_projections = torch.randn((batch_size, 2048), dtype=dtype) - timestep = torch.tensor([0], dtype=dtype) + timestep = torch.tensor([0, 0], dtype=dtype) turbine_output = run_mmdit_turbine( hidden_states, @@ -180,6 +180,7 @@ def find_errs(turbine_output, torch_output, dim=[], failed_dims=[], errs=[]): timestep, args, ) + np.save("torch_mmdit_output.npy", torch_output.astype(np.float16)) print("torch OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) print("\n(torch (comfy) image latents to iree image latents): ") diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py b/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py index 7f1ec7022..303ba326e 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py @@ -17,6 +17,7 @@ from turbine_models.custom_models.sd_inference import utils from turbine_models.model_runner import vmfbRunner from transformers import CLIPTokenizer +from diffusers import FlowMatchEulerDiscreteScheduler from PIL import Image import os @@ -426,10 +427,16 @@ def load_pipeline( unet_loaded = time.time() print("\n[LOG] MMDiT loaded in ", unet_loaded - load_start, "sec") - runners["scheduler"] = sd3_schedulers.SharkSchedulerWrapper( - self.devices["mmdit"]["driver"], - vmfbs["scheduler"], - ) + if not self.cpu_scheduling: + runners["scheduler"] = sd3_schedulers.SharkSchedulerWrapper( + self.devices["mmdit"]["driver"], + vmfbs["scheduler"], + ) + else: + print("Using torch CPU scheduler.") + runners["scheduler"] = FlowMatchEulerDiscreteScheduler.from_pretrained( + self.hf_model_name, subfolder="scheduler" + ) sched_loaded = time.time() print("\n[LOG] Scheduler loaded in ", sched_loaded - unet_loaded, "sec") @@ -502,11 +509,12 @@ def generate_images( ) ) - guidance_scale = ireert.asdevicearray( - self.runners["pipe"].config.device, - np.asarray([guidance_scale]), - dtype=iree_dtype, - ) + if not self.cpu_scheduling: + guidance_scale = ireert.asdevicearray( + self.runners["pipe"].config.device, + np.asarray([guidance_scale]), + dtype=iree_dtype, + ) tokenize_start = time.time() text_input_ids_dict = self.tokenizer.tokenize_with_weights(prompt) @@ -540,12 +548,23 @@ def generate_images( "clip" ].ctx.modules.compiled_text_encoder["encode_tokens"](*text_encoders_inputs) encode_prompts_end = time.time() + if self.cpu_scheduling: + timesteps, num_inference_steps = sd3_schedulers.retrieve_timesteps( + self.runners["scheduler"], + num_inference_steps=self.num_inference_steps, + timesteps=None, + ) + steps = num_inference_steps + for i in range(batch_count): unet_start = time.time() - sample, steps, timesteps = self.runners["scheduler"].initialize(samples[i]) + if not self.cpu_scheduling: + latents, steps, timesteps = self.runners["scheduler"].initialize(samples[i]) + else: + latents = torch.tensor(samples[i].to_host(), dtype=self.torch_dtype) iree_inputs = [ - sample, + latents, ireert.asdevicearray( self.runners["pipe"].config.device, prompt_embeds, dtype=iree_dtype ), @@ -560,41 +579,71 @@ def generate_images( # print(f"step {s}") if self.cpu_scheduling: step_index = s + t = timesteps[s] + if self.do_classifier_free_guidance: + latent_model_input = torch.cat([latents] * 2) + timestep = ireert.asdevicearray( + self.runners["pipe"].config.device, + t.expand(latent_model_input.shape[0]), + dtype=iree_dtype, + ) + latent_model_input = ireert.asdevicearray( + self.runners["pipe"].config.device, + latent_model_input, + dtype=iree_dtype, + ) else: step_index = ireert.asdevicearray( self.runners["scheduler"].runner.config.device, torch.tensor([s]), "int64", ) - latents, t = self.runners["scheduler"].prep( - sample, - step_index, - timesteps, - ) + latent_model_input, timestep = self.runners["scheduler"].prep( + latents, + step_index, + timesteps, + ) + t = ireert.asdevicearray( + self.runners["scheduler"].runner.config.device, + timestep.to_host()[0] + ) noise_pred = self.runners["pipe"].ctx.modules.compiled_mmdit[ "run_forward" ]( - latents, + latent_model_input, iree_inputs[1], iree_inputs[2], - t, - ) - sample = self.runners["scheduler"].step( - noise_pred, - t, - sample, - guidance_scale, - step_index, + timestep, ) - if isinstance(sample, torch.Tensor): + if not self.cpu_scheduling: + latents = self.runners["scheduler"].step( + noise_pred, + t, + latents, + guidance_scale, + step_index, + ) + else: + noise_pred = torch.tensor(noise_pred.to_host(), dtype=self.torch_dtype) + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + latents = self.runners["scheduler"].step( + noise_pred, + t, + latents, + return_dict=False, + )[0] + + if isinstance(latents, torch.Tensor): + latents = latents.type(self.vae_dtype) latents = ireert.asdevicearray( self.runners["vae"].config.device, - sample, - dtype=self.vae_dtype, + latents, ) else: vae_numpy_dtype = np.float32 if self.vae_precision == "fp32" else np.float16 - latents = sample.astype(vae_numpy_dtype) + latents = latents.astype(vae_numpy_dtype) vae_start = time.time() vae_out = self.runners["vae"].ctx.modules.compiled_vae["decode"](latents) @@ -791,10 +840,10 @@ def run_diffusers_cpu( cpu_scheduling=args.cpu_scheduling, vae_precision=args.vae_precision, ) - vmfbs, weights = sd3_pipe.check_prepared(mlirs, vmfbs, weights) if args.cpu_scheduling: vmfbs.pop("scheduler") weights.pop("scheduler") + vmfbs, weights = sd3_pipe.check_prepared(mlirs, vmfbs, weights) if args.npu_delegate_path: extra_device_args = {"npu_delegate_path": args.npu_delegate_path} else: diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py index 86179746a..0fe4ae0d8 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py @@ -5,9 +5,11 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import os +import inspect from typing import List import torch +from typing import Any, Callable, Dict, List, Optional, Union from shark_turbine.aot import * import shark_turbine.ops.iree as ops from iree.compiler.ir import Context @@ -75,11 +77,12 @@ def initialize(self, sample): def prepare_model_input(self, sample, t, timesteps): t = timesteps[t] - t = t.expand(sample.shape[0]) + if self.do_classifier_free_guidance: latent_model_input = torch.cat([sample] * 2) else: latent_model_input = sample + t = t.expand(sample.shape[0]) return latent_model_input.type(self.dtype), t.type(self.dtype) def step(self, noise_pred, t, sample, guidance_scale, i): @@ -146,6 +149,42 @@ def step(self, noise_pred, t, latents, guidance_scale, i): return_dict=False, )[0] +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +# Only used for cpu scheduling. +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps @torch.no_grad() def export_scheduler_model( diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_vae.py b/models/turbine_models/custom_models/sd3_inference/sd3_vae.py index a70c19882..5bd6f0f5b 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_vae.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_vae.py @@ -33,6 +33,7 @@ def __init__( ) def decode(self, inp): + inp = (inp / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(inp, return_dict=False)[0] image = image.float() image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_vae_runner.py b/models/turbine_models/custom_models/sd3_inference/sd3_vae_runner.py index 31b23b429..9cb435bde 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_vae_runner.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_vae_runner.py @@ -45,12 +45,17 @@ def imagearray_from_vae_out(image): if __name__ == "__main__": from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args import numpy as np + from PIL import Image dtype = torch.float16 if args.precision == "fp16" else torch.float32 if args.vae_variant == "decode": example_input = torch.rand( args.batch_size, 16, args.height // 8, args.width // 8, dtype=dtype ) + if args.vae_input_path: + example_input = np.load(args.vae_input_path) + if example_input.shape[0] == 2: + example_input = np.split(example_input, 2)[0] elif args.vae_variant == "encode": example_input = torch.rand( args.batch_size, 3, args.height, args.width, dtype=dtype @@ -74,13 +79,16 @@ def imagearray_from_vae_out(image): from turbine_models.custom_models.sd_inference import utils torch_output = run_torch_vae( - args.hf_model_name, args.vae_variant, example_input.float() + args.hf_model_name, args.vae_variant, torch.tensor(example_input).float() ) print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) + if args.vae_input_path: + out_image_torch = Image.fromarray(torch_output) + out_image_torch.save("vae_test_output_torch.png") + out_image_turbine = Image.fromarray(turbine_results) + out_image_turbine.save("vae_test_output_turbine.png") # Allow a small amount of wiggle room for rounding errors (1) + np.testing.assert_allclose( turbine_results, torch_output, rtol=1, atol=1 ) - - # TODO: Figure out why we occasionally segfault without unlinking output variables - turbine_results = None From 618d01f9b725d1c60e0c6da4302bf0976792ea3c Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 18 Jun 2024 21:12:41 -0500 Subject: [PATCH 6/6] Point to azure links for specs and fix timesteps dim in gpu scheduler. --- .../custom_models/sd3_inference/sd3_schedulers.py | 2 +- models/turbine_models/custom_models/sd_inference/utils.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py index 0fe4ae0d8..2efb13aa9 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py @@ -82,7 +82,7 @@ def prepare_model_input(self, sample, t, timesteps): latent_model_input = torch.cat([sample] * 2) else: latent_model_input = sample - t = t.expand(sample.shape[0]) + t = t.expand(latent_model_input.shape[0]) return latent_model_input.type(self.dtype), t.type(self.dtype) def step(self, noise_pred, t, sample, guidance_scale, i): diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index e4b755131..0931a4028 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -35,7 +35,7 @@ "--iree-codegen-gpu-native-math-precision=true", "--iree-rocm-waves-per-eu=2", "--iree-flow-inline-constants-max-byte-length=1", - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,128,0,32,0}))", ], "unet": [ "--iree-flow-enable-aggressive-fusion", @@ -275,7 +275,7 @@ def create_safe_name(hf_model_name, model_name_str): def get_mfma_spec_path(target_chip, save_dir): - url = "https://raw.githubusercontent.com/iree-org/iree/main/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir" + url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx942.mlir" attn_spec = urlopen(url).read().decode("utf-8") spec_path = os.path.join(save_dir, "attention_and_matmul_spec_mfma.mlir") if os.path.exists(spec_path): @@ -287,9 +287,9 @@ def get_mfma_spec_path(target_chip, save_dir): def get_wmma_spec_path(target_chip, save_dir): if target_chip == "gfx1100": - url = "https://github.com/iree-org/iree/raw/shared/tresleches-united/scripts/attention_gfx1100.spec.mlir" + url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx1100.mlir" elif target_chip in ["gfx1103", "gfx1150"]: - url = "https://github.com/iree-org/iree/raw/shared/tresleches-united/scripts/attention_gfx1103.spec.mlir" + url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx1150.mlir" else: return None attn_spec = urlopen(url).read().decode("utf-8")