Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
added test and put nystrom attention in the attention file
Browse files Browse the repository at this point in the history
  • Loading branch information
moerlemans committed Sep 5, 2024
1 parent 2f90ee0 commit c8ae154
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 162 deletions.
180 changes: 19 additions & 161 deletions ahcore/models/MIL/transmil.py
Original file line number Diff line number Diff line change
@@ -1,156 +1,14 @@
# this file includes the original nystrom attention and transmil model
# from https://github.com/lucidrains/nystrom-attention/blob/main/nystrom_attention/nystrom_attention.py
# and https://github.com/szc19990412/TransMIL/blob/main/models/TransMIL.py, respectively.
from math import ceil
from typing import Any, Optional

from typing import Any

import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange, reduce
from torch import nn as nn


def moore_penrose_iter_pinv(x: torch.Tensor, iters: int = 6) -> torch.Tensor:
device = x.device

abs_x = torch.abs(x)
col = abs_x.sum(dim=-1)
row = abs_x.sum(dim=-2)
z = rearrange(x, "... i j -> ... j i") / (torch.max(col) * torch.max(row))

eye = torch.eye(x.shape[-1], device=device)
eye = rearrange(eye, "i j -> () i j")

for _ in range(iters):
xz = x @ z
z = 0.25 * z @ (13 * eye - (xz @ (15 * eye - (xz @ (7 * eye - xz)))))

return z


# main attention class


class NystromAttention(nn.Module):
def __init__(
self,
dim: int,
dim_head: int = 64,
heads: int = 8,
num_landmarks: int = 256,
pinv_iterations: int = 6,
residual: bool = True,
residual_conv_kernel: int = 33,
eps: float = 1e-8,
dropout: float = 0.0,
) -> None:
super().__init__()
self.eps = eps
inner_dim = heads * dim_head

self.num_landmarks = num_landmarks
self.pinv_iterations = pinv_iterations

self.heads = heads
self.scale = dim_head**-0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))

self.residual = residual
if residual:
kernel_size = residual_conv_kernel
padding = residual_conv_kernel // 2
self.res_conv = nn.Conv2d(heads, heads, (kernel_size, 1), padding=(padding, 0), groups=heads, bias=False)

def forward(
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_attn: bool = False
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
b, n, _ = x.shape
h, m, iters, eps = self.heads, self.num_landmarks, self.pinv_iterations, self.eps
# pad so that sequence can be evenly divided into m landmarks

remainder = n % m
if remainder > 0:
padding = m - (n % m)
x = F.pad(x, (0, 0, padding, 0), value=0)

if mask is not None:
mask = F.pad(mask, (padding, 0), value=False)

# derive query, keys, values

q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))

# set masked positions to 0 in queries, keys, values

if mask is not None:
mask = rearrange(mask, "b n -> b () n")
q, k, v = map(lambda t: t * mask[..., None], (q, k, v))

q = q * self.scale

# generate landmarks by sum reduction, and then calculate mean using the mask

l_dim = ceil(n / m)
landmark_einops_eq = "... (n l) d -> ... n d"
q_landmarks = reduce(q, landmark_einops_eq, "sum", l=l_dim)
k_landmarks = reduce(k, landmark_einops_eq, "sum", l=l_dim)

# calculate landmark mask, and also get sum of non-masked elements in preparation for masked mean

if mask is not None:
mask_landmarks_sum = reduce(mask, "... (n l) -> ... n", "sum", l=l_dim)
divisor = mask_landmarks_sum[..., None] + eps
mask_landmarks = mask_landmarks_sum > 0
else:
divisor = torch.Tensor([l_dim]).to(q_landmarks.device)

# masked mean (if mask exists)

q_landmarks = q_landmarks / divisor
k_landmarks = k_landmarks / divisor

# similarities

einops_eq = "... i d, ... j d -> ... i j"
sim1 = torch.einsum(einops_eq, q, k_landmarks)
sim2 = torch.einsum(einops_eq, q_landmarks, k_landmarks)
sim3 = torch.einsum(einops_eq, q_landmarks, k)

# masking

