diff --git a/ahcore/models/MIL/transmil.py b/ahcore/models/MIL/transmil.py index 7870905..c8a3a5c 100644 --- a/ahcore/models/MIL/transmil.py +++ b/ahcore/models/MIL/transmil.py @@ -1,156 +1,14 @@ # this file includes the original nystrom attention and transmil model # from https://github.com/lucidrains/nystrom-attention/blob/main/nystrom_attention/nystrom_attention.py # and https://github.com/szc19990412/TransMIL/blob/main/models/TransMIL.py, respectively. -from math import ceil -from typing import Any, Optional + +from typing import Any import numpy as np import torch -import torch.nn.functional as F -from einops import rearrange, reduce from torch import nn as nn - -def moore_penrose_iter_pinv(x: torch.Tensor, iters: int = 6) -> torch.Tensor: - device = x.device - - abs_x = torch.abs(x) - col = abs_x.sum(dim=-1) - row = abs_x.sum(dim=-2) - z = rearrange(x, "... i j -> ... j i") / (torch.max(col) * torch.max(row)) - - eye = torch.eye(x.shape[-1], device=device) - eye = rearrange(eye, "i j -> () i j") - - for _ in range(iters): - xz = x @ z - z = 0.25 * z @ (13 * eye - (xz @ (15 * eye - (xz @ (7 * eye - xz))))) - - return z - - -# main attention class - - -class NystromAttention(nn.Module): - def __init__( - self, - dim: int, - dim_head: int = 64, - heads: int = 8, - num_landmarks: int = 256, - pinv_iterations: int = 6, - residual: bool = True, - residual_conv_kernel: int = 33, - eps: float = 1e-8, - dropout: float = 0.0, - ) -> None: - super().__init__() - self.eps = eps - inner_dim = heads * dim_head - - self.num_landmarks = num_landmarks - self.pinv_iterations = pinv_iterations - - self.heads = heads - self.scale = dim_head**-0.5 - self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) - - self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) - - self.residual = residual - if residual: - kernel_size = residual_conv_kernel - padding = residual_conv_kernel // 2 - self.res_conv = nn.Conv2d(heads, heads, (kernel_size, 1), padding=(padding, 0), groups=heads, bias=False) - - def forward( - self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_attn: bool = False - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - b, n, _ = x.shape - h, m, iters, eps = self.heads, self.num_landmarks, self.pinv_iterations, self.eps - # pad so that sequence can be evenly divided into m landmarks - - remainder = n % m - if remainder > 0: - padding = m - (n % m) - x = F.pad(x, (0, 0, padding, 0), value=0) - - if mask is not None: - mask = F.pad(mask, (padding, 0), value=False) - - # derive query, keys, values - - q, k, v = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - - # set masked positions to 0 in queries, keys, values - - if mask is not None: - mask = rearrange(mask, "b n -> b () n") - q, k, v = map(lambda t: t * mask[..., None], (q, k, v)) - - q = q * self.scale - - # generate landmarks by sum reduction, and then calculate mean using the mask - - l_dim = ceil(n / m) - landmark_einops_eq = "... (n l) d -> ... n d" - q_landmarks = reduce(q, landmark_einops_eq, "sum", l=l_dim) - k_landmarks = reduce(k, landmark_einops_eq, "sum", l=l_dim) - - # calculate landmark mask, and also get sum of non-masked elements in preparation for masked mean - - if mask is not None: - mask_landmarks_sum = reduce(mask, "... (n l) -> ... n", "sum", l=l_dim) - divisor = mask_landmarks_sum[..., None] + eps - mask_landmarks = mask_landmarks_sum > 0 - else: - divisor = torch.Tensor([l_dim]).to(q_landmarks.device) - - # masked mean (if mask exists) - - q_landmarks = q_landmarks / divisor - k_landmarks = k_landmarks / divisor - - # similarities - - einops_eq = "... i d, ... j d -> ... i j" - sim1 = torch.einsum(einops_eq, q, k_landmarks) - sim2 = torch.einsum(einops_eq, q_landmarks, k_landmarks) - sim3 = torch.einsum(einops_eq, q_landmarks, k) - - # masking - - if mask is not None: - mask_value = -torch.finfo(q.dtype).max - sim1.masked_fill_(~(mask[..., None] * mask_landmarks[..., None, :]), mask_value) - sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value) - sim3.masked_fill_(~(mask_landmarks[..., None] * mask[..., None, :]), mask_value) - - # eq (15) in the paper and aggregate values - - attn1, attn2, attn3 = map(lambda t: t.softmax(dim=-1), (sim1, sim2, sim3)) - attn2_inv = moore_penrose_iter_pinv(attn2, iters) - - out: torch.Tensor = (attn1 @ attn2_inv) @ (attn3 @ v) - - # add depth-wise conv residual of values - - if self.residual: - out = out + self.res_conv(v) - - # merge and combine heads - - out = rearrange(out, "b h n d -> b n (h d)", h=h) - out = self.to_out(out) - out = out[:, -n:] - - if return_attn: - attn = attn1 @ attn2_inv @ attn3 - return out, attn - - return out +from ahcore.models.layers.attention import NystromAttention class TransLayer(nn.Module): @@ -193,27 +51,27 @@ def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: class TransMIL(nn.Module): - def __init__(self, in_features: int = 1024, n_classes: int = 1) -> None: + def __init__(self, in_features: int = 1024, out_features: int = 1, hidden_dimension: int = 512) -> None: super(TransMIL, self).__init__() - self.pos_layer = PPEG(dim=512) - self._fc1 = nn.Sequential(nn.Linear(in_features, 512), nn.ReLU()) - self.cls_token = nn.Parameter(torch.randn(1, 1, 512)) - self.n_classes = n_classes - self.layer1 = TransLayer(dim=512) - self.layer2 = TransLayer(dim=512) - self.norm = nn.LayerNorm(512) - self._fc2 = nn.Linear(512, self.n_classes) + self.pos_layer = PPEG(dim=hidden_dimension) + self._fc1 = nn.Sequential(nn.Linear(in_features, hidden_dimension), nn.ReLU()) + self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dimension)) + self.n_classes = out_features + self.layer1 = TransLayer(dim=hidden_dimension) + self.layer2 = TransLayer(dim=hidden_dimension) + self.norm = nn.LayerNorm(hidden_dimension) + self._fc2 = nn.Linear(hidden_dimension, self.n_classes) def forward(self, features: torch.Tensor, **kwargs: Any) -> torch.Tensor: - h = features # [B, n, 1024] + h = features # [B, n, in_features] - h = self._fc1(h) # [B, n, 512] + h = self._fc1(h) # [B, n, hidden_dimension] # ---->pad H = h.shape[1] _H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H))) add_length = _H * _W - H - h = torch.cat([h, h[:, :add_length, :]], dim=1) # [B, N, 512] + h = torch.cat([h, h[:, :add_length, :]], dim=1) # [B, N, hidden_dimension] # ---->cls_token B = h.shape[0] @@ -221,18 +79,18 @@ def forward(self, features: torch.Tensor, **kwargs: Any) -> torch.Tensor: h = torch.cat((cls_tokens, h), dim=1) # ---->Translayer x1 - h = self.layer1(h) # [B, N, 512] + h = self.layer1(h) # [B, N, hidden_dimension] # ---->PPEG - h = self.pos_layer(h, _H, _W) # [B, N, 512] + h = self.pos_layer(h, _H, _W) # [B, N, hidden_dimension] # ---->Translayer x2 - h = self.layer2(h) # [B, N, 512] + h = self.layer2(h) # [B, N, hidden_dimension] # ---->cls_token h = self.norm(h)[:, 0] # ---->predict - logits: torch.Tensor = self._fc2(h) # [B, n_classes] + logits: torch.Tensor = self._fc2(h) # [B, out_features] return logits diff --git a/ahcore/models/layers/attention.py b/ahcore/models/layers/attention.py index a29ee1c..369c574 100644 --- a/ahcore/models/layers/attention.py +++ b/ahcore/models/layers/attention.py @@ -1,4 +1,9 @@ +import math +from typing import Optional + import torch +import torch.nn.functional as F +from einops import rearrange, reduce from torch import nn @@ -56,3 +61,145 @@ def forward( return scaled_attention.squeeze(1), attention_weights return scaled_attention.squeeze(1) + + +def moore_penrose_iter_pinv(x: torch.Tensor, iters: int = 6) -> torch.Tensor: + device = x.device + + abs_x = torch.abs(x) + col = abs_x.sum(dim=-1) + row = abs_x.sum(dim=-2) + z = rearrange(x, "... i j -> ... j i") / (torch.max(col) * torch.max(row)) + + eye = torch.eye(x.shape[-1], device=device) + eye = rearrange(eye, "i j -> () i j") + + for _ in range(iters): + xz = x @ z + z = 0.25 * z @ (13 * eye - (xz @ (15 * eye - (xz @ (7 * eye - xz))))) + + return z + + +# main attention class + + +class NystromAttention(nn.Module): + def __init__( + self, + dim: int, + dim_head: int = 64, + heads: int = 8, + num_landmarks: int = 256, + pinv_iterations: int = 6, + residual: bool = True, + residual_conv_kernel: int = 33, + eps: float = 1e-8, + dropout: float = 0.0, + ) -> None: + super().__init__() + self.eps = eps + inner_dim = heads * dim_head + + self.num_landmarks = num_landmarks + self.pinv_iterations = pinv_iterations + + self.heads = heads + self.scale = dim_head**-0.5 + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + + self.residual = residual + if residual: + kernel_size = residual_conv_kernel + padding = residual_conv_kernel // 2 + self.res_conv = nn.Conv2d(heads, heads, (kernel_size, 1), padding=(padding, 0), groups=heads, bias=False) + + def forward( + self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_attn: bool = False + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + b, n, _ = x.shape + h, m, iters, eps = self.heads, self.num_landmarks, self.pinv_iterations, self.eps + # pad so that sequence can be evenly divided into m landmarks + + remainder = n % m + if remainder > 0: + padding = m - (n % m) + x = F.pad(x, (0, 0, padding, 0), value=0) + + if mask is not None: + mask = F.pad(mask, (padding, 0), value=False) + + # derive query, keys, values + + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + # set masked positions to 0 in queries, keys, values + + if mask is not None: + mask = rearrange(mask, "b n -> b () n") + q, k, v = map(lambda t: t * mask[..., None], (q, k, v)) + + q = q * self.scale + + # generate landmarks by sum reduction, and then calculate mean using the mask + + l_dim = math.ceil(n / m) + landmark_einops_eq = "... (n l) d -> ... n d" + q_landmarks = reduce(q, landmark_einops_eq, "sum", l=l_dim) + k_landmarks = reduce(k, landmark_einops_eq, "sum", l=l_dim) + + # calculate landmark mask, and also get sum of non-masked elements in preparation for masked mean + + if mask is not None: + mask_landmarks_sum = reduce(mask, "... (n l) -> ... n", "sum", l=l_dim) + divisor = mask_landmarks_sum[..., None] + eps + mask_landmarks = mask_landmarks_sum > 0 + else: + divisor = torch.Tensor([l_dim]).to(q_landmarks.device) + + # masked mean (if mask exists) + + q_landmarks = q_landmarks / divisor + k_landmarks = k_landmarks / divisor + + # similarities + + einops_eq = "... i d, ... j d -> ... i j" + sim1 = torch.einsum(einops_eq, q, k_landmarks) + sim2 = torch.einsum(einops_eq, q_landmarks, k_landmarks) + sim3 = torch.einsum(einops_eq, q_landmarks, k) + + # masking + + if mask is not None: + mask_value = -torch.finfo(q.dtype).max + sim1.masked_fill_(~(mask[..., None] * mask_landmarks[..., None, :]), mask_value) + sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value) + sim3.masked_fill_(~(mask_landmarks[..., None] * mask[..., None, :]), mask_value) + + # eq (15) in the paper and aggregate values + + attn1, attn2, attn3 = map(lambda t: t.softmax(dim=-1), (sim1, sim2, sim3)) + attn2_inv = moore_penrose_iter_pinv(attn2, iters) + + out: torch.Tensor = (attn1 @ attn2_inv) @ (attn3 @ v) + + # add depth-wise conv residual of values + + if self.residual: + out = out + self.res_conv(v) + + # merge and combine heads + + out = rearrange(out, "b h n d -> b n (h d)", h=h) + out = self.to_out(out) + out = out[:, -n:] + + if return_attn: + attn = attn1 @ attn2_inv @ attn3 + return out, attn + + return out diff --git a/tests/test_models/test_models.py b/tests/test_models/test_models.py index edf8f21..9f3b487 100644 --- a/tests/test_models/test_models.py +++ b/tests/test_models/test_models.py @@ -1,6 +1,8 @@ import pytest import torch +from ahcore.models.layers.attention import GatedAttention, NystromAttention +from ahcore.models.layers.MLP import MLP from ahcore.models.MIL.ABmil import ABMIL from ahcore.models.MIL.transmil import TransMIL @@ -15,8 +17,38 @@ def test_ABmil_shape(input_data: torch.Tensor) -> None: output = model(input_data) assert output.shape == (16, 1) + output, attentions = model(input_data, return_attention_weights=True) + assert output.shape == (16, 1) + assert attentions.shape == (16, 1000, 1) + def test_TransMIL_shape(input_data: torch.Tensor) -> None: - model = TransMIL(in_features=768, n_classes=2) + model = TransMIL(in_features=768, out_features=2) output = model(input_data) assert output.shape == (16, 2) + + +def test_MLP_shape(input_data: torch.Tensor) -> None: + model = MLP(in_features=768, out_features=2, hidden=[128], dropout=[0.1]) + output = model(input_data) + assert output.shape == (16, 1000, 2) + + +def test_MLP_hidden_dropout() -> None: + with pytest.raises(ValueError): + MLP(in_features=768, out_features=2, hidden=None, dropout=[0.1]) + + +def test_attention_shape(input_data: torch.Tensor) -> None: + model = GatedAttention(dim=768) + output = model(input_data) + assert output.shape == (16, 768) + + +def test_nystrom_att_with_mask(input_data: torch.Tensor) -> None: + model = NystromAttention( + dim=768, dim_head=768 // 8, heads=8, num_landmarks=1, pinv_iterations=6, residual=True, dropout=0.1 + ) + output, attn = model(input_data, mask=torch.ones_like(input_data, dtype=torch.bool)[:, :, 0], return_attn=True) + assert output.shape == (16, 1000, 768) + assert attn.shape == (16, 8, 1000, 1000)