diff --git a/models/requirements.txt b/models/requirements.txt index ead79c1d9..bdd1892e8 100644 --- a/models/requirements.txt +++ b/models/requirements.txt @@ -1,5 +1,4 @@ protobuf -sentencepiece shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main transformers==4.37.1 torchsde diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py b/models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py index b3250ea35..55cf3b72d 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py @@ -247,6 +247,12 @@ def is_valid_file(arg): default="fp16", help="Precision of Stable Diffusion weights and graph.", ) +p.add_argument( + "--vae_precision", + type=str, + default=None, + help="Precision of Stable Diffusion VAE weights and graph.", +) p.add_argument( "--max_length", type=int, default=77, help="Sequence Length of Stable Diffusion" ) @@ -257,7 +263,7 @@ def is_valid_file(arg): p.add_argument( "--vae_decomp_attn", type=bool, - default=True, + default=False, help="Decompose attention for VAE decode only at fx graph level", ) p.add_argument( @@ -340,6 +346,12 @@ def is_valid_file(arg): action="store_true", help="Just compile attention reproducer for mmdit.", ) +p.add_argument( + "--vae_input_path", + type=str, + default=None, + help="Path to input latents for VAE inference numerics validation.", +) ############################################################################## diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py index 9d6ea012d..05d3e00cb 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py @@ -207,7 +207,7 @@ def export_mmdit_model( torch.empty(hidden_states_shape, dtype=dtype), torch.empty(encoder_hidden_states_shape, dtype=dtype), torch.empty(pooled_projections_shape, dtype=dtype), - torch.empty(1, dtype=dtype), + torch.empty(init_batch_dim, dtype=dtype), ] decomp_list = [] diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py index a0be81192..06100eab3 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py @@ -154,7 +154,7 @@ def find_errs(turbine_output, torch_output, dim=[], failed_dims=[], errs=[]): (batch_size, args.max_length * 2, 4096), dtype=dtype ) pooled_projections = torch.randn((batch_size, 2048), dtype=dtype) - timestep = torch.tensor([0], dtype=dtype) + timestep = torch.tensor([0, 0], dtype=dtype) turbine_output = run_mmdit_turbine( hidden_states, @@ -180,6 +180,7 @@ def find_errs(turbine_output, torch_output, dim=[], failed_dims=[], errs=[]): timestep, args, ) + np.save("torch_mmdit_output.npy", torch_output.astype(np.float16)) print("torch OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) print("\n(torch (comfy) image latents to iree image latents): ") diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py b/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py index 686e2b453..303ba326e 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py @@ -17,6 +17,7 @@ from turbine_models.custom_models.sd_inference import utils from turbine_models.model_runner import vmfbRunner from transformers import CLIPTokenizer +from diffusers import FlowMatchEulerDiscreteScheduler from PIL import Image import os @@ -44,10 +45,8 @@ class SharkSD3Pipeline: def __init__( self, hf_model_name: str, - # scheduler_id: str, height: int, width: int, - shift: float, precision: str, max_length: int, batch_size: int, @@ -60,9 +59,12 @@ def __init__( pipeline_dir: str = "./shark_vmfbs", external_weights_dir: str = "./shark_weights", external_weights: str = "safetensors", - vae_decomp_attn: bool = True, - custom_vae: str = "", + vae_decomp_attn: bool = False, cpu_scheduling: bool = False, + vae_precision: str = "fp32", + scheduler_id: str = None, #compatibility only, always uses EulerFlowScheduler + shift: float = 1.0, + ): self.hf_model_name = hf_model_name # self.scheduler_id = scheduler_id @@ -120,10 +122,11 @@ def __init__( self.external_weights_dir = external_weights_dir self.external_weights = external_weights self.vae_decomp_attn = vae_decomp_attn - self.custom_vae = custom_vae + self.custom_vae = None self.cpu_scheduling = cpu_scheduling self.torch_dtype = torch.float32 if self.precision == "fp32" else torch.float16 - self.vae_dtype = torch.float32 + self.vae_precision = vae_precision if vae_precision else self.precision + self.vae_dtype = torch.float32 if vae_precision == "fp32" else torch.float16 # TODO: set this based on user-inputted guidance scale and negative prompt. self.do_classifier_free_guidance = True # False if any(x in hf_model_name for x in ["turbo", "lightning"]) else True @@ -206,7 +209,12 @@ def is_prepared(self, vmfbs, weights): ) if w_key == "clip": default_name = os.path.join( - self.external_weights_dir, f"sd3_clip_fp16.irpa" + self.external_weights_dir, f"sd3_text_encoders_{self.precision}.irpa" + ) + if w_key == "mmdit": + default_name = os.path.join( + self.external_weights_dir, + f"sd3_mmdit_{self.precision}." + self.external_weights, ) if weights[w_key] is None and os.path.exists(default_name): weights[w_key] = os.path.join(default_name) @@ -357,7 +365,7 @@ def export_submodel( self.batch_size, self.height, self.width, - "fp32", + self.vae_precision, "vmfb", self.external_weights, vae_external_weight_path, @@ -419,10 +427,16 @@ def load_pipeline( unet_loaded = time.time() print("\n[LOG] MMDiT loaded in ", unet_loaded - load_start, "sec") - runners["scheduler"] = sd3_schedulers.SharkSchedulerWrapper( - self.devices["mmdit"]["driver"], - vmfbs["scheduler"], - ) + if not self.cpu_scheduling: + runners["scheduler"] = sd3_schedulers.SharkSchedulerWrapper( + self.devices["mmdit"]["driver"], + vmfbs["scheduler"], + ) + else: + print("Using torch CPU scheduler.") + runners["scheduler"] = FlowMatchEulerDiscreteScheduler.from_pretrained( + self.hf_model_name, subfolder="scheduler" + ) sched_loaded = time.time() print("\n[LOG] Scheduler loaded in ", sched_loaded - unet_loaded, "sec") @@ -495,11 +509,12 @@ def generate_images( ) ) - guidance_scale = ireert.asdevicearray( - self.runners["pipe"].config.device, - np.asarray([guidance_scale]), - dtype=iree_dtype, - ) + if not self.cpu_scheduling: + guidance_scale = ireert.asdevicearray( + self.runners["pipe"].config.device, + np.asarray([guidance_scale]), + dtype=iree_dtype, + ) tokenize_start = time.time() text_input_ids_dict = self.tokenizer.tokenize_with_weights(prompt) @@ -533,12 +548,23 @@ def generate_images( "clip" ].ctx.modules.compiled_text_encoder["encode_tokens"](*text_encoders_inputs) encode_prompts_end = time.time() + if self.cpu_scheduling: + timesteps, num_inference_steps = sd3_schedulers.retrieve_timesteps( + self.runners["scheduler"], + num_inference_steps=self.num_inference_steps, + timesteps=None, + ) + steps = num_inference_steps + for i in range(batch_count): unet_start = time.time() - sample, steps, timesteps = self.runners["scheduler"].initialize(samples[i]) + if not self.cpu_scheduling: + latents, steps, timesteps = self.runners["scheduler"].initialize(samples[i]) + else: + latents = torch.tensor(samples[i].to_host(), dtype=self.torch_dtype) iree_inputs = [ - sample, + latents, ireert.asdevicearray( self.runners["pipe"].config.device, prompt_embeds, dtype=iree_dtype ), @@ -553,40 +579,71 @@ def generate_images( # print(f"step {s}") if self.cpu_scheduling: step_index = s + t = timesteps[s] + if self.do_classifier_free_guidance: + latent_model_input = torch.cat([latents] * 2) + timestep = ireert.asdevicearray( + self.runners["pipe"].config.device, + t.expand(latent_model_input.shape[0]), + dtype=iree_dtype, + ) + latent_model_input = ireert.asdevicearray( + self.runners["pipe"].config.device, + latent_model_input, + dtype=iree_dtype, + ) else: step_index = ireert.asdevicearray( self.runners["scheduler"].runner.config.device, torch.tensor([s]), "int64", ) - latents, t = self.runners["scheduler"].prep( - sample, - step_index, - timesteps, - ) + latent_model_input, timestep = self.runners["scheduler"].prep( + latents, + step_index, + timesteps, + ) + t = ireert.asdevicearray( + self.runners["scheduler"].runner.config.device, + timestep.to_host()[0] + ) noise_pred = self.runners["pipe"].ctx.modules.compiled_mmdit[ "run_forward" ]( - latents, + latent_model_input, iree_inputs[1], iree_inputs[2], - t, + timestep, ) - sample = self.runners["scheduler"].step( - noise_pred, - t, - sample, - guidance_scale, - step_index, - ) - if isinstance(sample, torch.Tensor): + if not self.cpu_scheduling: + latents = self.runners["scheduler"].step( + noise_pred, + t, + latents, + guidance_scale, + step_index, + ) + else: + noise_pred = torch.tensor(noise_pred.to_host(), dtype=self.torch_dtype) + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + latents = self.runners["scheduler"].step( + noise_pred, + t, + latents, + return_dict=False, + )[0] + + if isinstance(latents, torch.Tensor): + latents = latents.type(self.vae_dtype) latents = ireert.asdevicearray( self.runners["vae"].config.device, - sample, - dtype=self.vae_dtype, + latents, ) else: - latents = sample.astype("float32") + vae_numpy_dtype = np.float32 if self.vae_precision == "fp32" else np.float16 + latents = latents.astype(vae_numpy_dtype) vae_start = time.time() vae_out = self.runners["vae"].ctx.modules.compiled_vae["decode"](latents) @@ -634,7 +691,7 @@ def generate_images( out_image = Image.fromarray(image) images.extend([[out_image]]) if return_imgs: - return images + return images[0] for idx_batch, image_batch in enumerate(images): for idx, image in enumerate(image_batch): img_path = ( @@ -767,7 +824,6 @@ def run_diffusers_cpu( args.hf_model_name, args.height, args.width, - args.shift, args.precision, args.max_length, args.batch_size, @@ -779,16 +835,15 @@ def run_diffusers_cpu( args.decomp_attn, args.pipeline_dir, args.external_weights_dir, - args.external_weights, - args.vae_decomp_attn, - custom_vae=None, + external_weights=args.external_weights, + vae_decomp_attn=args.vae_decomp_attn, cpu_scheduling=args.cpu_scheduling, vae_precision=args.vae_precision, ) - vmfbs, weights = sd3_pipe.check_prepared(mlirs, vmfbs, weights) if args.cpu_scheduling: vmfbs.pop("scheduler") weights.pop("scheduler") + vmfbs, weights = sd3_pipe.check_prepared(mlirs, vmfbs, weights) if args.npu_delegate_path: extra_device_args = {"npu_delegate_path": args.npu_delegate_path} else: diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py index 86179746a..2efb13aa9 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py @@ -5,9 +5,11 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import os +import inspect from typing import List import torch +from typing import Any, Callable, Dict, List, Optional, Union from shark_turbine.aot import * import shark_turbine.ops.iree as ops from iree.compiler.ir import Context @@ -75,11 +77,12 @@ def initialize(self, sample): def prepare_model_input(self, sample, t, timesteps): t = timesteps[t] - t = t.expand(sample.shape[0]) + if self.do_classifier_free_guidance: latent_model_input = torch.cat([sample] * 2) else: latent_model_input = sample + t = t.expand(latent_model_input.shape[0]) return latent_model_input.type(self.dtype), t.type(self.dtype) def step(self, noise_pred, t, sample, guidance_scale, i): @@ -146,6 +149,42 @@ def step(self, noise_pred, t, latents, guidance_scale, i): return_dict=False, )[0] +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +# Only used for cpu scheduling. +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps @torch.no_grad() def export_scheduler_model( diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_vae.py b/models/turbine_models/custom_models/sd3_inference/sd3_vae.py index a70c19882..5bd6f0f5b 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_vae.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_vae.py @@ -33,6 +33,7 @@ def __init__( ) def decode(self, inp): + inp = (inp / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(inp, return_dict=False)[0] image = image.float() image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_vae_runner.py b/models/turbine_models/custom_models/sd3_inference/sd3_vae_runner.py index 23db4ab73..9cb435bde 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_vae_runner.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_vae_runner.py @@ -15,8 +15,8 @@ def run_vae( ): runner = vmfbRunner(device, vmfb_path, external_weight_path) inputs = [ireert.asdevicearray(runner.config.device, example_input)] - results = runner.ctx.modules.compiled_vae["decode"](*inputs) - + results = runner.ctx.modules.compiled_vae["decode"](*inputs).to_host() + results = imagearray_from_vae_out(results) return results @@ -32,17 +32,30 @@ def run_torch_vae(hf_model_name, variant, example_input): elif variant == "encode": results = vae_model.encode(example_input) np_torch_output = results.detach().cpu().numpy() + np_torch_output = imagearray_from_vae_out(np_torch_output) return np_torch_output +def imagearray_from_vae_out(image): + if image.ndim == 4: + image = image[0] + image = torch.from_numpy(image).cpu().permute(1, 2, 0).float().numpy() + image = (image * 255).round().astype("uint8") + return image if __name__ == "__main__": from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args + import numpy as np + from PIL import Image dtype = torch.float16 if args.precision == "fp16" else torch.float32 if args.vae_variant == "decode": example_input = torch.rand( args.batch_size, 16, args.height // 8, args.width // 8, dtype=dtype ) + if args.vae_input_path: + example_input = np.load(args.vae_input_path) + if example_input.shape[0] == 2: + example_input = np.split(example_input, 2)[0] elif args.vae_variant == "encode": example_input = torch.rand( args.batch_size, 3, args.height, args.width, dtype=dtype @@ -57,21 +70,25 @@ def run_torch_vae(hf_model_name, variant, example_input): ) print( "TURBINE OUTPUT:", - turbine_results.to_host(), - turbine_results.to_host().shape, - turbine_results.to_host().dtype, + turbine_results, + turbine_results.shape, + turbine_results.dtype, ) if args.compare_vs_torch: print("generating torch output: ") from turbine_models.custom_models.sd_inference import utils torch_output = run_torch_vae( - args.hf_model_name, args.vae_variant, example_input.float() + args.hf_model_name, args.vae_variant, torch.tensor(example_input).float() ) print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) - err = utils.largest_error(torch_output, turbine_results) - print("Largest Error: ", err) - assert err < 2e-3 + if args.vae_input_path: + out_image_torch = Image.fromarray(torch_output) + out_image_torch.save("vae_test_output_torch.png") + out_image_turbine = Image.fromarray(turbine_results) + out_image_turbine.save("vae_test_output_turbine.png") + # Allow a small amount of wiggle room for rounding errors (1) - # TODO: Figure out why we occasionally segfault without unlinking output variables - turbine_results = None + np.testing.assert_allclose( + turbine_results, torch_output, rtol=1, atol=1 + ) diff --git a/models/turbine_models/custom_models/sd3_inference/text_encoder_impls.py b/models/turbine_models/custom_models/sd3_inference/text_encoder_impls.py index 29b9d2f80..747b60d9b 100644 --- a/models/turbine_models/custom_models/sd3_inference/text_encoder_impls.py +++ b/models/turbine_models/custom_models/sd3_inference/text_encoder_impls.py @@ -341,8 +341,10 @@ def __init__(self): self.clip_g = SDXLClipGTokenizer(clip_tokenizer) self.t5xxl = T5XXLTokenizer() - def tokenize_with_weights(self, text: str): + def tokenize_with_weights(self, text: str | list[str]): out = {} + if isinstance(text, list): + text = text[0] out["g"] = self.clip_g.tokenize_with_weights(text) out["l"] = self.clip_l.tokenize_with_weights(text) out["t5xxl"] = self.t5xxl.tokenize_with_weights(text) diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index e4b755131..0931a4028 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -35,7 +35,7 @@ "--iree-codegen-gpu-native-math-precision=true", "--iree-rocm-waves-per-eu=2", "--iree-flow-inline-constants-max-byte-length=1", - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,128,0,32,0}))", ], "unet": [ "--iree-flow-enable-aggressive-fusion", @@ -275,7 +275,7 @@ def create_safe_name(hf_model_name, model_name_str): def get_mfma_spec_path(target_chip, save_dir): - url = "https://raw.githubusercontent.com/iree-org/iree/main/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir" + url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx942.mlir" attn_spec = urlopen(url).read().decode("utf-8") spec_path = os.path.join(save_dir, "attention_and_matmul_spec_mfma.mlir") if os.path.exists(spec_path): @@ -287,9 +287,9 @@ def get_mfma_spec_path(target_chip, save_dir): def get_wmma_spec_path(target_chip, save_dir): if target_chip == "gfx1100": - url = "https://github.com/iree-org/iree/raw/shared/tresleches-united/scripts/attention_gfx1100.spec.mlir" + url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx1100.mlir" elif target_chip in ["gfx1103", "gfx1150"]: - url = "https://github.com/iree-org/iree/raw/shared/tresleches-united/scripts/attention_gfx1103.spec.mlir" + url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx1150.mlir" else: return None attn_spec = urlopen(url).read().decode("utf-8") diff --git a/models/turbine_models/custom_models/sdxl_inference/unet_runner.py b/models/turbine_models/custom_models/sdxl_inference/unet_runner.py index 4437b9eae..9d0b405c3 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet_runner.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet_runner.py @@ -31,9 +31,8 @@ def run_unet( ireert.asdevicearray(runner.config.device, prompt_embeds), ireert.asdevicearray(runner.config.device, text_embeds), ireert.asdevicearray(runner.config.device, time_ids), - ireert.asdevicearray(runner.config.device, guidance_scale), ] - results = runner.ctx.modules.compiled_unet["main"](*inputs) + results = runner.ctx.modules.compiled_unet["run_forward"](*inputs) return results @@ -57,7 +56,6 @@ def run_unet_steps( ireert.asdevicearray(runner.config.device, prompt_embeds), ireert.asdevicearray(runner.config.device, text_embeds), ireert.asdevicearray(runner.config.device, time_ids), - ireert.asdevicearray(runner.config.device, (guidance_scale,)), ] for i, t in tqdm(enumerate(scheduler.timesteps)): timestep = t @@ -69,7 +67,7 @@ def run_unet_steps( inputs[1] = timestep = ireert.asdevicearray( runner.config.device, (timestep,), dtype="int64" ) - noise_pred = runner.ctx.modules.compiled_unet["main"](*inputs).to_host() + noise_pred = runner.ctx.modules.compiled_unet["run_forward"](*inputs).to_host() sample = scheduler.step( torch.from_numpy(noise_pred).cpu(), timestep,