Skip to content

Commit

Permalink
Attn debugging tools
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Jun 17, 2024
1 parent d7c709e commit 94ba46d
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,49 @@ def run_diffusers_mmdit(

return noise_pred.numpy()

def run_attn_turbine(q, k, v, args):
attn_runner = vmfbRunner(
args.device,
args.vmfb_path,
None,
)
iree_inputs = [
ireert.asdevicearray(attn_runner.config.device, q),
ireert.asdevicearray(attn_runner.config.device, k),
ireert.asdevicearray(attn_runner.config.device, v),
]
attn_output = attn_runner.ctx.modules.compiled_attn["run_forward"](
*iree_inputs
).to_host()
return attn_output

@torch.no_grad()
def run_attn_torch(q, k, v, args):
from turbine_models.custom_models.sd3_inference.sd3_mmdit import MMDiTAttention

mmdit_attn = MMDiTAttention()
attn_output = mmdit_attn.forward(
torch.tensor(q, dtype=torch.float32),
torch.tensor(k, dtype=torch.float32),
torch.tensor(v, dtype=torch.float32),
)

return attn_output.numpy()

def find_errs(turbine_output, torch_output, dim=[], failed_dims=[], errs=[]):
if not np.allclose(turbine_output, torch_output, rtol=4e-2, atol=4e-2):
if turbine_output.ndim > 0:
orig_dim = dim
for idx, i in enumerate(torch_output):
dim = [*orig_dim, idx]
try:
np.testing.assert_allclose(turbine_output[idx], torch_output[idx], rtol=4e-2, atol=4e-2)
except Exception as e:
err = np.abs(turbine_output[idx] - torch_output[idx])
failed_dims.append(dim)
errs.append([err, turbine_output[idx], torch_output[idx]])
failed_dims, errs = find_errs(turbine_output[idx], torch_output[idx], dim, failed_dims, errs)
return (failed_dims, errs)

if __name__ == "__main__":
from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args
Expand All @@ -69,6 +112,29 @@ def run_diffusers_mmdit(
dtype = torch.float16
else:
dtype = torch.float32

if args.attn_repro:
qkv_shape = (2, 24, 4250, 64)
example_qkv = [
np.load("q.npy").astype(np.float16),
np.load("k.npy").astype(np.float16),
np.load("v.npy").astype(np.float16),
]
turbine_output = run_attn_turbine(
*example_qkv,
args,
)
torch_output = run_attn_torch(*example_qkv, args).astype(np.float16)
np.save("turbine_attn_output.npy", turbine_output)
np.save("torch_attn_output.npy", torch_output)
failed_dims, errs = find_errs(turbine_output, torch_output)
for idx, dim in enumerate(failed_dims):
if len(dim) == len(torch_output.shape):
print("Failed dimension: ", dim, " with error: ", errs[idx][0])
print("Turbine output: ", errs[idx][1])
print("Torch output: ", errs[idx][2])
print(torch_output.shape)
exit()

batch_size = args.batch_size * 2 #do classifier free guidance
hidden_states = torch.randn(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,9 @@ def step(self, noise_pred, t, sample, guidance_scale, i):
sample = self.model.step(noise_pred, t, sample, return_dict=False)[0]
return sample.type(self.dtype)


class SharkSchedulerCPUWrapper:
# Wraps a diffusers scheduler running on native pytorch+cpu.
# This allows us to use it interchangeably with compiled schedulers in our pipeline(s).
class TorchCPUFlowSchedulerCompat:
@torch.no_grad()
def __init__(
self, scheduler, batch_size, num_inference_steps, dest_device, latents_dtype
Expand Down

0 comments on commit 94ba46d

Please sign in to comment.