diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 54be0696..303c91e0 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -407,18 +407,28 @@ def forward(self, x, pos, cond): skip = x x = self.norm(x, cond) qkv = self.qkv_proj(x) - q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) - q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6) - theta = self.pos_emb(pos).movedim(-2, -4) - q = apply_rotary_emb_(q, theta) - k = apply_rotary_emb_(k, theta) if natten is None: raise ModuleNotFoundError("natten is required for neighborhood attention") - flops.op(flops.op_natten, q.shape, k.shape, v.shape, self.kernel_size) - qk = natten.functional.natten2dqk(q, k, self.kernel_size, 1) - a = torch.softmax(qk, dim=-1).to(v.dtype) - x = natten.functional.natten2dav(a, v, self.kernel_size, 1) - x = rearrange(x, "n nh h w e -> n h w (nh e)") + if natten.has_fused_na(): + q, k, v = rearrange(qkv, "n h w (t nh e) -> t n h w nh e", t=3, e=self.d_head) + q, k = scale_for_cosine_sim(q, k, self.scale[:, None], 1e-6) + theta = self.pos_emb(pos) + q = apply_rotary_emb_(q, theta) + k = apply_rotary_emb_(k, theta) + flops.op(flops.op_natten, q.shape, k.shape, v.shape, self.kernel_size) + x = natten.functional.na2d(q, k, v, self.kernel_size, scale=1.0) + x = rearrange(x, "n h w nh e -> n h w (nh e)") + else: + q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) + q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6) + theta = self.pos_emb(pos).movedim(-2, -4) + q = apply_rotary_emb_(q, theta) + k = apply_rotary_emb_(k, theta) + flops.op(flops.op_natten, q.shape, k.shape, v.shape, self.kernel_size) + qk = natten.functional.na2d_qk(q, k, self.kernel_size) + a = torch.softmax(qk, dim=-1).to(v.dtype) + x = natten.functional.na2d_av(a, v, self.kernel_size) + x = rearrange(x, "n nh h w e -> n h w (nh e)") x = self.dropout(x) x = self.out_proj(x) return x + skip