From 6ab5146d4a5ef63901326489f31f1d8e7dd36b48 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Wed, 24 Jan 2024 19:36:36 +0000 Subject: [PATCH] Cast a to v's dtype in neighborhood attention blocks --- k_diffusion/models/image_transformer_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index f7ac209..54be069 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -416,7 +416,7 @@ def forward(self, x, pos, cond): raise ModuleNotFoundError("natten is required for neighborhood attention") flops.op(flops.op_natten, q.shape, k.shape, v.shape, self.kernel_size) qk = natten.functional.natten2dqk(q, k, self.kernel_size, 1) - a = torch.softmax(qk, dim=-1) + a = torch.softmax(qk, dim=-1).to(v.dtype) x = natten.functional.natten2dav(a, v, self.kernel_size, 1) x = rearrange(x, "n nh h w e -> n h w (nh e)") x = self.dropout(x)