diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 303c91e0..4f3410d8 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -200,23 +200,32 @@ def _apply_rotary_emb_inplace(x, theta, conj): class ApplyRotaryEmbeddingInplace(torch.autograd.Function): - @staticmethod - def forward(x, theta, conj): - _apply_rotary_emb_inplace(x, theta, conj=conj) - return x + generate_vmap_rule = True @staticmethod def setup_context(ctx, inputs, output): - _, theta, conj = inputs + x, theta, conj = inputs + ctx.mark_dirty(x) ctx.save_for_backward(theta) + ctx.save_for_forward(theta) ctx.conj = conj + @staticmethod + def forward(x, theta, conj): + _apply_rotary_emb_inplace(x, theta, conj) + return x + @staticmethod def backward(ctx, grad_output): theta, = ctx.saved_tensors - _apply_rotary_emb_inplace(grad_output, theta, conj=not ctx.conj) + grad_output = ApplyRotaryEmbeddingInplace.apply(grad_output.clone(), theta, not ctx.conj) return grad_output, None, None + @staticmethod + def jvp(ctx, grad_input, _, __): + theta, = ctx.saved_tensors + return ApplyRotaryEmbeddingInplace.apply(grad_input, theta, ctx.conj) + def apply_rotary_emb_(x, theta): return ApplyRotaryEmbeddingInplace.apply(x, theta, False)