if mask is not None:
mask_value = -torch.finfo(q.dtype).max
sim1.masked_fill_(~(mask[..., None] * mask_landmarks[..., None, :]), mask_value)
sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value)
sim3.masked_fill_(~(mask_landmarks[..., None] * mask[..., None, :]), mask_value)

# eq (15) in the paper and aggregate values

attn1, attn2, attn3 = map(lambda t: t.softmax(dim=-1), (sim1, sim2, sim3))
attn2_inv = moore_penrose_iter_pinv(attn2, iters)

out: torch.Tensor = (attn1 @ attn2_inv) @ (attn3 @ v)

# add depth-wise conv residual of values

if self.residual:
out = out + self.res_conv(v)

# merge and combine heads

out = rearrange(out, "b h n d -> b n (h d)", h=h)
out = self.to_out(out)
out = out[:, -n:]

if return_attn:
attn = attn1 @ attn2_inv @ attn3
return out, attn

return out
from ahcore.models.layers.attention import NystromAttention


class TransLayer(nn.Module):
Expand Down Expand Up @@ -193,46 +51,46 @@ def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:


class TransMIL(nn.Module):
def __init__(self, in_features: int = 1024, n_classes: int = 1) -> None:
def __init__(self, in_features: int = 1024, out_features: int = 1, hidden_dimension: int = 512) -> None:
super(TransMIL, self).__init__()
self.pos_layer = PPEG(dim=512)
self._fc1 = nn.Sequential(nn.Linear(in_features, 512), nn.ReLU())
self.cls_token = nn.Parameter(torch.randn(1, 1, 512))
self.n_classes = n_classes
self.layer1 = TransLayer(dim=512)
self.layer2 = TransLayer(dim=512)
self.norm = nn.LayerNorm(512)
self._fc2 = nn.Linear(512, self.n_classes)
self.pos_layer = PPEG(dim=hidden_dimension)
self._fc1 = nn.Sequential(nn.Linear(in_features, hidden_dimension), nn.ReLU())
self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dimension))
self.n_classes = out_features
self.layer1 = TransLayer(dim=hidden_dimension)
self.layer2 = TransLayer(dim=hidden_dimension)
self.norm = nn.LayerNorm(hidden_dimension)
self._fc2 = nn.Linear(hidden_dimension, self.n_classes)

def forward(self, features: torch.Tensor, **kwargs: Any) -> torch.Tensor:
h = features # [B, n, 1024]
h = features # [B, n, in_features]

h = self._fc1(h) # [B, n, 512]
h = self._fc1(h) # [B, n, hidden_dimension]

# ---->pad
H = h.shape[1]
_H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
add_length = _H * _W - H
h = torch.cat([h, h[:, :add_length, :]], dim=1) # [B, N, 512]
h = torch.cat([h, h[:, :add_length, :]], dim=1) # [B, N, hidden_dimension]

# ---->cls_token
B = h.shape[0]
cls_tokens = self.cls_token.expand(B, -1, -1).to(h.device)
h = torch.cat((cls_tokens, h), dim=1)

# ---->Translayer x1
h = self.layer1(h) # [B, N, 512]
h = self.layer1(h) # [B, N, hidden_dimension]

# ---->PPEG
h = self.pos_layer(h, _H, _W) # [B, N, 512]
h = self.pos_layer(h, _H, _W) # [B, N, hidden_dimension]

# ---->Translayer x2
h = self.layer2(h) # [B, N, 512]
h = self.layer2(h) # [B, N, hidden_dimension]

# ---->cls_token
h = self.norm(h)[:, 0]

# ---->predict
logits: torch.Tensor = self._fc2(h) # [B, n_classes]
logits: torch.Tensor = self._fc2(h) # [B, out_features]

return logits
147 changes: 147 additions & 0 deletions ahcore/models/layers/attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import math
from typing import Optional

import torch
import torch.nn.functional as F
from einops import rearrange, reduce
from torch import nn


