diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py index 50a5fb285..f547b59bb 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py @@ -110,15 +110,17 @@ def find_errs(turbine_output, torch_output, dim=[], failed_dims=[], errs=[]): if args.precision == "fp16": dtype = torch.float16 + np_dtype = np.float16 else: dtype = torch.float32 + np_dtype = np.float32 if args.attn_repro: qkv_shape = (2, 24, 4250, 64) example_qkv = [ - np.load("q.npy").astype(np.float16), - np.load("k.npy").astype(np.float16), - np.load("v.npy").astype(np.float16), + np.load("q.npy").astype(np_dtype), + np.load("k.npy").astype(np_dtype), + np.load("v.npy").astype(np_dtype), ] turbine_output = run_attn_turbine( *example_qkv, diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py b/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py index 4aed8a962..c16c22c3c 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py @@ -68,8 +68,8 @@ def __init__( max_length: int, batch_size: int, num_inference_steps: int, - device: str, - iree_target_triple: str, + device: str | dict[str], + iree_target_triple: str | dict[str], ireec_flags: dict = EMPTY_FLAGS, attn_spec: str = None, decomp_attn: bool = False, @@ -89,7 +89,25 @@ def __init__( self.max_length = max_length self.batch_size = batch_size self.num_inference_steps = num_inference_steps - self.device = device + self.devices = {} + if isinstance(self.device, dict): + assert isinstance(iree_target_triple, dict), "Device and target triple must be both dicts or both strings." + self.devices["clip"] = { + "device": device["clip"], + "target": iree_target_triple["clip"] + } + self.devices["mmdit"] = { + "device": device["mmdit"], + "target": iree_target_triple["mmdit"] + } + self.devices["vae"] = { + "device": device["vae"], + "target": iree_target_triple["vae"] + } + else: + self.devices["clip"] = device + self.devices["mmdit"] = device + self.devices["vae"] = device self.iree_target_triple = iree_target_triple self.ireec_flags = ireec_flags if ireec_flags else EMPTY_FLAGS self.attn_spec = attn_spec @@ -291,8 +309,8 @@ def export_submodel( "vmfb", self.external_weights, mmdit_external_weight_path, - self.device, - self.iree_target_triple, + self.devices["mmdit"]["device"], + self.devices["mmdit"]["target"], self.ireec_flags["mmdit"], self.decomp_attn, exit_on_vmfb=False, @@ -313,8 +331,8 @@ def export_submodel( self.num_inference_steps, self.precision, "vmfb", - self.device, - self.iree_target_triple, + self.devices["mmdit"]["device"], + self.devices["mmdit"]["target"], self.ireec_flags["scheduler"], exit_on_vmfb=False, pipeline_dir=self.pipeline_dir, @@ -336,8 +354,8 @@ def export_submodel( "vmfb", self.external_weights, vae_external_weight_path, - self.device, - self.iree_target_triple, + self.devices["vae"]["device"], + self.devices["vae"]["target"], self.ireec_flags["vae"], self.vae_decomp_attn, exit_on_vmfb=False, @@ -357,8 +375,8 @@ def export_submodel( "vmfb", self.external_weights, text_encoders_external_weight_path, - self.device, - self.iree_target_triple, + self.devices["clip"]["device"], + self.devices["clip"]["target"], self.ireec_flags["clip"], exit_on_vmfb=False, pipeline_dir=self.pipeline_dir, @@ -374,10 +392,15 @@ def load_pipeline( self, vmfbs: dict, weights: dict, - rt_device: str = "local-task", + rt_device: str | dict[str], compiled_pipeline: bool = False, split_scheduler: bool = True, + extra_device_args: dict = {}, ): + if "npu_delegate_path" in extra_device_args.keys(): + delegate = extra_device_args["npu_delegate_path"] + else: + delegate = None self.runners = {} runners = {} load_start = time.time() @@ -399,7 +422,7 @@ def load_pipeline( runners["vae"] = vmfbRunner( rt_device, vmfbs["vae"], - weights["vae"], + weights["vae"], ) vae_loaded = time.time() print("\n[LOG] VAE Decode loaded in ", vae_loaded - sched_loaded, "sec") diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 52c980903..4489141d6 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -66,7 +66,6 @@ } znver4_flags = { "all": [ - # "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-linalg-ext-convert-conv2d-to-winograd{replace-all-convs=true},iree-global-opt-demote-contraction-inputs-to-bf16))", "--iree-llvmcpu-target-cpu=znver4", "--iree-opt-const-eval=false", "--iree-llvmcpu-enable-ukernels=mmt4d,pack,unpack", @@ -74,6 +73,12 @@ "--iree-opt-const-expr-max-size-increase-threshold=1000000000000000", "--iree-flow-enable-fuse-padding-into-linalg-consumer-ops", ], + "bf16": [ + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-demote-contraction-inputs-to-bf16))", + ], + "winograd": [ + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-linalg-ext-convert-conv2d-to-winograd{replace-all-convs=true},iree-global-opt-demote-contraction-inputs-to-bf16))" + ], } @@ -182,10 +187,12 @@ def compile_to_vmfb( if attn_spec in ["default", "mfma"]: attn_spec = get_mfma_spec_path(target_triple, os.path.dirname(safe_name)) flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) - elif attn_spec in ["wmma"] or "gfx11" in target_triple: + elif attn_spec in ["wmma"] or ("gfx11" in target_triple and not attn_spec): attn_spec = get_wmma_spec_path(target_triple, os.path.dirname(safe_name)) if attn_spec: flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) + elif attn_spec and attn_spec != "None": + flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) for i, flag in enumerate(ireec_flags): k = flag.strip().split("=")[0] diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py index 514c73118..2edc2866c 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py @@ -773,6 +773,7 @@ def generate_images( ](samples[i], prompt_embeds, add_text_embeds, guidance_scale) vae_start = time.time() + np.save("latents_winter_cat.npy", latents.to_host().astype(np.float32)) vae_out = self.runners["vae_decode"].ctx.modules.compiled_vae["main"]( latents ) @@ -780,7 +781,7 @@ def generate_images( pipe_end = time.time() image = vae_out.to_host() - + np.save("image_winter_cat.npy", image.astype(np.float32)) numpy_images.append(image) print("Batch #", i + 1, "\n") print(