Skip to content

Commit

Permalink
Attn debugging, piping for multi-device in sd3
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Jun 17, 2024
1 parent 81ee093 commit b793686
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,17 @@ def find_errs(turbine_output, torch_output, dim=[], failed_dims=[], errs=[]):

if args.precision == "fp16":
dtype = torch.float16
np_dtype = np.float16
else:
dtype = torch.float32
np_dtype = np.float32

if args.attn_repro:
qkv_shape = (2, 24, 4250, 64)
example_qkv = [
np.load("q.npy").astype(np.float16),
np.load("k.npy").astype(np.float16),
np.load("v.npy").astype(np.float16),
np.load("q.npy").astype(np_dtype),
np.load("k.npy").astype(np_dtype),
np.load("v.npy").astype(np_dtype),
]
turbine_output = run_attn_turbine(
*example_qkv,
Expand Down
49 changes: 36 additions & 13 deletions models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def __init__(
max_length: int,
batch_size: int,
num_inference_steps: int,
device: str,
iree_target_triple: str,
device: str | dict[str],
iree_target_triple: str | dict[str],
ireec_flags: dict = EMPTY_FLAGS,
attn_spec: str = None,
decomp_attn: bool = False,
Expand All @@ -89,7 +89,25 @@ def __init__(
self.max_length = max_length
self.batch_size = batch_size
self.num_inference_steps = num_inference_steps
self.device = device
self.devices = {}
if isinstance(self.device, dict):
assert isinstance(iree_target_triple, dict), "Device and target triple must be both dicts or both strings."
self.devices["clip"] = {
"device": device["clip"],
"target": iree_target_triple["clip"]
}
self.devices["mmdit"] = {
"device": device["mmdit"],
"target": iree_target_triple["mmdit"]
}
self.devices["vae"] = {
"device": device["vae"],
"target": iree_target_triple["vae"]
}
else:
self.devices["clip"] = device
self.devices["mmdit"] = device
self.devices["vae"] = device
self.iree_target_triple = iree_target_triple
self.ireec_flags = ireec_flags if ireec_flags else EMPTY_FLAGS
self.attn_spec = attn_spec
Expand Down Expand Up @@ -291,8 +309,8 @@ def export_submodel(
"vmfb",
self.external_weights,
mmdit_external_weight_path,
self.device,
self.iree_target_triple,
self.devices["mmdit"]["device"],
self.devices["mmdit"]["target"],
self.ireec_flags["mmdit"],
self.decomp_attn,
exit_on_vmfb=False,
Expand All @@ -313,8 +331,8 @@ def export_submodel(
self.num_inference_steps,
self.precision,
"vmfb",
self.device,
self.iree_target_triple,
self.devices["mmdit"]["device"],
self.devices["mmdit"]["target"],
self.ireec_flags["scheduler"],
exit_on_vmfb=False,
pipeline_dir=self.pipeline_dir,
Expand All @@ -336,8 +354,8 @@ def export_submodel(
"vmfb",
self.external_weights,
vae_external_weight_path,
self.device,
self.iree_target_triple,
self.devices["vae"]["device"],
self.devices["vae"]["target"],
self.ireec_flags["vae"],
self.vae_decomp_attn,
exit_on_vmfb=False,
Expand All @@ -357,8 +375,8 @@ def export_submodel(
"vmfb",
self.external_weights,
text_encoders_external_weight_path,
self.device,
self.iree_target_triple,
self.devices["clip"]["device"],
self.devices["clip"]["target"],
self.ireec_flags["clip"],
exit_on_vmfb=False,
pipeline_dir=self.pipeline_dir,
Expand All @@ -374,10 +392,15 @@ def load_pipeline(
self,
vmfbs: dict,
weights: dict,
rt_device: str = "local-task",
rt_device: str | dict[str],
compiled_pipeline: bool = False,
split_scheduler: bool = True,
extra_device_args: dict = {},
):
if "npu_delegate_path" in extra_device_args.keys():
delegate = extra_device_args["npu_delegate_path"]
else:
delegate = None
self.runners = {}
runners = {}
load_start = time.time()
Expand All @@ -399,7 +422,7 @@ def load_pipeline(
runners["vae"] = vmfbRunner(
rt_device,
vmfbs["vae"],
weights["vae"],
weights["vae"],
)
vae_loaded = time.time()
print("\n[LOG] VAE Decode loaded in ", vae_loaded - sched_loaded, "sec")
Expand Down
11 changes: 9 additions & 2 deletions models/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,19 @@
}
znver4_flags = {
"all": [
# "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-linalg-ext-convert-conv2d-to-winograd{replace-all-convs=true},iree-global-opt-demote-contraction-inputs-to-bf16))",
"--iree-llvmcpu-target-cpu=znver4",
"--iree-opt-const-eval=false",
"--iree-llvmcpu-enable-ukernels=mmt4d,pack,unpack",
"--iree-flow-collapse-reduction-dims",
"--iree-opt-const-expr-max-size-increase-threshold=1000000000000000",
"--iree-flow-enable-fuse-padding-into-linalg-consumer-ops",
],
"bf16": [
"--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-demote-contraction-inputs-to-bf16))",
],
"winograd": [
"--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-linalg-ext-convert-conv2d-to-winograd{replace-all-convs=true},iree-global-opt-demote-contraction-inputs-to-bf16))"
],
}


Expand Down Expand Up @@ -182,10 +187,12 @@ def compile_to_vmfb(
if attn_spec in ["default", "mfma"]:
attn_spec = get_mfma_spec_path(target_triple, os.path.dirname(safe_name))
flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec])
elif attn_spec in ["wmma"] or "gfx11" in target_triple:
elif attn_spec in ["wmma"] or ("gfx11" in target_triple and not attn_spec):
attn_spec = get_wmma_spec_path(target_triple, os.path.dirname(safe_name))
if attn_spec:
flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec])
elif attn_spec and attn_spec != "None":
flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec])

for i, flag in enumerate(ireec_flags):
k = flag.strip().split("=")[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -773,14 +773,15 @@ def generate_images(
](samples[i], prompt_embeds, add_text_embeds, guidance_scale)

vae_start = time.time()
np.save("latents_winter_cat.npy", latents.to_host().astype(np.float32))
vae_out = self.runners["vae_decode"].ctx.modules.compiled_vae["main"](
latents
)

pipe_end = time.time()

image = vae_out.to_host()

np.save("image_winter_cat.npy", image.astype(np.float32))
numpy_images.append(image)
print("Batch #", i + 1, "\n")
print(
Expand Down

0 comments on commit b793686

Please sign in to comment.