diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py index 599132480..50a5fb285 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py @@ -57,6 +57,49 @@ def run_diffusers_mmdit( return noise_pred.numpy() +def run_attn_turbine(q, k, v, args): + attn_runner = vmfbRunner( + args.device, + args.vmfb_path, + None, + ) + iree_inputs = [ + ireert.asdevicearray(attn_runner.config.device, q), + ireert.asdevicearray(attn_runner.config.device, k), + ireert.asdevicearray(attn_runner.config.device, v), + ] + attn_output = attn_runner.ctx.modules.compiled_attn["run_forward"]( + *iree_inputs + ).to_host() + return attn_output + +@torch.no_grad() +def run_attn_torch(q, k, v, args): + from turbine_models.custom_models.sd3_inference.sd3_mmdit import MMDiTAttention + + mmdit_attn = MMDiTAttention() + attn_output = mmdit_attn.forward( + torch.tensor(q, dtype=torch.float32), + torch.tensor(k, dtype=torch.float32), + torch.tensor(v, dtype=torch.float32), + ) + + return attn_output.numpy() + +def find_errs(turbine_output, torch_output, dim=[], failed_dims=[], errs=[]): + if not np.allclose(turbine_output, torch_output, rtol=4e-2, atol=4e-2): + if turbine_output.ndim > 0: + orig_dim = dim + for idx, i in enumerate(torch_output): + dim = [*orig_dim, idx] + try: + np.testing.assert_allclose(turbine_output[idx], torch_output[idx], rtol=4e-2, atol=4e-2) + except Exception as e: + err = np.abs(turbine_output[idx] - torch_output[idx]) + failed_dims.append(dim) + errs.append([err, turbine_output[idx], torch_output[idx]]) + failed_dims, errs = find_errs(turbine_output[idx], torch_output[idx], dim, failed_dims, errs) + return (failed_dims, errs) if __name__ == "__main__": from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args @@ -69,6 +112,29 @@ def run_diffusers_mmdit( dtype = torch.float16 else: dtype = torch.float32 + + if args.attn_repro: + qkv_shape = (2, 24, 4250, 64) + example_qkv = [ + np.load("q.npy").astype(np.float16), + np.load("k.npy").astype(np.float16), + np.load("v.npy").astype(np.float16), + ] + turbine_output = run_attn_turbine( + *example_qkv, + args, + ) + torch_output = run_attn_torch(*example_qkv, args).astype(np.float16) + np.save("turbine_attn_output.npy", turbine_output) + np.save("torch_attn_output.npy", torch_output) + failed_dims, errs = find_errs(turbine_output, torch_output) + for idx, dim in enumerate(failed_dims): + if len(dim) == len(torch_output.shape): + print("Failed dimension: ", dim, " with error: ", errs[idx][0]) + print("Turbine output: ", errs[idx][1]) + print("Torch output: ", errs[idx][2]) + print(torch_output.shape) + exit() batch_size = args.batch_size * 2 #do classifier free guidance hidden_states = torch.randn( diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py index 607689b3a..06c23ef1c 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py @@ -93,8 +93,9 @@ def step(self, noise_pred, t, sample, guidance_scale, i): sample = self.model.step(noise_pred, t, sample, return_dict=False)[0] return sample.type(self.dtype) - -class SharkSchedulerCPUWrapper: +# Wraps a diffusers scheduler running on native pytorch+cpu. +# This allows us to use it interchangeably with compiled schedulers in our pipeline(s). +class TorchCPUFlowSchedulerCompat: @torch.no_grad() def __init__( self, scheduler, batch_size, num_inference_steps, dest_device, latents_dtype