diff --git a/ahcore/models/MIL/ABmil.py b/ahcore/models/MIL/ABmil.py new file mode 100644 index 0000000..056f4f1 --- /dev/null +++ b/ahcore/models/MIL/ABmil.py @@ -0,0 +1,133 @@ +from typing import List, Optional + +import torch +from torch import nn + +from ahcore.models.layers.attention import GatedAttention +from ahcore.models.layers.MLP import MLP + + +class ABMIL(nn.Module): + """ + Attention-based MIL (Multiple Instance Learning) classification model (See [1]_). + This model is adapted from + https://github.com/owkin/HistoSSLscaling/blob/main/rl_benchmarks/models/slide_models/abmil.py. + It uses an attention mechanism to aggregate features from multiple instances (tiles) into a single prediction. + + Methods + ------- + forward(features: torch.Tensor, return_attention_weights: bool = False) + -> torch.Tensor | tuple[torch.Tensor, torch.Tensor] + Forward pass of the ABMIL model. + + References + ---------- + .. [1] Maximilian Ilse, Jakub Tomczak, and Max Welling. Attention-based + deep multiple instance learning. In Jennifer Dy and Andreas Krause, + editors, Proceedings of the 35th International Conference on Machine + Learning, volume 80 of Proceedings of Machine Learning Research, + pages 2127–2136. PMLR, 10–15 Jul 2018. + + """ + + def __init__( + self, + in_features: int, + num_classes: int = 1, + attention_dimension: int = 128, + temperature: float = 1.0, + embed_mlp_hidden: Optional[List[int]] = None, + embed_mlp_dropout: Optional[List[float]] = None, + embed_mlp_activation: Optional[torch.nn.Module] = nn.ReLU(), + embed_mlp_bias: bool = True, + classifier_hidden: Optional[List[int]] = [128, 64], + classifier_dropout: Optional[List[float]] = None, + classifier_activation: Optional[torch.nn.Module] = nn.ReLU(), + classifier_bias: bool = False, + ) -> None: + """ + Initializes the ABMIL model with embedding and classification layers. + + Parameters + ---------- + in_features : int + Number of input features for each tile. + out_features : int, optional + Number of output features (typically 1 for binary classification), by default 1. + attention_dimension : int, optional + Dimensionality of the attention mechanism, by default 128. + temperature : float, optional + Temperature parameter for scaling the attention scores, by default 1.0. + embed_mlp_hidden : Optional[List[int]], optional + List of hidden layer sizes for the embedding MLP, by default None. + embed_mlp_dropout : Optional[List[float]], optional + List of dropout rates for the embedding MLP, by default None. + embed_mlp_activation : Optional[torch.nn.Module], optional + Activation function for the embedding MLP, by default nn.ReLU(). + embed_mlp_bias : bool, optional + Whether to include bias in the embedding MLP layers, by default True. + classifier_hidden : Optional[List[int]], optional + List of hidden layer sizes for the classifier MLP, by default [128, 64]. + classifier_dropout : Optional[List[float]], optional + List of dropout rates for the classifier MLP, by default None. + classifier_activation : Optional[torch.nn.Module], optional + Activation function for the classifier MLP, by default nn.ReLU(). + classifier_bias : bool, optional + Whether to include bias in the classifier MLP layers, by default False. + + """ + super(ABMIL, self).__init__() + + self.embed_mlp = MLP( + in_features=in_features, + hidden=embed_mlp_hidden, + bias=embed_mlp_bias, + out_features=attention_dimension, + dropout=embed_mlp_dropout, + activation=embed_mlp_activation, + ) + + self.attention_layer = GatedAttention(dim=attention_dimension, temperature=temperature) + + self.classifier = MLP( + in_features=attention_dimension, + out_features=num_classes, + bias=classifier_bias, + hidden=classifier_hidden, + dropout=classifier_dropout, + activation=classifier_activation, + ) + + def forward( + self, + features: torch.Tensor, + return_attention_weights: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass of the ABMIL model. + + Parameters + ---------- + features : torch.Tensor + Input tensor of shape (batch_size, n_tiles, in_features) representing the features of tiles. + return_attention : bool, optional + If True, also returns the attention weights, by default False. + + Returns + ------- + torch.Tensor + Logits representing the model's output. + torch.Tensor, optional + Attention weights, returned if return_attention is True. + + """ + tiles_emb = self.embed_mlp(features) # BxN_tilesxN_features --> BxN_tilesx128 + scaled_tiles_emb, attention_weights = self.attention_layer( + tiles_emb, return_attention_weights=True + ) # BxN_tilesx128 --> Bx128 + logits: torch.Tensor = self.classifier(scaled_tiles_emb) # Bx128 --> Bx1 + + if return_attention_weights: + return logits, attention_weights + + return logits diff --git a/ahcore/models/MIL/transmil.py b/ahcore/models/MIL/transmil.py new file mode 100644 index 0000000..27cef7e --- /dev/null +++ b/ahcore/models/MIL/transmil.py @@ -0,0 +1,96 @@ +# this file includes the original nystrom attention and transmil model +# from https://github.com/lucidrains/nystrom-attention/blob/main/nystrom_attention/nystrom_attention.py +# and https://github.com/szc19990412/TransMIL/blob/main/models/TransMIL.py, respectively. + +from typing import Any + +import numpy as np +import torch +from torch import nn as nn + +from ahcore.models.layers.attention import NystromAttention + + +class TransLayer(nn.Module): + def __init__(self, norm_layer: type = nn.LayerNorm, dim: int = 512) -> None: + super().__init__() + self.norm = norm_layer(dim) + self.attn = NystromAttention( + dim=dim, + dim_head=dim // 8, + heads=8, + num_landmarks=dim // 2, # number of landmarks + pinv_iterations=6, + # number of moore-penrose iterations for approximating pinverse. 6 was recommended by the paper + residual=True, + # whether to do an extra residual with the value or not. supposedly faster convergence if turned on + dropout=0.1, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.attn(self.norm(x)) + + return x + + +class PPEG(nn.Module): + def __init__(self, dim: int = 512) -> None: + super(PPEG, self).__init__() + self.proj = nn.Conv2d(dim, dim, 7, 1, 7 // 2, groups=dim) + self.proj1 = nn.Conv2d(dim, dim, 5, 1, 5 // 2, groups=dim) + self.proj2 = nn.Conv2d(dim, dim, 3, 1, 3 // 2, groups=dim) + + def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: + B, _, C = x.shape + cls_token, feat_token = x[:, 0], x[:, 1:] + cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W) + x = self.proj(cnn_feat) + cnn_feat + self.proj1(cnn_feat) + self.proj2(cnn_feat) + x = x.flatten(2).transpose(1, 2) + x = torch.cat((cls_token.unsqueeze(1), x), dim=1) + return x + + +class TransMIL(nn.Module): + def __init__(self, in_features: int = 1024, num_classes: int = 1, hidden_dimension: int = 512) -> None: + super(TransMIL, self).__init__() + self.pos_layer = PPEG(dim=hidden_dimension) + self._fc1 = nn.Sequential(nn.Linear(in_features, hidden_dimension), nn.ReLU()) + self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dimension)) + self.n_classes = num_classes + self.layer1 = TransLayer(dim=hidden_dimension) + self.layer2 = TransLayer(dim=hidden_dimension) + self.norm = nn.LayerNorm(hidden_dimension) + self._fc2 = nn.Linear(hidden_dimension, self.n_classes) + + def forward(self, features: torch.Tensor, **kwargs: Any) -> torch.Tensor: + h = features # [B, n, in_features] + + h = self._fc1(h) # [B, n, hidden_dimension] + + # ---->pad + H = h.shape[1] + _H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H))) + add_length = _H * _W - H + h = torch.cat([h, h[:, :add_length, :]], dim=1) # [B, N, hidden_dimension] + + # ---->cls_token + B = h.shape[0] + cls_tokens = self.cls_token.expand(B, -1, -1).to(h.device) + h = torch.cat((cls_tokens, h), dim=1) + + # ---->Translayer x1 + h = self.layer1(h) # [B, N, hidden_dimension] + + # ---->PPEG + h = self.pos_layer(h, _H, _W) # [B, N, hidden_dimension] + + # ---->Translayer x2 + h = self.layer2(h) # [B, N, hidden_dimension] + + # ---->cls_token + h = self.norm(h)[:, 0] + + # ---->predict + logits: torch.Tensor = self._fc2(h) # [B, out_features] + + return logits diff --git a/ahcore/models/base_jit_model.py b/ahcore/models/base_jit_model.py index f21bb9b..502fb48 100644 --- a/ahcore/models/base_jit_model.py +++ b/ahcore/models/base_jit_model.py @@ -1,8 +1,8 @@ from pathlib import Path from typing import Any +from torch import nn from torch.jit import ScriptModule, load -from torch.nn import Module class BaseAhcoreJitModel(ScriptModule): @@ -46,7 +46,7 @@ def from_jit_path(cls, jit_path: Path, output_mode: str) -> Any: model = load(jit_path) # type: ignore return cls(model) - def extend_model(self, modules: dict[str, Module]) -> None: + def extend_model(self, modules: dict[str, nn.Module]) -> None: """ Add modules to a jit compiled model. diff --git a/ahcore/models/layers/MLP.py b/ahcore/models/layers/MLP.py new file mode 100644 index 0000000..3bffe41 --- /dev/null +++ b/ahcore/models/layers/MLP.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from torch import nn + +"""Most of this stuff is adapted from utils from https://github.com/owkin/HistoSSLscaling/tree/main""" + + +class MLP(nn.Sequential): + """MLP Module. + + Parameters + ---------- + in_features: int + Features (model input) dimension. + out_features: int = 1 + Prediction (model output) dimension. + hidden: Optional[List[int]] = None + Dimension of hidden layer(s). + dropout: Optional[List[float]] = None + Dropout rate(s). + activation: Optional[torch.nn.Module] = torch.nn.Sigmoid + MLP activation. + bias: bool = True + Add bias to MLP hidden layers. + + Raises + ------ + ValueError + If ``hidden`` and ``dropout`` do not share the same length. + """ + + def __init__( + self, + in_features: int, + out_features: int, + hidden: Optional[List[int]] = None, + dropout: Optional[List[float]] = None, + activation: Optional[nn.Module] = nn.ReLU(), + bias: bool = True, + ): + if dropout is not None: + if hidden is not None: + assert len(hidden) == len(dropout), "hidden and dropout must have the same length" + else: + raise ValueError("hidden must have a value and have the same length as dropout if dropout is given.") + + d_model = in_features + layers: list[nn.Module] = [] + + if hidden is not None: + for i, h in enumerate(hidden): + seq: list[nn.Module] = [nn.Linear(d_model, h, bias=bias)] + d_model = h + + if activation is not None: + seq.append(activation) + + if dropout is not None: + seq.append(nn.Dropout(dropout[i])) + + layers.append(nn.Sequential(*seq)) + + layers.append(nn.Linear(d_model, out_features)) + + super(MLP, self).__init__(*layers) diff --git a/ahcore/models/layers/attention.py b/ahcore/models/layers/attention.py new file mode 100644 index 0000000..369c574 --- /dev/null +++ b/ahcore/models/layers/attention.py @@ -0,0 +1,205 @@ +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange, reduce +from torch import nn + + +class GatedAttention(nn.Module): + """Gated Attention, as defined in https://arxiv.org/abs/1802.04712. + Permutation invariant Layer on dim 1. + Parameters + ---------- + d_model: int = 128 + temperature: float = 1.0 + Attention Softmax temperature + """ + + def __init__( + self, + dim: int = 128, + temperature: float = 1.0, + ): + super(GatedAttention, self).__init__() + + self.V = nn.Linear(dim, dim) + self.U = nn.Linear(dim, dim) + self.w = nn.Linear(dim, 1) + + self.temperature = temperature + + def forward( + self, + features: torch.Tensor, + return_attention_weights: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """Forward pass. + Parameters + ---------- + features: torch.Tensor + (B, SEQ_LEN, IN_FEATURES) + return_attention_weights: bool = False + + Returns + ------- + scaled_attention, attention_weights: Tuple[torch.Tensor, torch.Tensor] + (B, IN_FEATURES), (B, N_TILES, 1) + """ + h_v = torch.tanh(self.U(features)) + + u_v = torch.sigmoid(self.V(features)) + + attention_logits = self.w(h_v * u_v) / self.temperature + + attention_weights = torch.softmax(attention_logits, 1) + + scaled_attention = torch.matmul(attention_weights.transpose(1, 2), features) + + if return_attention_weights: + return scaled_attention.squeeze(1), attention_weights + + return scaled_attention.squeeze(1) + + +def moore_penrose_iter_pinv(x: torch.Tensor, iters: int = 6) -> torch.Tensor: + device = x.device + + abs_x = torch.abs(x) + col = abs_x.sum(dim=-1) + row = abs_x.sum(dim=-2) + z = rearrange(x, "... i j -> ... j i") / (torch.max(col) * torch.max(row)) + + eye = torch.eye(x.shape[-1], device=device) + eye = rearrange(eye, "i j -> () i j") + + for _ in range(iters): + xz = x @ z + z = 0.25 * z @ (13 * eye - (xz @ (15 * eye - (xz @ (7 * eye - xz))))) + + return z + + +# main attention class + + +class NystromAttention(nn.Module): + def __init__( + self, + dim: int, + dim_head: int = 64, + heads: int = 8, + num_landmarks: int = 256, + pinv_iterations: int = 6, + residual: bool = True, + residual_conv_kernel: int = 33, + eps: float = 1e-8, + dropout: float = 0.0, + ) -> None: + super().__init__() + self.eps = eps + inner_dim = heads * dim_head + + self.num_landmarks = num_landmarks + self.pinv_iterations = pinv_iterations + + self.heads = heads + self.scale = dim_head**-0.5 + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + + self.residual = residual + if residual: + kernel_size = residual_conv_kernel + padding = residual_conv_kernel // 2 + self.res_conv = nn.Conv2d(heads, heads, (kernel_size, 1), padding=(padding, 0), groups=heads, bias=False) + + def forward( + self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_attn: bool = False + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + b, n, _ = x.shape + h, m, iters, eps = self.heads, self.num_landmarks, self.pinv_iterations, self.eps + # pad so that sequence can be evenly divided into m landmarks + + remainder = n % m + if remainder > 0: + padding = m - (n % m) + x = F.pad(x, (0, 0, padding, 0), value=0) + + if mask is not None: + mask = F.pad(mask, (padding, 0), value=False) + + # derive query, keys, values + + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + # set masked positions to 0 in queries, keys, values + + if mask is not None: + mask = rearrange(mask, "b n -> b () n") + q, k, v = map(lambda t: t * mask[..., None], (q, k, v)) + + q = q * self.scale + + # generate landmarks by sum reduction, and then calculate mean using the mask + + l_dim = math.ceil(n / m) + landmark_einops_eq = "... (n l) d -> ... n d" + q_landmarks = reduce(q, landmark_einops_eq, "sum", l=l_dim) + k_landmarks = reduce(k, landmark_einops_eq, "sum", l=l_dim) + + # calculate landmark mask, and also get sum of non-masked elements in preparation for masked mean + + if mask is not None: + mask_landmarks_sum = reduce(mask, "... (n l) -> ... n", "sum", l=l_dim) + divisor = mask_landmarks_sum[..., None] + eps + mask_landmarks = mask_landmarks_sum > 0 + else: + divisor = torch.Tensor([l_dim]).to(q_landmarks.device) + + # masked mean (if mask exists) + + q_landmarks = q_landmarks / divisor + k_landmarks = k_landmarks / divisor + + # similarities + + einops_eq = "... i d, ... j d -> ... i j" + sim1 = torch.einsum(einops_eq, q, k_landmarks) + sim2 = torch.einsum(einops_eq, q_landmarks, k_landmarks) + sim3 = torch.einsum(einops_eq, q_landmarks, k) + + # masking + + if mask is not None: + mask_value = -torch.finfo(q.dtype).max + sim1.masked_fill_(~(mask[..., None] * mask_landmarks[..., None, :]), mask_value) + sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value) + sim3.masked_fill_(~(mask_landmarks[..., None] * mask[..., None, :]), mask_value) + + # eq (15) in the paper and aggregate values + + attn1, attn2, attn3 = map(lambda t: t.softmax(dim=-1), (sim1, sim2, sim3)) + attn2_inv = moore_penrose_iter_pinv(attn2, iters) + + out: torch.Tensor = (attn1 @ attn2_inv) @ (attn3 @ v) + + # add depth-wise conv residual of values + + if self.residual: + out = out + self.res_conv(v) + + # merge and combine heads + + out = rearrange(out, "b h n d -> b n (h d)", h=h) + out = self.to_out(out) + out = out[:, -n:] + + if return_attn: + attn = attn1 @ attn2_inv @ attn3 + return out, attn + + return out diff --git a/tests/test_models/test_models.py b/tests/test_models/test_models.py new file mode 100644 index 0000000..139207b --- /dev/null +++ b/tests/test_models/test_models.py @@ -0,0 +1,54 @@ +import pytest +import torch + +from ahcore.models.layers.attention import GatedAttention, NystromAttention +from ahcore.models.layers.MLP import MLP +from ahcore.models.MIL.ABmil import ABMIL +from ahcore.models.MIL.transmil import TransMIL + + +@pytest.fixture +def input_data(B: int = 16, N_tiles: int = 1000, feature_dim: int = 768) -> torch.Tensor: + return torch.randn(B, N_tiles, feature_dim) + + +def test_ABmil_shape(input_data: torch.Tensor) -> None: + model = ABMIL(in_features=768) + output = model(input_data) + assert output.shape == (16, 1) + + output, attentions = model(input_data, return_attention_weights=True) + assert output.shape == (16, 1) + assert attentions.shape == (16, 1000, 1) + + +def test_TransMIL_shape(input_data: torch.Tensor) -> None: + model = TransMIL(in_features=768, num_classes=2) + output = model(input_data) + assert output.shape == (16, 2) + + +def test_MLP_shape(input_data: torch.Tensor) -> None: + model = MLP(in_features=768, out_features=2, hidden=[128], dropout=[0.1]) + output = model(input_data) + assert output.shape == (16, 1000, 2) + + +def test_MLP_hidden_dropout() -> None: + with pytest.raises(ValueError): + MLP(in_features=768, out_features=2, hidden=None, dropout=[0.1]) + + +def test_attention_shape(input_data: torch.Tensor) -> None: + model = GatedAttention(dim=768) + output = model(input_data) + assert output.shape == (16, 768) + + +def test_nystrom_att_with_mask(input_data: torch.Tensor) -> None: + model = NystromAttention( + dim=768, dim_head=768 // 8, heads=8, num_landmarks=1, pinv_iterations=6, residual=True, dropout=0.1 + ) + output, attn = model(input_data, mask=torch.ones_like(input_data, dtype=torch.bool)[:, :, 0], return_attn=True) + assert output.shape == (16, 1000, 768) + assert attn.shape == (16, 8, 1000, 1000)