Expand Down Expand Up @@ -56,3 +61,145 @@ def forward(
return scaled_attention.squeeze(1), attention_weights

return scaled_attention.squeeze(1)


def moore_penrose_iter_pinv(x: torch.Tensor, iters: int = 6) -> torch.Tensor:
device = x.device

abs_x = torch.abs(x)
col = abs_x.sum(dim=-1)
row = abs_x.sum(dim=-2)
z = rearrange(x, "... i j -> ... j i") / (torch.max(col) * torch.max(row))

eye = torch.eye(x.shape[-1], device=device)
eye = rearrange(eye, "i j -> () i j")

for _ in range(iters):
xz = x @ z
z = 0.25 * z @ (13 * eye - (xz @ (15 * eye - (xz @ (7 * eye - xz)))))

return z


# main attention class


class NystromAttention(nn.Module):
def __init__(
self,
dim: int,
dim_head: int = 64,
heads: int = 8,
num_landmarks: int = 256,
pinv_iterations: int = 6,
residual: bool = True,
residual_conv_kernel: int = 33,
eps: float = 1e-8,
dropout: float = 0.0,
) -> None:
super().__init__()
self.eps = eps
inner_dim = heads * dim_head

self.num_landmarks = num_landmarks
self.pinv_iterations = pinv_iterations

self.heads = heads
self.scale = dim_head**-0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))

self.residual = residual
if residual:
kernel_size = residual_conv_kernel
padding = residual_conv_kernel // 2
self.res_conv = nn.Conv2d(heads, heads, (kernel_size, 1), padding=(padding, 0), groups=heads, bias=False)

def forward(
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_attn: bool = False
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
b, n, _ = x.shape
h, m, iters, eps = self.heads, self.num_landmarks, self.pinv_iterations, self.eps
# pad so that sequence can be evenly divided into m landmarks

remainder = n % m
if remainder > 0:
padding = m - (n % m)
x = F.pad(x, (0, 0, padding, 0), value=0)

if mask is not None:
mask = F.pad(mask, (padding, 0), value=False)

# derive query, keys, values

q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))

# set masked positions to 0 in queries, keys, values

if mask is not None:
mask = rearrange(mask, "b n -> b () n")
q, k, v = map(lambda t: t * mask[..., None], (q, k, v))

q = q * self.scale

# generate landmarks by sum reduction, and then calculate mean using the mask

l_dim = math.ceil(n / m)
landmark_einops_eq = "... (n l) d -> ... n d"
q_landmarks = reduce(q, landmark_einops_eq, "sum", l=l_dim)
k_landmarks = reduce(k, landmark_einops_eq, "sum", l=l_dim)

# calculate landmark mask, and also get sum of non-masked elements in preparation for masked mean

if mask is not None:
mask_landmarks_sum = reduce(mask, "... (n l) -> ... n", "sum", l=l_dim)
divisor = mask_landmarks_sum[..., None] + eps
mask_landmarks = mask_landmarks_sum > 0
else:
divisor = torch.Tensor([l_dim]).to(q_landmarks.device)

# masked mean (if mask exists)

q_landmarks = q_landmarks / divisor
k_landmarks = k_landmarks / divisor

# similarities

einops_eq = "... i d, ... j d -> ... i j"
sim1 = torch.einsum(einops_eq, q, k_landmarks)
sim2 = torch.einsum(einops_eq, q_landmarks, k_landmarks)
sim3 = torch.einsum(einops_eq, q_landmarks, k)

# masking

if mask is not None:
mask_value = -torch.finfo(q.dtype).max
sim1.masked_fill_(~(mask[..., None] * mask_landmarks[..., None, :]), mask_value)
sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value)
sim3.masked_fill_(~(mask_landmarks[..., None] * mask[..., None, :]), mask_value)

# eq (15) in the paper and aggregate values

attn1, attn2, attn3 = map(lambda t: t.softmax(dim=-1), (sim1, sim2, sim3))
attn2_inv = moore_penrose_iter_pinv(attn2, iters)

out: torch.Tensor = (attn1 @ attn2_inv) @ (attn3 @ v)

# add depth-wise conv residual of values

if self.residual:
out = out + self.res_conv(v)

# merge and combine heads

out = rearrange(out, "b h n d -> b n (h d)", h=h)
out = self.to_out(out)
out = out[:, -n:]

if return_attn:
attn = attn1 @ attn2_inv @ attn3
return out, attn

return out
Loading

0 comments on commit c8ae154

Please sign in to comment.