Skip to content

Commit

Permalink
Enable vmap, jvp, double backward for apply_rotary_emb_()
Browse files Browse the repository at this point in the history
  • Loading branch information
crowsonkb committed Jan 7, 2025
1 parent 21d12c9 commit 8018de0
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions k_diffusion/models/image_transformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,23 +200,32 @@ def _apply_rotary_emb_inplace(x, theta, conj):


class ApplyRotaryEmbeddingInplace(torch.autograd.Function):
@staticmethod
def forward(x, theta, conj):
_apply_rotary_emb_inplace(x, theta, conj=conj)
return x
generate_vmap_rule = True

@staticmethod
def setup_context(ctx, inputs, output):
_, theta, conj = inputs
x, theta, conj = inputs
ctx.mark_dirty(x)
ctx.save_for_backward(theta)
ctx.save_for_forward(theta)
ctx.conj = conj

@staticmethod
def forward(x, theta, conj):
_apply_rotary_emb_inplace(x, theta, conj)
return x

@staticmethod
def backward(ctx, grad_output):
theta, = ctx.saved_tensors
_apply_rotary_emb_inplace(grad_output, theta, conj=not ctx.conj)
grad_output = ApplyRotaryEmbeddingInplace.apply(grad_output.clone(), theta, not ctx.conj)
return grad_output, None, None

@staticmethod
def jvp(ctx, grad_input, _, __):
theta, = ctx.saved_tensors
return ApplyRotaryEmbeddingInplace.apply(grad_input, theta, ctx.conj)


def apply_rotary_emb_(x, theta):
return ApplyRotaryEmbeddingInplace.apply(x, theta, False)
Expand Down

0 comments on commit 8018de0

Please sign in to comment.