From dbedb6111d427e10485ee50fe52ea39c1f78dfe8 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Thu, 8 Aug 2024 16:36:25 +0200 Subject: [PATCH 01/14] Added ABmil and transmil implementations from other project --- ahcore/models/MIL/ABmil.py | 229 +++++++++++++++++++++++++++++++ ahcore/models/MIL/transmil.py | 246 ++++++++++++++++++++++++++++++++++ 2 files changed, 475 insertions(+) create mode 100644 ahcore/models/MIL/ABmil.py create mode 100644 ahcore/models/MIL/transmil.py diff --git a/ahcore/models/MIL/ABmil.py b/ahcore/models/MIL/ABmil.py new file mode 100644 index 0000000..8cbe1f0 --- /dev/null +++ b/ahcore/models/MIL/ABmil.py @@ -0,0 +1,229 @@ +from src.utils.model_utils import MLP, MaskedMLP, GatedAttention + +from typing import List, Optional + +import torch +from torch import nn +import torch.nn.init as init + +class ABMIL(nn.Module): + """Attention-based MIL classification model (See [1]_). + Adapted from https://github.com/owkin/HistoSSLscaling/blob/main/rl_benchmarks/models/slide_models/abmil.py + + Example: + >>> module = ABMIL(in_features=128, out_features=1) + >>> logits, attention_scores = module(slide, mask=mask) + >>> attention_scores = module.score_model(slide, mask=mask) + + Parameters + ---------- + in_features: int + Features (model input) dimension. + out_features: int = 1 + Prediction (model output) dimension. + d_model_attention: int = 128 + Dimension of attention scores. + temperature: float = 1.0 + GatedAttention softmax temperature. + tiles_mlp_hidden: Optional[List[int]] = None + Dimension of hidden layers in first MLP. + mlp_hidden: Optional[List[int]] = None + Dimension of hidden layers in last MLP. + mlp_dropout: Optional[List[float]] = None, + Dropout rate for last MLP. + mlp_activation: Optional[torch.nn.Module] = torch.nn.Sigmoid + Activation for last MLP. + bias: bool = True + Add bias to the first MLP. + metadata_cols: int = 3 + Number of metadata columns (for example, magnification, patch start + coordinates etc.) at the start of input data. Default of 3 assumes + that the first 3 columns of input data are, respectively: + 1) Deep zoom level, corresponding to a given magnification + 2) input patch starting x value + 3) input patch starting y value + + References + ---------- + .. [1] Maximilian Ilse, Jakub Tomczak, and Max Welling. Attention-based + deep multiple instance learning. In Jennifer Dy and Andreas Krause, + editors, Proceedings of the 35th International Conference on Machine + Learning, volume 80 of Proceedings of Machine Learning Research, + pages 2127–2136. PMLR, 10–15 Jul 2018. + + """ + + def __init__( + self, + in_features: int, + out_features: int = 1, + number_of_tiles: int = 1000, + d_model_attention: int = 128, + temperature: float = 1.0, + masked_mlp_hidden: Optional[List[int]] = None, + masked_mlp_dropout: Optional[List[float]] = None, + masked_mlp_activation: Optional[torch.nn.Module] = nn.Sigmoid(), + mlp_hidden: Optional[List[int]] = [128, 64], + mlp_dropout: Optional[List[float]] = None, + mlp_activation: Optional[torch.nn.Module] = nn.Sigmoid(), + bias: bool = True, + use_positional_encoding: bool = False, + ) -> None: + super(ABMIL, self).__init__() + + if mlp_dropout is not None: + if mlp_hidden is not None: + assert len(mlp_hidden) == len( + mlp_dropout + ), "mlp_hidden and mlp_dropout must have the same length" + else: + raise ValueError( + "mlp_hidden must have a value and have the same length" + "as mlp_dropout if mlp_dropout is given." + ) + + self.embed_mlp = MLP( + in_features=in_features, + hidden=masked_mlp_hidden, + bias=bias, + out_features=d_model_attention, + dropout=masked_mlp_dropout, + activation=masked_mlp_activation, + ) + + self.attention_layer = GatedAttention( + d_model=d_model_attention, temperature=temperature + ) + + mlp_in_features = d_model_attention + + self.mlp = MLP( + in_features=mlp_in_features, + out_features=out_features, + hidden=mlp_hidden, + dropout=mlp_dropout, + activation=mlp_activation, + ) + + self.use_positional_encoding = use_positional_encoding + + if self.use_positional_encoding: + # TODO this should also add some interpolation + self.positional_encoding = nn.Parameter(torch.zeros(1, number_of_tiles, in_features)) + init.trunc_normal_(self.positional_encoding, mean=0.0, std=0.02, a=-2.0, b=2.0) + self.positional_encoding.requires_grad = True + + def interpolate_positional_encoding(self, coordinates: torch.Tensor): + """ + Perform bilinear interpolation using the given coordinates on the positional encoding. + The positional encoding is considered as a flattened array representing a (h, w) grid. + + Args: + coordinates (torch.Tensor): The normalized coordinates tensor of shape (batch_size, 2), + where each row is (x, y) in normalized coordinates [0, 1]. + + Returns: + torch.Tensor: The interpolated features from the positional encoding. + """ + # Scale coordinates to the range of the positional encoding indices + max_idx = int(torch.sqrt(torch.tensor([self.positional_encoding.shape[1]]))) - 1 + scaled_coordinates = max_idx * coordinates + + # Separate scaled coordinates into x and y components + x = scaled_coordinates[..., 0] + y = scaled_coordinates[..., 1] + + + # Get integer parts of coordinates + x0 = torch.floor(x).int() + x1 = x0 + 1 + y0 = torch.floor(y).int() + y1 = y0 + 1 + + # Clamp indices to ensure they remain within valid range + x0 = torch.clamp(x0, 0, max_idx) + x1 = torch.clamp(x1, 0, max_idx) + y0 = torch.clamp(y0, 0, max_idx) + y1 = torch.clamp(y1, 0, max_idx) + + # Calculate linear indices + idx_q11 = y0 * max_idx + x0 + idx_q12 = y1 * max_idx + x0 + idx_q21 = y0 * max_idx + x1 + idx_q22 = y1 * max_idx + x1 + + # Fetch the corner points + q11 = self.positional_encoding[0, idx_q11, :] + q12 = self.positional_encoding[0, idx_q12, :] + q21 = self.positional_encoding[0, idx_q21, :] + q22 = self.positional_encoding[0, idx_q22, :] + + # Compute fractional part for interpolation + x_frac = x - x0.float() + y_frac = y - y0.float() + + # Bilinear interpolation + interpolated_positional_encoding = (q11 * (1 - x_frac).unsqueeze(2) * (1 - y_frac).unsqueeze(2) + + q12 * (1 - x_frac).unsqueeze(2) * y_frac.unsqueeze(2) + + q21 * x_frac.unsqueeze(2) * (1 - y_frac).unsqueeze(2) + + q22 * x_frac.unsqueeze(2) * y_frac.unsqueeze(2)) + + return interpolated_positional_encoding + + def get_attention( + self, x: torch.Tensor, mask: Optional[torch.BoolTensor] = None, coordinates: torch.Tensor = None, + ) -> torch.Tensor: + """Get attention logits. + + Parameters + ---------- + x: torch.Tensor + (B, N_TILES, FEATURES) + mask: Optional[torch.BoolTensor] + (B, N_TILES, 1), True for values that were padded. + + Returns + ------- + attention_logits: torch.Tensor + (B, N_TILES, 1) + """ + if self.use_positional_encoding: + positional_encoding = self.interpolate_positional_encoding(coordinates) + x = x + positional_encoding + + tiles_emb = self.tiles_emb(x, mask) + attention_weights = self.attention_layer.attention(tiles_emb, mask) + return attention_weights + + def forward( + self, features: torch.Tensor, mask: Optional[torch.BoolTensor] = None, coordinates: torch.Tensor = None, return_attention: bool = False, + ) -> torch.Tensor: + """ + Parameters + ---------- + coordinates + features: torch.Tensor + (B, N_TILES, D+3) + mask: Optional[torch.BoolTensor] + (B, N_TILES, 1), True for values that were padded. + + Returns + ------- + logits, attention_weights: Tuple[torch.Tensor, torch.Tensor] + (B, OUT_FEATURES), (B, N_TILES) + """ + if coordinates is None and self.positional_encoding: + raise ValueError(f"Coordinates of NoneType are not accepted if positional_encoding is used") + + if self.use_positional_encoding: + positional_encoding = self.interpolate_positional_encoding(coordinates) + features = features + positional_encoding + + tiles_emb = self.embed_mlp(features) # BxN_tilesxN_features --> BxN_tilesx128 + scaled_tiles_emb, attention_weights = self.attention_layer(tiles_emb, mask) # BxN_tilesx128 --> Bx128 + logits = self.mlp(scaled_tiles_emb) # Bx128 --> Bx1 + + if return_attention: + return logits, attention_weights + + return logits \ No newline at end of file diff --git a/ahcore/models/MIL/transmil.py b/ahcore/models/MIL/transmil.py new file mode 100644 index 0000000..716b3bc --- /dev/null +++ b/ahcore/models/MIL/transmil.py @@ -0,0 +1,246 @@ +# this file includes the original nystrom attention and transmil model from https://github.com/lucidrains/nystrom-attention/blob/main/nystrom_attention/nystrom_attention.py and https://github.com/szc19990412/TransMIL/blob/main/models/TransMIL.py, respectively. + + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from math import ceil +import torch +from torch import nn, einsum +import torch.nn.functional as F + +from einops import rearrange, reduce + +# helper functions + +def exists(val): + return val is not None + +def moore_penrose_iter_pinv(x, iters = 6): + device = x.device + + abs_x = torch.abs(x) + col = abs_x.sum(dim = -1) + row = abs_x.sum(dim = -2) + z = rearrange(x, '... i j -> ... j i') / (torch.max(col) * torch.max(row)) + + I = torch.eye(x.shape[-1], device = device) + I = rearrange(I, 'i j -> () i j') + + for _ in range(iters): + xz = x @ z + z = 0.25 * z @ (13 * I - (xz @ (15 * I - (xz @ (7 * I - xz))))) + + return z + +# main attention class + +class NystromAttention(nn.Module): + def __init__( + self, + dim, + dim_head = 64, + heads = 8, + num_landmarks = 256, + pinv_iterations = 6, + residual = True, + residual_conv_kernel = 33, + eps = 1e-8, + dropout = 0. + ): + super().__init__() + self.eps = eps + inner_dim = heads * dim_head + + self.num_landmarks = num_landmarks + self.pinv_iterations = pinv_iterations + + self.heads = heads + self.scale = dim_head ** -0.5 + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) + + self.residual = residual + if residual: + kernel_size = residual_conv_kernel + padding = residual_conv_kernel // 2 + self.res_conv = nn.Conv2d(heads, heads, (kernel_size, 1), padding = (padding, 0), groups = heads, bias = False) + + def forward(self, x, mask = None, return_attn = False): + b, n, _, h, m, iters, eps = *x.shape, self.heads, self.num_landmarks, self.pinv_iterations, self.eps + + # pad so that sequence can be evenly divided into m landmarks + + remainder = n % m + if remainder > 0: + padding = m - (n % m) + x = F.pad(x, (0, 0, padding, 0), value = 0) + + if exists(mask): + mask = F.pad(mask, (padding, 0), value = False) + + # derive query, keys, values + + q, k, v = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + + # set masked positions to 0 in queries, keys, values + + if exists(mask): + mask = rearrange(mask, 'b n -> b () n') + q, k, v = map(lambda t: t * mask[..., None], (q, k, v)) + + q = q * self.scale + + # generate landmarks by sum reduction, and then calculate mean using the mask + + l = ceil(n / m) + landmark_einops_eq = '... (n l) d -> ... n d' + q_landmarks = reduce(q, landmark_einops_eq, 'sum', l = l) + k_landmarks = reduce(k, landmark_einops_eq, 'sum', l = l) + + # calculate landmark mask, and also get sum of non-masked elements in preparation for masked mean + + divisor = l + if exists(mask): + mask_landmarks_sum = reduce(mask, '... (n l) -> ... n', 'sum', l = l) + divisor = mask_landmarks_sum[..., None] + eps + mask_landmarks = mask_landmarks_sum > 0 + + # masked mean (if mask exists) + + q_landmarks = q_landmarks / divisor + k_landmarks = k_landmarks / divisor + + # similarities + + einops_eq = '... i d, ... j d -> ... i j' + sim1 = einsum(einops_eq, q, k_landmarks) + sim2 = einsum(einops_eq, q_landmarks, k_landmarks) + sim3 = einsum(einops_eq, q_landmarks, k) + + # masking + + if exists(mask): + mask_value = -torch.finfo(q.dtype).max + sim1.masked_fill_(~(mask[..., None] * mask_landmarks[..., None, :]), mask_value) + sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value) + sim3.masked_fill_(~(mask_landmarks[..., None] * mask[..., None, :]), mask_value) + + # eq (15) in the paper and aggregate values + + attn1, attn2, attn3 = map(lambda t: t.softmax(dim = -1), (sim1, sim2, sim3)) + attn2_inv = moore_penrose_iter_pinv(attn2, iters) + + out = (attn1 @ attn2_inv) @ (attn3 @ v) + + # add depth-wise conv residual of values + + if self.residual: + out = out + self.res_conv(v) + + # merge and combine heads + + out = rearrange(out, 'b h n d -> b n (h d)', h = h) + out = self.to_out(out) + out = out[:, -n:] + + if return_attn: + attn = attn1 @ attn2_inv @ attn3 + return out, attn + + return out + + +class TransLayer(nn.Module): + + def __init__(self, norm_layer=nn.LayerNorm, dim=512): + super().__init__() + self.norm = norm_layer(dim) + self.attn = NystromAttention( + dim=dim, + dim_head=dim // 8, + heads=8, + num_landmarks=dim // 2, # number of landmarks + pinv_iterations=6, + # number of moore-penrose iterations for approximating pinverse. 6 was recommended by the paper + residual=True, + # whether to do an extra residual with the value or not. supposedly faster convergence if turned on + dropout=0.1 + ) + + def forward(self, x): + x = x + self.attn(self.norm(x)) + + return x + + +class PPEG(nn.Module): + def __init__(self, dim=512): + super(PPEG, self).__init__() + self.proj = nn.Conv2d(dim, dim, 7, 1, 7 // 2, groups=dim) + self.proj1 = nn.Conv2d(dim, dim, 5, 1, 5 // 2, groups=dim) + self.proj2 = nn.Conv2d(dim, dim, 3, 1, 3 // 2, groups=dim) + + def forward(self, x, H, W): + B, _, C = x.shape + cls_token, feat_token = x[:, 0], x[:, 1:] + cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W) + x = self.proj(cnn_feat) + cnn_feat + self.proj1(cnn_feat) + self.proj2(cnn_feat) + x = x.flatten(2).transpose(1, 2) + x = torch.cat((cls_token.unsqueeze(1), x), dim=1) + return x + + +class TransMIL(nn.Module): + def __init__(self, n_classes): + super(TransMIL, self).__init__() + self.pos_layer = PPEG(dim=512) + self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU()) + self.cls_token = nn.Parameter(torch.randn(1, 1, 512)) + self.n_classes = n_classes + self.layer1 = TransLayer(dim=512) + self.layer2 = TransLayer(dim=512) + self.norm = nn.LayerNorm(512) + self._fc2 = nn.Linear(512, self.n_classes) + + def forward(self, features, **kwargs): + h = features # [B, n, 1024] + + h = self._fc1(h) # [B, n, 512] + + # ---->pad + H = h.shape[1] + _H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H))) + add_length = _H * _W - H + h = torch.cat([h, h[:, :add_length, :]], dim=1) # [B, N, 512] + + # ---->cls_token + B = h.shape[0] + cls_tokens = self.cls_token.expand(B, -1, -1).cuda() + h = torch.cat((cls_tokens, h), dim=1) + + # ---->Translayer x1 + h = self.layer1(h) # [B, N, 512] + + # ---->PPEG + h = self.pos_layer(h, _H, _W) # [B, N, 512] + + # ---->Translayer x2 + h = self.layer2(h) # [B, N, 512] + + # ---->cls_token + h = self.norm(h)[:, 0] + + # ---->predict + logits = self._fc2(h) # [B, n_classes] + Y_hat = torch.argmax(logits, dim=1) + Y_prob = F.softmax(logits, dim=1) + results_dict = {'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat} + return logits \ No newline at end of file From c30e19505132f36d6c7cf352ff47aaeb758f0c5c Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Thu, 8 Aug 2024 16:44:15 +0200 Subject: [PATCH 02/14] added utils necessary to use models --- ahcore/models/MIL/ABmil.py | 3 +- ahcore/models/layers/MLP.py | 376 ++++++++++++++++++++++++++++++ ahcore/models/layers/attention.py | 170 ++++++++++++++ 3 files changed, 548 insertions(+), 1 deletion(-) create mode 100644 ahcore/models/layers/MLP.py create mode 100644 ahcore/models/layers/attention.py diff --git a/ahcore/models/MIL/ABmil.py b/ahcore/models/MIL/ABmil.py index 8cbe1f0..7c7b732 100644 --- a/ahcore/models/MIL/ABmil.py +++ b/ahcore/models/MIL/ABmil.py @@ -1,4 +1,5 @@ -from src.utils.model_utils import MLP, MaskedMLP, GatedAttention +from ahcore.models.layers.MLP import MLP, MaskedMLP, GatedAttention +from ahcore.models.layers.attention import GatedAttention from typing import List, Optional diff --git a/ahcore/models/layers/MLP.py b/ahcore/models/layers/MLP.py new file mode 100644 index 0000000..dfa310d --- /dev/null +++ b/ahcore/models/layers/MLP.py @@ -0,0 +1,376 @@ +from typing import Optional, List, Union, Tuple + +import torch +from torch import nn + +"""Most of this stuff is adapted from utils from https://github.com/owkin/HistoSSLscaling/tree/main""" + + +class MLP(nn.Sequential): + """MLP Module. + + Parameters + ---------- + in_features: int + Features (model input) dimension. + out_features: int = 1 + Prediction (model output) dimension. + hidden: Optional[List[int]] = None + Dimension of hidden layer(s). + dropout: Optional[List[float]] = None + Dropout rate(s). + activation: Optional[torch.nn.Module] = torch.nn.Sigmoid + MLP activation. + bias: bool = True + Add bias to MLP hidden layers. + + Raises + ------ + ValueError + If ``hidden`` and ``dropout`` do not share the same length. + """ + + def __init__( + self, + in_features: int, + out_features: int, + hidden: Optional[List[int]] = None, + dropout: Optional[List[float]] = None, + activation: Optional[nn.Module] = nn.Sigmoid(), + bias: bool = True, + ): + if dropout is not None: + if hidden is not None: + assert len(hidden) == len( + dropout + ), "hidden and dropout must have the same length" + else: + raise ValueError( + "hidden must have a value and have the same length as dropout if dropout is given." + ) + + d_model = in_features + layers = [] + + if hidden is not None: + for i, h in enumerate(hidden): + seq = [nn.Linear(d_model, h, bias=bias)] + d_model = h + + if activation is not None: + seq.append(activation) + + if dropout is not None: + seq.append(nn.Dropout(dropout[i])) + + layers.append(nn.Sequential(*seq)) + + layers.append(nn.Linear(d_model, out_features)) + + super(MLP, self).__init__(*layers) + +class MaskedLinear(nn.Linear): + """ + Linear layer to be applied tile wise. + This layer can be used in combination with a mask + to prevent padding tiles from influencing the values of a subsequent + activation. + Example: + >>> module = Linear(in_features=128, out_features=1) # With Linear + >>> out = module(slide) + >>> wrong_value = torch.sigmoid(out) # Value is influenced by padding + >>> module = MaskedLinear(in_features=128, out_features=1, mask_value='-inf') # With MaskedLinear + >>> out = module(slide, mask) # Padding now has the '-inf' value + >>> correct_value = torch.sigmoid(out) # Value is not influenced by padding as sigmoid('-inf') = 0 + Parameters + ---------- + in_features: int + size of each input sample + out_features: int + size of each output sample + mask_value: Union[str, int] + value to give to the mask + bias: bool = True + If set to ``False``, the layer will not learn an additive bias. + """ + + def __init__( + self, + in_features: int, + out_features: int, + mask_value: Union[str, float], + bias: bool = True, + ): + super(MaskedLinear, self).__init__( + in_features=in_features, out_features=out_features, bias=bias + ) + self.mask_value = mask_value + + def forward( + self, x: torch.Tensor, mask: Optional[torch.BoolTensor] = None + ): # pylint: disable=arguments-renamed + """Forward pass. + + Parameters + ---------- + x: torch.Tensor + Input tensor, shape (B, SEQ_LEN, IN_FEATURES). + mask: Optional[torch.BoolTensor] = None + True for values that were padded, shape (B, SEQ_LEN, 1), + + Returns + ------- + x: torch.Tensor + (B, SEQ_LEN, OUT_FEATURES) + """ + x = super(MaskedLinear, self).forward(x) + if mask is not None: + x = x.masked_fill(mask, float(self.mask_value)) + return x + + def extra_repr(self): + return ( + f"in_features={self.in_features}, out_features={self.out_features}, " + f"mask_value={self.mask_value}, bias={self.bias is not None}" + ) + + +class MaskedMLP(nn.Module): + """MLP to be applied to tiles to compute scores. + This module can be used in combination of a mask + to prevent padding from influencing the scores values. + Parameters + ---------- + in_features: int + size of each input sample + out_features: int + size of each output sample + hidden: Optional[List[int]] = None + Number of hidden layers and their respective number of features. + bias: bool = True + If set to ``False``, the layer will not learn an additive bias. + activation: torch.nn.Module = torch.nn.Sigmoid() + MLP activation function + dropout: Optional[torch.nn.Module] = None + Optional dropout module. Will be interlaced with the linear layers. + """ + + def __init__( + self, + in_features: int, + out_features: int = 1, + hidden: Optional[List[int]] = None, + bias: bool = True, + activation: nn.Module = nn.Sigmoid(), + dropout: Optional[nn.Module] = None, + ): + super(MaskedMLP, self).__init__() + + if dropout is not None: + assert len(dropout) == len(hidden), "Length of dropout is not correct" + + self.hidden_layers = nn.ModuleList() + if hidden is not None: + for i, h in enumerate(hidden): + self.hidden_layers.append( + MaskedLinear(in_features, h, bias=bias, mask_value="-inf") + ) + self.hidden_layers.append(activation) + if dropout: + self.hidden_layers.append(nn.Dropout(dropout[i])) + in_features = h + + self.hidden_layers.append( + nn.Linear(in_features, out_features, bias=bias) + ) + + def forward( + self, x: torch.Tensor, mask: Optional[torch.BoolTensor] = None + ): + """Forward pass. + + Parameters + ---------- + x: torch.Tensor + (B, N_TILES, IN_FEATURES) + mask: Optional[torch.BoolTensor] = None + (B, N_TILES), True for values that were padded. + + Returns + ------- + x: torch.Tensor + (B, N_TILES, OUT_FEATURES) + """ + for layer in self.hidden_layers: + if isinstance(layer, MaskedLinear): + x = layer(x, mask) + else: + x_before = x.clone().detach() + x = layer(x) + + if torch.any(x.masked_fill(mask, 0).isnan()): + raise RuntimeError(f"Found NaN values in x outside the mask") + + return x + +class SelfAttention(nn.Module): + """Multi-Head Self-Attention. + + Implementation adapted from https://github.com/rwightman/pytorch-image-models. + + Parameters + ---------- + in_features : int + Number of input features. + + num_heads : int = 8 + Number of attention heads. Should be an integer greater or equal to 1. + + qkv_bias : bool = False + Whether to add a bias to the linear projection for query, key and value. + + attn_dropout : float = 0.0 + Dropout rate (applied before the multiplication with the values). + + proj_dropout : float = 0.0 + Dropout rate (applied after the multiplication with the values). + """ + + def __init__( + self, + in_features: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_dropout: float = 0.0, + proj_dropout: float = 0.0, + ): + super().__init__() + self.in_features = in_features + self.num_heads = num_heads + self.qkv_bias = qkv_bias + self.attn_dropout = attn_dropout + self.proj_dropout = proj_dropout + + self.__build() + + def __build(self): + """Build the `SelfAttention` module.""" + head_dim = self.in_features // self.num_heads + self.scale = head_dim**-0.5 + self.qkv = nn.Linear( + self.in_features, self.in_features * 3, bias=self.qkv_bias + ) + self.attn_drop = nn.Dropout(self.attn_dropout) + self.proj = nn.Linear(self.in_features, self.in_features) + self.proj_drop = nn.Dropout(self.proj_dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Parameters + ---------- + x : torch.Tensor + Input tensor, shape (B, seq_len, in_features). + + Returns + ------- + out : torch.Tensor + Output tensor, shape (B, seq_len, in_features). + """ + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class GatedAttention(nn.Module): + """Gated Attention, as defined in https://arxiv.org/abs/1802.04712. + Permutation invariant Layer on dim 1. + Parameters + ---------- + d_model: int = 128 + temperature: float = 1.0 + Attention Softmax temperature + """ + + def __init__( + self, + d_model: int = 128, + temperature: float = 1.0, + ): + super(GatedAttention, self).__init__() + + self.V = nn.Linear(d_model, d_model) + self.U = nn.Linear(d_model, d_model) + self.w = MaskedLinear(d_model, 1, "-inf") + + self.temperature = temperature + + def attention( + self, + features: torch.Tensor, + mask: Optional[torch.BoolTensor] = None, + ) -> torch.Tensor: + """Gets attention logits. + Parameters + ---------- + v: torch.Tensor + (B, SEQ_LEN, IN_FEATURES) + mask: Optional[torch.BoolTensor] = None + (B, SEQ_LEN, 1), True for values that were padded. + Returns + ------- + attention_logits: torch.Tensor + (B, N_TILES, 1) + """ + h_v = torch.tanh(self.U(features)) + + u_v = torch.sigmoid(self.V(features)) + + attention_logits = self.w(h_v * u_v, mask=mask) / self.temperature + + attention_weights = torch.softmax(attention_logits, 1) + + return attention_weights + + def forward( + self, features: torch.Tensor, mask: Optional[torch.BoolTensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass. + Parameters + ---------- + v: torch.Tensor + (B, SEQ_LEN, IN_FEATURES) + mask: Optional[torch.BoolTensor] = None + (B, SEQ_LEN, 1), True for values that were padded. + Returns + ------- + scaled_attention, attention_weights: Tuple[torch.Tensor, torch.Tensor] + (B, IN_FEATURES), (B, N_TILES, 1) + """ + h_v = torch.tanh(self.U(features)) + + u_v = torch.sigmoid(self.V(features)) + + attention_logits = self.w(h_v * u_v, mask=mask) / self.temperature + + attention_weights = torch.softmax(attention_logits, 1) + # if not torch.any(attention_weights[mask]==0.0): + # raise RuntimeError(f"Masked indices got non-zero weight") + + # features = features.masked_fill(mask, float(0.0)) + scaled_attention = torch.matmul(attention_weights.transpose(1, 2), features) + + return scaled_attention.squeeze(1), attention_weights \ No newline at end of file diff --git a/ahcore/models/layers/attention.py b/ahcore/models/layers/attention.py new file mode 100644 index 0000000..ec7d4aa --- /dev/null +++ b/ahcore/models/layers/attention.py @@ -0,0 +1,170 @@ +from typing import Optional, List, Union, Tuple + +import torch +from torch import nn + +from ahcore.models.layers.MLP import MaskedLinear + +"""Most of this stuff is adapted from utils from https://github.com/owkin/HistoSSLscaling/tree/main""" + +class SelfAttention(nn.Module): + """Multi-Head Self-Attention. + + Implementation adapted from https://github.com/rwightman/pytorch-image-models. + + Parameters + ---------- + in_features : int + Number of input features. + + num_heads : int = 8 + Number of attention heads. Should be an integer greater or equal to 1. + + qkv_bias : bool = False + Whether to add a bias to the linear projection for query, key and value. + + attn_dropout : float = 0.0 + Dropout rate (applied before the multiplication with the values). + + proj_dropout : float = 0.0 + Dropout rate (applied after the multiplication with the values). + """ + + def __init__( + self, + in_features: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_dropout: float = 0.0, + proj_dropout: float = 0.0, + ): + super().__init__() + self.in_features = in_features + self.num_heads = num_heads + self.qkv_bias = qkv_bias + self.attn_dropout = attn_dropout + self.proj_dropout = proj_dropout + + self.__build() + + def __build(self): + """Build the `SelfAttention` module.""" + head_dim = self.in_features // self.num_heads + self.scale = head_dim**-0.5 + self.qkv = nn.Linear( + self.in_features, self.in_features * 3, bias=self.qkv_bias + ) + self.attn_drop = nn.Dropout(self.attn_dropout) + self.proj = nn.Linear(self.in_features, self.in_features) + self.proj_drop = nn.Dropout(self.proj_dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Parameters + ---------- + x : torch.Tensor + Input tensor, shape (B, seq_len, in_features). + + Returns + ------- + out : torch.Tensor + Output tensor, shape (B, seq_len, in_features). + """ + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class GatedAttention(nn.Module): + """Gated Attention, as defined in https://arxiv.org/abs/1802.04712. + Permutation invariant Layer on dim 1. + Parameters + ---------- + d_model: int = 128 + temperature: float = 1.0 + Attention Softmax temperature + """ + + def __init__( + self, + d_model: int = 128, + temperature: float = 1.0, + ): + super(GatedAttention, self).__init__() + + self.V = nn.Linear(d_model, d_model) + self.U = nn.Linear(d_model, d_model) + self.w = MaskedLinear(d_model, 1, "-inf") + + self.temperature = temperature + + def attention( + self, + features: torch.Tensor, + mask: Optional[torch.BoolTensor] = None, + ) -> torch.Tensor: + """Gets attention logits. + Parameters + ---------- + v: torch.Tensor + (B, SEQ_LEN, IN_FEATURES) + mask: Optional[torch.BoolTensor] = None + (B, SEQ_LEN, 1), True for values that were padded. + Returns + ------- + attention_logits: torch.Tensor + (B, N_TILES, 1) + """ + h_v = torch.tanh(self.U(features)) + + u_v = torch.sigmoid(self.V(features)) + + attention_logits = self.w(h_v * u_v, mask=mask) / self.temperature + + attention_weights = torch.softmax(attention_logits, 1) + + return attention_weights + + def forward( + self, features: torch.Tensor, mask: Optional[torch.BoolTensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass. + Parameters + ---------- + v: torch.Tensor + (B, SEQ_LEN, IN_FEATURES) + mask: Optional[torch.BoolTensor] = None + (B, SEQ_LEN, 1), True for values that were padded. + Returns + ------- + scaled_attention, attention_weights: Tuple[torch.Tensor, torch.Tensor] + (B, IN_FEATURES), (B, N_TILES, 1) + """ + h_v = torch.tanh(self.U(features)) + + u_v = torch.sigmoid(self.V(features)) + + attention_logits = self.w(h_v * u_v, mask=mask) / self.temperature + + attention_weights = torch.softmax(attention_logits, 1) + # if not torch.any(attention_weights[mask]==0.0): + # raise RuntimeError(f"Masked indices got non-zero weight") + + # features = features.masked_fill(mask, float(0.0)) + scaled_attention = torch.matmul(attention_weights.transpose(1, 2), features) + + return scaled_attention.squeeze(1), attention_weights \ No newline at end of file From a20bb153812556cbd8e67e13c74ceac3d2065a75 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Fri, 9 Aug 2024 18:03:07 +0200 Subject: [PATCH 03/14] added huggingface models and transforms and minor model changes --- ahcore/models/MIL/ABmil.py | 10 ++------ ahcore/models/base_jit_model.py | 36 +++++++++++++++++++++++++++-- ahcore/transforms/pre_transforms.py | 19 +++++++++++++-- 3 files changed, 53 insertions(+), 12 deletions(-) diff --git a/ahcore/models/MIL/ABmil.py b/ahcore/models/MIL/ABmil.py index 7c7b732..5dff15b 100644 --- a/ahcore/models/MIL/ABmil.py +++ b/ahcore/models/MIL/ABmil.py @@ -1,4 +1,4 @@ -from ahcore.models.layers.MLP import MLP, MaskedMLP, GatedAttention +from ahcore.models.layers.MLP import MLP, MaskedMLP from ahcore.models.layers.attention import GatedAttention from typing import List, Optional @@ -7,6 +7,7 @@ from torch import nn import torch.nn.init as init +# todo: fix docstring class ABMIL(nn.Module): """Attention-based MIL classification model (See [1]_). Adapted from https://github.com/owkin/HistoSSLscaling/blob/main/rl_benchmarks/models/slide_models/abmil.py @@ -36,13 +37,6 @@ class ABMIL(nn.Module): Activation for last MLP. bias: bool = True Add bias to the first MLP. - metadata_cols: int = 3 - Number of metadata columns (for example, magnification, patch start - coordinates etc.) at the start of input data. Default of 3 assumes - that the first 3 columns of input data are, respectively: - 1) Deep zoom level, corresponding to a given magnification - 2) input patch starting x value - 3) input patch starting y value References ---------- diff --git a/ahcore/models/base_jit_model.py b/ahcore/models/base_jit_model.py index f21bb9b..69fe1f1 100644 --- a/ahcore/models/base_jit_model.py +++ b/ahcore/models/base_jit_model.py @@ -2,7 +2,39 @@ from typing import Any from torch.jit import ScriptModule, load -from torch.nn import Module +from torch import nn + +from transformers.modeling_utils import PreTrainedModel + +class BaseHuggingfaceModel(nn.Module): + + def __init__(self, model: PreTrainedModel, pretrained_model_name_or_path: str, **kwargs) -> None: + super().__init__() + + self.model: model = model.from_pretrained(pretrained_model_name_or_path, **kwargs) + + def forward(self, x): + model_input = x if type(x) is dict else {"pixel_values": x} # todo check if huggingface models sometimes other things??? + model_output = self.model(**model_input) + return model_output.last_hidden_states + + def get_attentions(self, x): + model_input = {"pixel_values": x} + model_output = self.model(**model_input) + return model_output.attentions + + def get_raw_output(self, x): + model_input = {"pixel_values": x} + model_output = self.model(**model_input) + return model_output + + def get_output_at_keys(self, x, keys): + if isinstance(keys, str): + keys = [keys] + + model_input = {"pixel_values": x} + model_output = self.model(**model_input) + return {model_output[key] for key in keys} if len(keys)>1 else model_output[keys[0]] class BaseAhcoreJitModel(ScriptModule): @@ -46,7 +78,7 @@ def from_jit_path(cls, jit_path: Path, output_mode: str) -> Any: model = load(jit_path) # type: ignore return cls(model) - def extend_model(self, modules: dict[str, Module]) -> None: + def extend_model(self, modules: dict[str, nn.Module]) -> None: """ Add modules to a jit compiled model. diff --git a/ahcore/transforms/pre_transforms.py b/ahcore/transforms/pre_transforms.py index 2430778..a14e9c7 100644 --- a/ahcore/transforms/pre_transforms.py +++ b/ahcore/transforms/pre_transforms.py @@ -19,6 +19,8 @@ from ahcore.utils.io import get_logger from ahcore.utils.types import DlupDatasetSample +from transformers import AutoImageProcessor + PreTransformCallable = Callable[[Any], Any] logger = get_logger(__name__) @@ -117,6 +119,17 @@ def __repr__(self) -> str: return f"PreTransformTaskFactory(transforms={self._transforms})" +class ApplyHuggingfaceTransforms: + + def __init__(self, pretrained_model_name_or_path: str, **kwargs): + self._processor = AutoImageProcessor.from_pretrained(pretrained_model_name_or_path, **kwargs) + + def __call__(self, sample: DlupDatasetSample) -> DlupDatasetSample: + # Apply the huggingface transforms here + sample["image"]: np.ndarray = self._processor(sample["image"])["pixel_values"] + + return sample + class LabelToClassIndex: """ Maps label values to class indices according to the index_map specified in the data description. @@ -216,12 +229,14 @@ class ImageToTensor: """ def __call__(self, sample: DlupDatasetSample) -> dict[str, DlupDatasetSample]: - tile: pyvips.Image = sample["image"] + tile: pyvips.Image | np.ndarray = sample["image"] # Flatten the image to remove the alpha channel, using white as the background color tile_ = tile.flatten(background=[255, 255, 255]) # Convert VIPS image to a numpy array then to a torch tensor - np_image = tile_.numpy() + if type(tile_) == pyvips.Image: + np_image = tile_.numpy() + sample["image"] = torch.from_numpy(np_image).permute(2, 0, 1).float() if sample["image"].sum() == 0: From e2e10ec7e89d25ddf5b8663c3aac23b3a9868060 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Tue, 13 Aug 2024 14:21:05 +0200 Subject: [PATCH 04/14] fix loading paths for HF models similar to jitmodels --- ahcore/lit_module.py | 3 ++- ahcore/utils/io.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ahcore/lit_module.py b/ahcore/lit_module.py index 3001501..a84c43a 100644 --- a/ahcore/lit_module.py +++ b/ahcore/lit_module.py @@ -13,6 +13,7 @@ import torch.optim.optimizer from pytorch_lightning.trainer.states import TrainerFn from torch import nn +import transformers from ahcore.exceptions import ConfigurationError from ahcore.metrics import MetricFactory, WSIMetricFactory @@ -58,7 +59,7 @@ def __init__( "loss", ], ) # TODO: we should send the hyperparams to the logger elsewhere - if isinstance(model, BaseAhcoreJitModel): + if isinstance(model, BaseAhcoreJitModel) or isinstance(model, transformers.modeling_utils.PretrainedModel): self._model = model elif isinstance(model, functools.partial): try: diff --git a/ahcore/utils/io.py b/ahcore/utils/io.py index bc7b2dd..ae65402 100644 --- a/ahcore/utils/io.py +++ b/ahcore/utils/io.py @@ -28,6 +28,7 @@ from omegaconf.errors import InterpolationKeyError from pytorch_lightning import LightningModule from pytorch_lightning.utilities import rank_zero_only +import transformers from ahcore.models.base_jit_model import BaseAhcoreJitModel @@ -238,7 +239,7 @@ def load_weights(model: LightningModule, config: DictConfig) -> LightningModule: The model loaded from the checkpoint file. """ _model = getattr(model, "_model") - if isinstance(_model, BaseAhcoreJitModel): + if isinstance(_model, BaseAhcoreJitModel) or isinstance(_model, transformers.modeling_utils.PretrainedModel): return model else: # Load checkpoint weights From 8327135cb9fb9f47394982bc06b30b0b91907f67 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Tue, 13 Aug 2024 17:52:23 +0200 Subject: [PATCH 05/14] simplified abmil and layers --- ahcore/models/MIL/ABmil.py | 282 ++++++++++--------------- ahcore/models/MIL/transmil.py | 75 +++---- ahcore/models/base_jit_model.py | 7 +- ahcore/models/layers/MLP.py | 316 +--------------------------- ahcore/models/layers/attention.py | 65 ++---- ahcore/transforms/pre_transforms.py | 1 + 6 files changed, 179 insertions(+), 567 deletions(-) diff --git a/ahcore/models/MIL/ABmil.py b/ahcore/models/MIL/ABmil.py index 5dff15b..a4900e7 100644 --- a/ahcore/models/MIL/ABmil.py +++ b/ahcore/models/MIL/ABmil.py @@ -1,50 +1,52 @@ -from ahcore.models.layers.MLP import MLP, MaskedMLP +from ahcore.models.layers.MLP import MLP from ahcore.models.layers.attention import GatedAttention from typing import List, Optional import torch from torch import nn -import torch.nn.init as init -# todo: fix docstring -class ABMIL(nn.Module): - """Attention-based MIL classification model (See [1]_). - Adapted from https://github.com/owkin/HistoSSLscaling/blob/main/rl_benchmarks/models/slide_models/abmil.py - Example: - >>> module = ABMIL(in_features=128, out_features=1) - >>> logits, attention_scores = module(slide, mask=mask) - >>> attention_scores = module.score_model(slide, mask=mask) +class ABMIL(nn.Module): + """ + Attention-based MIL (Multiple Instance Learning) classification model (See [1]_). + This model is adapted from https://github.com/owkin/HistoSSLscaling/blob/main/rl_benchmarks/models/slide_models/abmil.py. + It uses an attention mechanism to aggregate features from multiple instances (tiles) into a single prediction. Parameters ---------- - in_features: int - Features (model input) dimension. - out_features: int = 1 - Prediction (model output) dimension. - d_model_attention: int = 128 - Dimension of attention scores. - temperature: float = 1.0 - GatedAttention softmax temperature. - tiles_mlp_hidden: Optional[List[int]] = None - Dimension of hidden layers in first MLP. - mlp_hidden: Optional[List[int]] = None - Dimension of hidden layers in last MLP. - mlp_dropout: Optional[List[float]] = None, - Dropout rate for last MLP. - mlp_activation: Optional[torch.nn.Module] = torch.nn.Sigmoid - Activation for last MLP. - bias: bool = True - Add bias to the first MLP. + in_features : int + Number of input features for each tile. + out_features : int, optional + Number of output features (typically 1 for binary classification), by default 1. + attention_dimension : int, optional + Dimensionality of the attention mechanism, by default 128. + temperature : float, optional + Temperature parameter for scaling the attention scores, by default 1.0. + embed_mlp_hidden : Optional[List[int]], optional + List of hidden layer sizes for the embedding MLP, by default None. + embed_mlp_dropout : Optional[List[float]], optional + List of dropout rates for the embedding MLP, by default None. + embed_mlp_activation : Optional[torch.nn.Module], optional + Activation function for the embedding MLP, by default nn.ReLU(). + embed_mlp_bias : bool, optional + Whether to include bias in the embedding MLP layers, by default True. + classifier_hidden : Optional[List[int]], optional + List of hidden layer sizes for the classifier MLP, by default [128, 64]. + classifier_dropout : Optional[List[float]], optional + List of dropout rates for the classifier MLP, by default None. + classifier_activation : Optional[torch.nn.Module], optional + Activation function for the classifier MLP, by default nn.ReLU(). + classifier_bias : bool, optional + Whether to include bias in the classifier MLP layers, by default False. References ---------- .. [1] Maximilian Ilse, Jakub Tomczak, and Max Welling. Attention-based - deep multiple instance learning. In Jennifer Dy and Andreas Krause, - editors, Proceedings of the 35th International Conference on Machine - Learning, volume 80 of Proceedings of Machine Learning Research, - pages 2127–2136. PMLR, 10–15 Jul 2018. + deep multiple instance learning. In Jennifer Dy and Andreas Krause, + editors, Proceedings of the 35th International Conference on Machine + Learning, volume 80 of Proceedings of Machine Learning Research, + pages 2127–2136. PMLR, 10–15 Jul 2018. """ @@ -52,173 +54,119 @@ def __init__( self, in_features: int, out_features: int = 1, - number_of_tiles: int = 1000, - d_model_attention: int = 128, + attention_dimension: int = 128, temperature: float = 1.0, - masked_mlp_hidden: Optional[List[int]] = None, - masked_mlp_dropout: Optional[List[float]] = None, - masked_mlp_activation: Optional[torch.nn.Module] = nn.Sigmoid(), - mlp_hidden: Optional[List[int]] = [128, 64], - mlp_dropout: Optional[List[float]] = None, - mlp_activation: Optional[torch.nn.Module] = nn.Sigmoid(), - bias: bool = True, - use_positional_encoding: bool = False, + embed_mlp_hidden: Optional[List[int]] = None, + embed_mlp_dropout: Optional[List[float]] = None, + embed_mlp_activation: Optional[torch.nn.Module] = nn.ReLU(), + embed_mlp_bias: bool = True, + classifier_hidden: Optional[List[int]] = [128, 64], + classifier_dropout: Optional[List[float]] = None, + classifier_activation: Optional[torch.nn.Module] = nn.ReLU(), + classifier_bias: bool = False, ) -> None: - super(ABMIL, self).__init__() + """ + Initializes the ABMIL model with embedding and classification layers. - if mlp_dropout is not None: - if mlp_hidden is not None: - assert len(mlp_hidden) == len( - mlp_dropout - ), "mlp_hidden and mlp_dropout must have the same length" - else: - raise ValueError( - "mlp_hidden must have a value and have the same length" - "as mlp_dropout if mlp_dropout is given." - ) + Parameters + ---------- + in_features : int + Number of input features for each tile. + out_features : int, optional + Number of output features (typically 1 for binary classification), by default 1. + attention_dimension : int, optional + Dimensionality of the attention mechanism, by default 128. + temperature : float, optional + Temperature parameter for scaling the attention scores, by default 1.0. + embed_mlp_hidden : Optional[List[int]], optional + List of hidden layer sizes for the embedding MLP, by default None. + embed_mlp_dropout : Optional[List[float]], optional + List of dropout rates for the embedding MLP, by default None. + embed_mlp_activation : Optional[torch.nn.Module], optional + Activation function for the embedding MLP, by default nn.ReLU(). + embed_mlp_bias : bool, optional + Whether to include bias in the embedding MLP layers, by default True. + classifier_hidden : Optional[List[int]], optional + List of hidden layer sizes for the classifier MLP, by default [128, 64]. + classifier_dropout : Optional[List[float]], optional + List of dropout rates for the classifier MLP, by default None. + classifier_activation : Optional[torch.nn.Module], optional + Activation function for the classifier MLP, by default nn.ReLU(). + classifier_bias : bool, optional + Whether to include bias in the classifier MLP layers, by default False. + + """ + super(ABMIL, self).__init__() self.embed_mlp = MLP( in_features=in_features, - hidden=masked_mlp_hidden, - bias=bias, - out_features=d_model_attention, - dropout=masked_mlp_dropout, - activation=masked_mlp_activation, - ) - - self.attention_layer = GatedAttention( - d_model=d_model_attention, temperature=temperature + hidden=embed_mlp_hidden, + bias=embed_mlp_bias, + out_features=attention_dimension, + dropout=embed_mlp_dropout, + activation=embed_mlp_activation, ) - mlp_in_features = d_model_attention + self.attention_layer = GatedAttention(dim=attention_dimension, temperature=temperature) - self.mlp = MLP( - in_features=mlp_in_features, + self.classifier = MLP( + in_features=attention_dimension, out_features=out_features, - hidden=mlp_hidden, - dropout=mlp_dropout, - activation=mlp_activation, + bias=classifier_bias, + hidden=classifier_hidden, + dropout=classifier_dropout, + activation=classifier_activation, ) - self.use_positional_encoding = use_positional_encoding - - if self.use_positional_encoding: - # TODO this should also add some interpolation - self.positional_encoding = nn.Parameter(torch.zeros(1, number_of_tiles, in_features)) - init.trunc_normal_(self.positional_encoding, mean=0.0, std=0.02, a=-2.0, b=2.0) - self.positional_encoding.requires_grad = True - - def interpolate_positional_encoding(self, coordinates: torch.Tensor): + def get_attention(self, x: torch.Tensor) -> torch.Tensor: """ - Perform bilinear interpolation using the given coordinates on the positional encoding. - The positional encoding is considered as a flattened array representing a (h, w) grid. - - Args: - coordinates (torch.Tensor): The normalized coordinates tensor of shape (batch_size, 2), - where each row is (x, y) in normalized coordinates [0, 1]. - - Returns: - torch.Tensor: The interpolated features from the positional encoding. - """ - # Scale coordinates to the range of the positional encoding indices - max_idx = int(torch.sqrt(torch.tensor([self.positional_encoding.shape[1]]))) - 1 - scaled_coordinates = max_idx * coordinates - - # Separate scaled coordinates into x and y components - x = scaled_coordinates[..., 0] - y = scaled_coordinates[..., 1] - - - # Get integer parts of coordinates - x0 = torch.floor(x).int() - x1 = x0 + 1 - y0 = torch.floor(y).int() - y1 = y0 + 1 - - # Clamp indices to ensure they remain within valid range - x0 = torch.clamp(x0, 0, max_idx) - x1 = torch.clamp(x1, 0, max_idx) - y0 = torch.clamp(y0, 0, max_idx) - y1 = torch.clamp(y1, 0, max_idx) - - # Calculate linear indices - idx_q11 = y0 * max_idx + x0 - idx_q12 = y1 * max_idx + x0 - idx_q21 = y0 * max_idx + x1 - idx_q22 = y1 * max_idx + x1 - - # Fetch the corner points - q11 = self.positional_encoding[0, idx_q11, :] - q12 = self.positional_encoding[0, idx_q12, :] - q21 = self.positional_encoding[0, idx_q21, :] - q22 = self.positional_encoding[0, idx_q22, :] - - # Compute fractional part for interpolation - x_frac = x - x0.float() - y_frac = y - y0.float() - - # Bilinear interpolation - interpolated_positional_encoding = (q11 * (1 - x_frac).unsqueeze(2) * (1 - y_frac).unsqueeze(2) + - q12 * (1 - x_frac).unsqueeze(2) * y_frac.unsqueeze(2) + - q21 * x_frac.unsqueeze(2) * (1 - y_frac).unsqueeze(2) + - q22 * x_frac.unsqueeze(2) * y_frac.unsqueeze(2)) - - return interpolated_positional_encoding - - def get_attention( - self, x: torch.Tensor, mask: Optional[torch.BoolTensor] = None, coordinates: torch.Tensor = None, - ) -> torch.Tensor: - """Get attention logits. + Computes the attention weights for the input features. Parameters ---------- - x: torch.Tensor - (B, N_TILES, FEATURES) - mask: Optional[torch.BoolTensor] - (B, N_TILES, 1), True for values that were padded. + x : torch.Tensor + Input tensor of shape (batch_size, n_tiles, in_features) representing the features of tiles. Returns ------- - attention_logits: torch.Tensor - (B, N_TILES, 1) - """ - if self.use_positional_encoding: - positional_encoding = self.interpolate_positional_encoding(coordinates) - x = x + positional_encoding + torch.Tensor + Attention weights for each tile. - tiles_emb = self.tiles_emb(x, mask) - attention_weights = self.attention_layer.attention(tiles_emb, mask) + """ + tiles_emb = self.embed_mlp(x) + attention_weights = self.attention_layer.attention(tiles_emb) return attention_weights def forward( - self, features: torch.Tensor, mask: Optional[torch.BoolTensor] = None, coordinates: torch.Tensor = None, return_attention: bool = False, + self, + features: torch.Tensor, + return_attention_weights: bool = False, ) -> torch.Tensor: """ + Forward pass of the ABMIL model. + Parameters ---------- - coordinates - features: torch.Tensor - (B, N_TILES, D+3) - mask: Optional[torch.BoolTensor] - (B, N_TILES, 1), True for values that were padded. + features : torch.Tensor + Input tensor of shape (batch_size, n_tiles, in_features) representing the features of tiles. + return_attention : bool, optional + If True, also returns the attention weights, by default False. Returns ------- - logits, attention_weights: Tuple[torch.Tensor, torch.Tensor] - (B, OUT_FEATURES), (B, N_TILES) - """ - if coordinates is None and self.positional_encoding: - raise ValueError(f"Coordinates of NoneType are not accepted if positional_encoding is used") - - if self.use_positional_encoding: - positional_encoding = self.interpolate_positional_encoding(coordinates) - features = features + positional_encoding + torch.Tensor + Logits representing the model's output. + torch.Tensor, optional + Attention weights, returned if return_attention is True. + """ tiles_emb = self.embed_mlp(features) # BxN_tilesxN_features --> BxN_tilesx128 - scaled_tiles_emb, attention_weights = self.attention_layer(tiles_emb, mask) # BxN_tilesx128 --> Bx128 - logits = self.mlp(scaled_tiles_emb) # Bx128 --> Bx1 + scaled_tiles_emb, attention_weights = self.attention_layer( + tiles_emb, return_attention_weights=True + ) # BxN_tilesx128 --> Bx128 + logits = self.classifier(scaled_tiles_emb) # Bx128 --> Bx1 - if return_attention: + if return_attention_weights: return logits, attention_weights - return logits \ No newline at end of file + return logits diff --git a/ahcore/models/MIL/transmil.py b/ahcore/models/MIL/transmil.py index 716b3bc..3026cab 100644 --- a/ahcore/models/MIL/transmil.py +++ b/ahcore/models/MIL/transmil.py @@ -15,19 +15,21 @@ # helper functions + def exists(val): return val is not None -def moore_penrose_iter_pinv(x, iters = 6): + +def moore_penrose_iter_pinv(x, iters=6): device = x.device abs_x = torch.abs(x) - col = abs_x.sum(dim = -1) - row = abs_x.sum(dim = -2) - z = rearrange(x, '... i j -> ... j i') / (torch.max(col) * torch.max(row)) + col = abs_x.sum(dim=-1) + row = abs_x.sum(dim=-2) + z = rearrange(x, "... i j -> ... j i") / (torch.max(col) * torch.max(row)) - I = torch.eye(x.shape[-1], device = device) - I = rearrange(I, 'i j -> () i j') + I = torch.eye(x.shape[-1], device=device) + I = rearrange(I, "i j -> () i j") for _ in range(iters): xz = x @ z @@ -35,20 +37,22 @@ def moore_penrose_iter_pinv(x, iters = 6): return z + # main attention class + class NystromAttention(nn.Module): def __init__( self, dim, - dim_head = 64, - heads = 8, - num_landmarks = 256, - pinv_iterations = 6, - residual = True, - residual_conv_kernel = 33, - eps = 1e-8, - dropout = 0. + dim_head=64, + heads=8, + num_landmarks=256, + pinv_iterations=6, + residual=True, + residual_conv_kernel=33, + eps=1e-8, + dropout=0.0, ): super().__init__() self.eps = eps @@ -58,21 +62,18 @@ def __init__( self.pinv_iterations = pinv_iterations self.heads = heads - self.scale = dim_head ** -0.5 - self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + self.scale = dim_head**-0.5 + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) - self.to_out = nn.Sequential( - nn.Linear(inner_dim, dim), - nn.Dropout(dropout) - ) + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) self.residual = residual if residual: kernel_size = residual_conv_kernel padding = residual_conv_kernel // 2 - self.res_conv = nn.Conv2d(heads, heads, (kernel_size, 1), padding = (padding, 0), groups = heads, bias = False) + self.res_conv = nn.Conv2d(heads, heads, (kernel_size, 1), padding=(padding, 0), groups=heads, bias=False) - def forward(self, x, mask = None, return_attn = False): + def forward(self, x, mask=None, return_attn=False): b, n, _, h, m, iters, eps = *x.shape, self.heads, self.num_landmarks, self.pinv_iterations, self.eps # pad so that sequence can be evenly divided into m landmarks @@ -80,20 +81,20 @@ def forward(self, x, mask = None, return_attn = False): remainder = n % m if remainder > 0: padding = m - (n % m) - x = F.pad(x, (0, 0, padding, 0), value = 0) + x = F.pad(x, (0, 0, padding, 0), value=0) if exists(mask): - mask = F.pad(mask, (padding, 0), value = False) + mask = F.pad(mask, (padding, 0), value=False) # derive query, keys, values - q, k, v = self.to_qkv(x).chunk(3, dim = -1) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) # set masked positions to 0 in queries, keys, values if exists(mask): - mask = rearrange(mask, 'b n -> b () n') + mask = rearrange(mask, "b n -> b () n") q, k, v = map(lambda t: t * mask[..., None], (q, k, v)) q = q * self.scale @@ -101,15 +102,15 @@ def forward(self, x, mask = None, return_attn = False): # generate landmarks by sum reduction, and then calculate mean using the mask l = ceil(n / m) - landmark_einops_eq = '... (n l) d -> ... n d' - q_landmarks = reduce(q, landmark_einops_eq, 'sum', l = l) - k_landmarks = reduce(k, landmark_einops_eq, 'sum', l = l) + landmark_einops_eq = "... (n l) d -> ... n d" + q_landmarks = reduce(q, landmark_einops_eq, "sum", l=l) + k_landmarks = reduce(k, landmark_einops_eq, "sum", l=l) # calculate landmark mask, and also get sum of non-masked elements in preparation for masked mean divisor = l if exists(mask): - mask_landmarks_sum = reduce(mask, '... (n l) -> ... n', 'sum', l = l) + mask_landmarks_sum = reduce(mask, "... (n l) -> ... n", "sum", l=l) divisor = mask_landmarks_sum[..., None] + eps mask_landmarks = mask_landmarks_sum > 0 @@ -120,7 +121,7 @@ def forward(self, x, mask = None, return_attn = False): # similarities - einops_eq = '... i d, ... j d -> ... i j' + einops_eq = "... i d, ... j d -> ... i j" sim1 = einsum(einops_eq, q, k_landmarks) sim2 = einsum(einops_eq, q_landmarks, k_landmarks) sim3 = einsum(einops_eq, q_landmarks, k) @@ -135,7 +136,7 @@ def forward(self, x, mask = None, return_attn = False): # eq (15) in the paper and aggregate values - attn1, attn2, attn3 = map(lambda t: t.softmax(dim = -1), (sim1, sim2, sim3)) + attn1, attn2, attn3 = map(lambda t: t.softmax(dim=-1), (sim1, sim2, sim3)) attn2_inv = moore_penrose_iter_pinv(attn2, iters) out = (attn1 @ attn2_inv) @ (attn3 @ v) @@ -147,7 +148,7 @@ def forward(self, x, mask = None, return_attn = False): # merge and combine heads - out = rearrange(out, 'b h n d -> b n (h d)', h = h) + out = rearrange(out, "b h n d -> b n (h d)", h=h) out = self.to_out(out) out = out[:, -n:] @@ -172,7 +173,7 @@ def __init__(self, norm_layer=nn.LayerNorm, dim=512): # number of moore-penrose iterations for approximating pinverse. 6 was recommended by the paper residual=True, # whether to do an extra residual with the value or not. supposedly faster convergence if turned on - dropout=0.1 + dropout=0.1, ) def forward(self, x): @@ -242,5 +243,5 @@ def forward(self, features, **kwargs): logits = self._fc2(h) # [B, n_classes] Y_hat = torch.argmax(logits, dim=1) Y_prob = F.softmax(logits, dim=1) - results_dict = {'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat} - return logits \ No newline at end of file + results_dict = {"logits": logits, "Y_prob": Y_prob, "Y_hat": Y_hat} + return logits diff --git a/ahcore/models/base_jit_model.py b/ahcore/models/base_jit_model.py index 69fe1f1..67071d2 100644 --- a/ahcore/models/base_jit_model.py +++ b/ahcore/models/base_jit_model.py @@ -6,6 +6,7 @@ from transformers.modeling_utils import PreTrainedModel + class BaseHuggingfaceModel(nn.Module): def __init__(self, model: PreTrainedModel, pretrained_model_name_or_path: str, **kwargs) -> None: @@ -14,7 +15,9 @@ def __init__(self, model: PreTrainedModel, pretrained_model_name_or_path: str, * self.model: model = model.from_pretrained(pretrained_model_name_or_path, **kwargs) def forward(self, x): - model_input = x if type(x) is dict else {"pixel_values": x} # todo check if huggingface models sometimes other things??? + model_input = ( + x if type(x) is dict else {"pixel_values": x} + ) # todo check if huggingface models sometimes other things??? model_output = self.model(**model_input) return model_output.last_hidden_states @@ -34,7 +37,7 @@ def get_output_at_keys(self, x, keys): model_input = {"pixel_values": x} model_output = self.model(**model_input) - return {model_output[key] for key in keys} if len(keys)>1 else model_output[keys[0]] + return {model_output[key] for key in keys} if len(keys) > 1 else model_output[keys[0]] class BaseAhcoreJitModel(ScriptModule): diff --git a/ahcore/models/layers/MLP.py b/ahcore/models/layers/MLP.py index dfa310d..7a75211 100644 --- a/ahcore/models/layers/MLP.py +++ b/ahcore/models/layers/MLP.py @@ -36,18 +36,14 @@ def __init__( out_features: int, hidden: Optional[List[int]] = None, dropout: Optional[List[float]] = None, - activation: Optional[nn.Module] = nn.Sigmoid(), + activation: Optional[nn.Module] = nn.ReLU(), bias: bool = True, ): if dropout is not None: if hidden is not None: - assert len(hidden) == len( - dropout - ), "hidden and dropout must have the same length" + assert len(hidden) == len(dropout), "hidden and dropout must have the same length" else: - raise ValueError( - "hidden must have a value and have the same length as dropout if dropout is given." - ) + raise ValueError("hidden must have a value and have the same length as dropout if dropout is given.") d_model = in_features layers = [] @@ -68,309 +64,3 @@ def __init__( layers.append(nn.Linear(d_model, out_features)) super(MLP, self).__init__(*layers) - -class MaskedLinear(nn.Linear): - """ - Linear layer to be applied tile wise. - This layer can be used in combination with a mask - to prevent padding tiles from influencing the values of a subsequent - activation. - Example: - >>> module = Linear(in_features=128, out_features=1) # With Linear - >>> out = module(slide) - >>> wrong_value = torch.sigmoid(out) # Value is influenced by padding - >>> module = MaskedLinear(in_features=128, out_features=1, mask_value='-inf') # With MaskedLinear - >>> out = module(slide, mask) # Padding now has the '-inf' value - >>> correct_value = torch.sigmoid(out) # Value is not influenced by padding as sigmoid('-inf') = 0 - Parameters - ---------- - in_features: int - size of each input sample - out_features: int - size of each output sample - mask_value: Union[str, int] - value to give to the mask - bias: bool = True - If set to ``False``, the layer will not learn an additive bias. - """ - - def __init__( - self, - in_features: int, - out_features: int, - mask_value: Union[str, float], - bias: bool = True, - ): - super(MaskedLinear, self).__init__( - in_features=in_features, out_features=out_features, bias=bias - ) - self.mask_value = mask_value - - def forward( - self, x: torch.Tensor, mask: Optional[torch.BoolTensor] = None - ): # pylint: disable=arguments-renamed - """Forward pass. - - Parameters - ---------- - x: torch.Tensor - Input tensor, shape (B, SEQ_LEN, IN_FEATURES). - mask: Optional[torch.BoolTensor] = None - True for values that were padded, shape (B, SEQ_LEN, 1), - - Returns - ------- - x: torch.Tensor - (B, SEQ_LEN, OUT_FEATURES) - """ - x = super(MaskedLinear, self).forward(x) - if mask is not None: - x = x.masked_fill(mask, float(self.mask_value)) - return x - - def extra_repr(self): - return ( - f"in_features={self.in_features}, out_features={self.out_features}, " - f"mask_value={self.mask_value}, bias={self.bias is not None}" - ) - - -class MaskedMLP(nn.Module): - """MLP to be applied to tiles to compute scores. - This module can be used in combination of a mask - to prevent padding from influencing the scores values. - Parameters - ---------- - in_features: int - size of each input sample - out_features: int - size of each output sample - hidden: Optional[List[int]] = None - Number of hidden layers and their respective number of features. - bias: bool = True - If set to ``False``, the layer will not learn an additive bias. - activation: torch.nn.Module = torch.nn.Sigmoid() - MLP activation function - dropout: Optional[torch.nn.Module] = None - Optional dropout module. Will be interlaced with the linear layers. - """ - - def __init__( - self, - in_features: int, - out_features: int = 1, - hidden: Optional[List[int]] = None, - bias: bool = True, - activation: nn.Module = nn.Sigmoid(), - dropout: Optional[nn.Module] = None, - ): - super(MaskedMLP, self).__init__() - - if dropout is not None: - assert len(dropout) == len(hidden), "Length of dropout is not correct" - - self.hidden_layers = nn.ModuleList() - if hidden is not None: - for i, h in enumerate(hidden): - self.hidden_layers.append( - MaskedLinear(in_features, h, bias=bias, mask_value="-inf") - ) - self.hidden_layers.append(activation) - if dropout: - self.hidden_layers.append(nn.Dropout(dropout[i])) - in_features = h - - self.hidden_layers.append( - nn.Linear(in_features, out_features, bias=bias) - ) - - def forward( - self, x: torch.Tensor, mask: Optional[torch.BoolTensor] = None - ): - """Forward pass. - - Parameters - ---------- - x: torch.Tensor - (B, N_TILES, IN_FEATURES) - mask: Optional[torch.BoolTensor] = None - (B, N_TILES), True for values that were padded. - - Returns - ------- - x: torch.Tensor - (B, N_TILES, OUT_FEATURES) - """ - for layer in self.hidden_layers: - if isinstance(layer, MaskedLinear): - x = layer(x, mask) - else: - x_before = x.clone().detach() - x = layer(x) - - if torch.any(x.masked_fill(mask, 0).isnan()): - raise RuntimeError(f"Found NaN values in x outside the mask") - - return x - -class SelfAttention(nn.Module): - """Multi-Head Self-Attention. - - Implementation adapted from https://github.com/rwightman/pytorch-image-models. - - Parameters - ---------- - in_features : int - Number of input features. - - num_heads : int = 8 - Number of attention heads. Should be an integer greater or equal to 1. - - qkv_bias : bool = False - Whether to add a bias to the linear projection for query, key and value. - - attn_dropout : float = 0.0 - Dropout rate (applied before the multiplication with the values). - - proj_dropout : float = 0.0 - Dropout rate (applied after the multiplication with the values). - """ - - def __init__( - self, - in_features: int, - num_heads: int = 8, - qkv_bias: bool = False, - attn_dropout: float = 0.0, - proj_dropout: float = 0.0, - ): - super().__init__() - self.in_features = in_features - self.num_heads = num_heads - self.qkv_bias = qkv_bias - self.attn_dropout = attn_dropout - self.proj_dropout = proj_dropout - - self.__build() - - def __build(self): - """Build the `SelfAttention` module.""" - head_dim = self.in_features // self.num_heads - self.scale = head_dim**-0.5 - self.qkv = nn.Linear( - self.in_features, self.in_features * 3, bias=self.qkv_bias - ) - self.attn_drop = nn.Dropout(self.attn_dropout) - self.proj = nn.Linear(self.in_features, self.in_features) - self.proj_drop = nn.Dropout(self.proj_dropout) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass. - - Parameters - ---------- - x : torch.Tensor - Input tensor, shape (B, seq_len, in_features). - - Returns - ------- - out : torch.Tensor - Output tensor, shape (B, seq_len, in_features). - """ - B, N, C = x.shape - qkv = ( - self.qkv(x) - .reshape(B, N, 3, self.num_heads, C // self.num_heads) - .permute(2, 0, 3, 1, 4) - ) - q, k, v = qkv[0], qkv[1], qkv[2] - - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class GatedAttention(nn.Module): - """Gated Attention, as defined in https://arxiv.org/abs/1802.04712. - Permutation invariant Layer on dim 1. - Parameters - ---------- - d_model: int = 128 - temperature: float = 1.0 - Attention Softmax temperature - """ - - def __init__( - self, - d_model: int = 128, - temperature: float = 1.0, - ): - super(GatedAttention, self).__init__() - - self.V = nn.Linear(d_model, d_model) - self.U = nn.Linear(d_model, d_model) - self.w = MaskedLinear(d_model, 1, "-inf") - - self.temperature = temperature - - def attention( - self, - features: torch.Tensor, - mask: Optional[torch.BoolTensor] = None, - ) -> torch.Tensor: - """Gets attention logits. - Parameters - ---------- - v: torch.Tensor - (B, SEQ_LEN, IN_FEATURES) - mask: Optional[torch.BoolTensor] = None - (B, SEQ_LEN, 1), True for values that were padded. - Returns - ------- - attention_logits: torch.Tensor - (B, N_TILES, 1) - """ - h_v = torch.tanh(self.U(features)) - - u_v = torch.sigmoid(self.V(features)) - - attention_logits = self.w(h_v * u_v, mask=mask) / self.temperature - - attention_weights = torch.softmax(attention_logits, 1) - - return attention_weights - - def forward( - self, features: torch.Tensor, mask: Optional[torch.BoolTensor] = None - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Forward pass. - Parameters - ---------- - v: torch.Tensor - (B, SEQ_LEN, IN_FEATURES) - mask: Optional[torch.BoolTensor] = None - (B, SEQ_LEN, 1), True for values that were padded. - Returns - ------- - scaled_attention, attention_weights: Tuple[torch.Tensor, torch.Tensor] - (B, IN_FEATURES), (B, N_TILES, 1) - """ - h_v = torch.tanh(self.U(features)) - - u_v = torch.sigmoid(self.V(features)) - - attention_logits = self.w(h_v * u_v, mask=mask) / self.temperature - - attention_weights = torch.softmax(attention_logits, 1) - # if not torch.any(attention_weights[mask]==0.0): - # raise RuntimeError(f"Masked indices got non-zero weight") - - # features = features.masked_fill(mask, float(0.0)) - scaled_attention = torch.matmul(attention_weights.transpose(1, 2), features) - - return scaled_attention.squeeze(1), attention_weights \ No newline at end of file diff --git a/ahcore/models/layers/attention.py b/ahcore/models/layers/attention.py index ec7d4aa..8f7c3ea 100644 --- a/ahcore/models/layers/attention.py +++ b/ahcore/models/layers/attention.py @@ -3,10 +3,10 @@ import torch from torch import nn -from ahcore.models.layers.MLP import MaskedLinear """Most of this stuff is adapted from utils from https://github.com/owkin/HistoSSLscaling/tree/main""" + class SelfAttention(nn.Module): """Multi-Head Self-Attention. @@ -51,9 +51,7 @@ def __build(self): """Build the `SelfAttention` module.""" head_dim = self.in_features // self.num_heads self.scale = head_dim**-0.5 - self.qkv = nn.Linear( - self.in_features, self.in_features * 3, bias=self.qkv_bias - ) + self.qkv = nn.Linear(self.in_features, self.in_features * 3, bias=self.qkv_bias) self.attn_drop = nn.Dropout(self.attn_dropout) self.proj = nn.Linear(self.in_features, self.in_features) self.proj_drop = nn.Dropout(self.proj_dropout) @@ -72,11 +70,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Output tensor, shape (B, seq_len, in_features). """ B, N, C = x.shape - qkv = ( - self.qkv(x) - .reshape(B, N, 3, self.num_heads, C // self.num_heads) - .permute(2, 0, 3, 1, 4) - ) + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * self.scale @@ -101,54 +95,29 @@ class GatedAttention(nn.Module): def __init__( self, - d_model: int = 128, + dim: int = 128, temperature: float = 1.0, ): super(GatedAttention, self).__init__() - self.V = nn.Linear(d_model, d_model) - self.U = nn.Linear(d_model, d_model) - self.w = MaskedLinear(d_model, 1, "-inf") + self.V = nn.Linear(dim, dim) + self.U = nn.Linear(dim, dim) + self.w = nn.Linear(dim, 1) self.temperature = temperature - def attention( + def forward( self, features: torch.Tensor, - mask: Optional[torch.BoolTensor] = None, - ) -> torch.Tensor: - """Gets attention logits. - Parameters - ---------- - v: torch.Tensor - (B, SEQ_LEN, IN_FEATURES) - mask: Optional[torch.BoolTensor] = None - (B, SEQ_LEN, 1), True for values that were padded. - Returns - ------- - attention_logits: torch.Tensor - (B, N_TILES, 1) - """ - h_v = torch.tanh(self.U(features)) - - u_v = torch.sigmoid(self.V(features)) - - attention_logits = self.w(h_v * u_v, mask=mask) / self.temperature - - attention_weights = torch.softmax(attention_logits, 1) - - return attention_weights - - def forward( - self, features: torch.Tensor, mask: Optional[torch.BoolTensor] = None + return_attention_weights=False, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass. Parameters ---------- - v: torch.Tensor + features: torch.Tensor (B, SEQ_LEN, IN_FEATURES) - mask: Optional[torch.BoolTensor] = None - (B, SEQ_LEN, 1), True for values that were padded. + return_attention_weights: bool = False + Returns ------- scaled_attention, attention_weights: Tuple[torch.Tensor, torch.Tensor] @@ -158,13 +127,13 @@ def forward( u_v = torch.sigmoid(self.V(features)) - attention_logits = self.w(h_v * u_v, mask=mask) / self.temperature + attention_logits = self.w(h_v * u_v) / self.temperature attention_weights = torch.softmax(attention_logits, 1) - # if not torch.any(attention_weights[mask]==0.0): - # raise RuntimeError(f"Masked indices got non-zero weight") - # features = features.masked_fill(mask, float(0.0)) scaled_attention = torch.matmul(attention_weights.transpose(1, 2), features) - return scaled_attention.squeeze(1), attention_weights \ No newline at end of file + if return_attention_weights: + return scaled_attention.squeeze(1), attention_weights + + return scaled_attention.squeeze(1) diff --git a/ahcore/transforms/pre_transforms.py b/ahcore/transforms/pre_transforms.py index a14e9c7..21a2577 100644 --- a/ahcore/transforms/pre_transforms.py +++ b/ahcore/transforms/pre_transforms.py @@ -130,6 +130,7 @@ def __call__(self, sample: DlupDatasetSample) -> DlupDatasetSample: return sample + class LabelToClassIndex: """ Maps label values to class indices according to the index_map specified in the data description. From 22b2ddb60add18e851abd6a8741865f7c93b98e8 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Wed, 21 Aug 2024 17:43:36 +0200 Subject: [PATCH 06/14] fixes some mypy stuff --- ahcore/data/dataset.py | 6 ++---- ahcore/lit_module.py | 2 +- ahcore/models/MIL/ABmil.py | 2 +- ahcore/models/MIL/transmil.py | 6 +++--- ahcore/models/base_jit_model.py | 13 +++++-------- ahcore/models/layers/MLP.py | 4 ++-- ahcore/models/layers/attention.py | 2 +- ahcore/utils/io.py | 2 +- setup.py | 1 + 9 files changed, 17 insertions(+), 21 deletions(-) diff --git a/ahcore/data/dataset.py b/ahcore/data/dataset.py index f4c7d13..ea2d7a9 100644 --- a/ahcore/data/dataset.py +++ b/ahcore/data/dataset.py @@ -88,12 +88,10 @@ def __len__(self) -> int: return self.cumulative_sizes[-1] @overload - def __getitem__(self, index: int) -> DlupDatasetSample: - ... + def __getitem__(self, index: int) -> DlupDatasetSample: ... @overload - def __getitem__(self, index: slice) -> list[DlupDatasetSample]: - ... + def __getitem__(self, index: slice) -> list[DlupDatasetSample]: ... def __getitem__(self, index: Union[int, slice]) -> DlupDatasetSample | list[DlupDatasetSample]: """Returns the sample at the given index.""" diff --git a/ahcore/lit_module.py b/ahcore/lit_module.py index a84c43a..fdb3291 100644 --- a/ahcore/lit_module.py +++ b/ahcore/lit_module.py @@ -59,7 +59,7 @@ def __init__( "loss", ], ) # TODO: we should send the hyperparams to the logger elsewhere - if isinstance(model, BaseAhcoreJitModel) or isinstance(model, transformers.modeling_utils.PretrainedModel): + if isinstance(model, BaseAhcoreJitModel) or isinstance(model, transformers.modeling_utils.PreTrainedModel): self._model = model elif isinstance(model, functools.partial): try: diff --git a/ahcore/models/MIL/ABmil.py b/ahcore/models/MIL/ABmil.py index a4900e7..53533c1 100644 --- a/ahcore/models/MIL/ABmil.py +++ b/ahcore/models/MIL/ABmil.py @@ -141,7 +141,7 @@ def forward( self, features: torch.Tensor, return_attention_weights: bool = False, - ) -> torch.Tensor: + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ Forward pass of the ABMIL model. diff --git a/ahcore/models/MIL/transmil.py b/ahcore/models/MIL/transmil.py index 3026cab..6a7150b 100644 --- a/ahcore/models/MIL/transmil.py +++ b/ahcore/models/MIL/transmil.py @@ -73,9 +73,9 @@ def __init__( padding = residual_conv_kernel // 2 self.res_conv = nn.Conv2d(heads, heads, (kernel_size, 1), padding=(padding, 0), groups=heads, bias=False) - def forward(self, x, mask=None, return_attn=False): - b, n, _, h, m, iters, eps = *x.shape, self.heads, self.num_landmarks, self.pinv_iterations, self.eps - + def forward(self, x: torch.Tensor, mask=None, return_attn=False): + b, n, _ = x.shape + h, m, iters, eps = self.heads, self.num_landmarks, self.pinv_iterations, self.eps # pad so that sequence can be evenly divided into m landmarks remainder = n % m diff --git a/ahcore/models/base_jit_model.py b/ahcore/models/base_jit_model.py index 67071d2..29e9f0c 100644 --- a/ahcore/models/base_jit_model.py +++ b/ahcore/models/base_jit_model.py @@ -1,9 +1,11 @@ from pathlib import Path from typing import Any +import torch from torch.jit import ScriptModule, load from torch import nn + from transformers.modeling_utils import PreTrainedModel @@ -14,24 +16,19 @@ def __init__(self, model: PreTrainedModel, pretrained_model_name_or_path: str, * self.model: model = model.from_pretrained(pretrained_model_name_or_path, **kwargs) - def forward(self, x): + def forward(self, x: dict | torch.Tensor) -> torch.Tensor: model_input = ( x if type(x) is dict else {"pixel_values": x} ) # todo check if huggingface models sometimes other things??? model_output = self.model(**model_input) return model_output.last_hidden_states - def get_attentions(self, x): - model_input = {"pixel_values": x} - model_output = self.model(**model_input) - return model_output.attentions - - def get_raw_output(self, x): + def get_raw_output(self, x: torch.Tensor) -> dict: model_input = {"pixel_values": x} model_output = self.model(**model_input) return model_output - def get_output_at_keys(self, x, keys): + def get_output_at_keys(self, x: torch.Tensor, keys: str | list[str]) -> dict[str, torch.Tensor]: if isinstance(keys, str): keys = [keys] diff --git a/ahcore/models/layers/MLP.py b/ahcore/models/layers/MLP.py index 7a75211..f255602 100644 --- a/ahcore/models/layers/MLP.py +++ b/ahcore/models/layers/MLP.py @@ -46,11 +46,11 @@ def __init__( raise ValueError("hidden must have a value and have the same length as dropout if dropout is given.") d_model = in_features - layers = [] + layers: list[nn.Module] = [] if hidden is not None: for i, h in enumerate(hidden): - seq = [nn.Linear(d_model, h, bias=bias)] + seq: list[nn.Module] = [nn.Linear(d_model, h, bias=bias)] d_model = h if activation is not None: diff --git a/ahcore/models/layers/attention.py b/ahcore/models/layers/attention.py index 8f7c3ea..9bf2f30 100644 --- a/ahcore/models/layers/attention.py +++ b/ahcore/models/layers/attention.py @@ -110,7 +110,7 @@ def forward( self, features: torch.Tensor, return_attention_weights=False, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]: """Forward pass. Parameters ---------- diff --git a/ahcore/utils/io.py b/ahcore/utils/io.py index ae65402..03e7881 100644 --- a/ahcore/utils/io.py +++ b/ahcore/utils/io.py @@ -239,7 +239,7 @@ def load_weights(model: LightningModule, config: DictConfig) -> LightningModule: The model loaded from the checkpoint file. """ _model = getattr(model, "_model") - if isinstance(_model, BaseAhcoreJitModel) or isinstance(_model, transformers.modeling_utils.PretrainedModel): + if isinstance(_model, BaseAhcoreJitModel) or isinstance(_model, transformers.modeling_utils.PreTrainedModel): return model else: # Load checkpoint weights diff --git a/setup.py b/setup.py index bc097fd..32daad0 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ "zarr==2.17.2", "sqlalchemy>=2.0.21", "imageio>=2.34.0", + "transformers>=4.44.1", ] From 93c9aea68ca7292c14de47af5389fc0f5924b8fd Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Thu, 22 Aug 2024 16:40:48 +0200 Subject: [PATCH 07/14] minor bugfixes on input/output tests of the huggingface stuff --- ahcore/models/base_jit_model.py | 2 +- ahcore/transforms/pre_transforms.py | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/ahcore/models/base_jit_model.py b/ahcore/models/base_jit_model.py index 29e9f0c..b10d745 100644 --- a/ahcore/models/base_jit_model.py +++ b/ahcore/models/base_jit_model.py @@ -21,7 +21,7 @@ def forward(self, x: dict | torch.Tensor) -> torch.Tensor: x if type(x) is dict else {"pixel_values": x} ) # todo check if huggingface models sometimes other things??? model_output = self.model(**model_input) - return model_output.last_hidden_states + return model_output.last_hidden_state def get_raw_output(self, x: torch.Tensor) -> dict: model_input = {"pixel_values": x} diff --git a/ahcore/transforms/pre_transforms.py b/ahcore/transforms/pre_transforms.py index 21a2577..3b9de88 100644 --- a/ahcore/transforms/pre_transforms.py +++ b/ahcore/transforms/pre_transforms.py @@ -126,7 +126,14 @@ def __init__(self, pretrained_model_name_or_path: str, **kwargs): def __call__(self, sample: DlupDatasetSample) -> DlupDatasetSample: # Apply the huggingface transforms here - sample["image"]: np.ndarray = self._processor(sample["image"])["pixel_values"] + if isinstance(sample["image"], pyvips.Image): + img = np.array(sample["image"]).numpy() + if isinstance(sample["image"], np.ndarray): + img = sample["image"] + else: + raise ValueError(f"The image must be a pyvips.Image or a numpy array, got {type(sample['image'])}") + + sample["image"]: np.ndarray = self._processor(img)["pixel_values"][0] return sample From 497b0be65df4b057eddca42e1a73c84e7735be0a Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Mon, 2 Sep 2024 12:49:07 +0200 Subject: [PATCH 08/14] remove huggingface transformers (to be added in different pr) and clean up docs --- ahcore/lit_module.py | 3 +- ahcore/models/MIL/ABmil.py | 32 +++--------- ahcore/models/base_jit_model.py | 31 ----------- ahcore/models/layers/attention.py | 80 ----------------------------- ahcore/transforms/pre_transforms.py | 18 ------- ahcore/utils/io.py | 3 +- setup.py | 1 - 7 files changed, 8 insertions(+), 160 deletions(-) diff --git a/ahcore/lit_module.py b/ahcore/lit_module.py index fdb3291..3001501 100644 --- a/ahcore/lit_module.py +++ b/ahcore/lit_module.py @@ -13,7 +13,6 @@ import torch.optim.optimizer from pytorch_lightning.trainer.states import TrainerFn from torch import nn -import transformers from ahcore.exceptions import ConfigurationError from ahcore.metrics import MetricFactory, WSIMetricFactory @@ -59,7 +58,7 @@ def __init__( "loss", ], ) # TODO: we should send the hyperparams to the logger elsewhere - if isinstance(model, BaseAhcoreJitModel) or isinstance(model, transformers.modeling_utils.PreTrainedModel): + if isinstance(model, BaseAhcoreJitModel): self._model = model elif isinstance(model, functools.partial): try: diff --git a/ahcore/models/MIL/ABmil.py b/ahcore/models/MIL/ABmil.py index 53533c1..7976b43 100644 --- a/ahcore/models/MIL/ABmil.py +++ b/ahcore/models/MIL/ABmil.py @@ -13,32 +13,12 @@ class ABMIL(nn.Module): This model is adapted from https://github.com/owkin/HistoSSLscaling/blob/main/rl_benchmarks/models/slide_models/abmil.py. It uses an attention mechanism to aggregate features from multiple instances (tiles) into a single prediction. - Parameters - ---------- - in_features : int - Number of input features for each tile. - out_features : int, optional - Number of output features (typically 1 for binary classification), by default 1. - attention_dimension : int, optional - Dimensionality of the attention mechanism, by default 128. - temperature : float, optional - Temperature parameter for scaling the attention scores, by default 1.0. - embed_mlp_hidden : Optional[List[int]], optional - List of hidden layer sizes for the embedding MLP, by default None. - embed_mlp_dropout : Optional[List[float]], optional - List of dropout rates for the embedding MLP, by default None. - embed_mlp_activation : Optional[torch.nn.Module], optional - Activation function for the embedding MLP, by default nn.ReLU(). - embed_mlp_bias : bool, optional - Whether to include bias in the embedding MLP layers, by default True. - classifier_hidden : Optional[List[int]], optional - List of hidden layer sizes for the classifier MLP, by default [128, 64]. - classifier_dropout : Optional[List[float]], optional - List of dropout rates for the classifier MLP, by default None. - classifier_activation : Optional[torch.nn.Module], optional - Activation function for the classifier MLP, by default nn.ReLU(). - classifier_bias : bool, optional - Whether to include bias in the classifier MLP layers, by default False. + Methods + ------- + get_attention(x: torch.Tensor) -> torch.Tensor + Computes the attention weights for the input features. + forward(features: torch.Tensor, return_attention_weights: bool = False) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor] + Forward pass of the ABMIL model. References ---------- diff --git a/ahcore/models/base_jit_model.py b/ahcore/models/base_jit_model.py index b10d745..94c0f29 100644 --- a/ahcore/models/base_jit_model.py +++ b/ahcore/models/base_jit_model.py @@ -6,37 +6,6 @@ from torch import nn -from transformers.modeling_utils import PreTrainedModel - - -class BaseHuggingfaceModel(nn.Module): - - def __init__(self, model: PreTrainedModel, pretrained_model_name_or_path: str, **kwargs) -> None: - super().__init__() - - self.model: model = model.from_pretrained(pretrained_model_name_or_path, **kwargs) - - def forward(self, x: dict | torch.Tensor) -> torch.Tensor: - model_input = ( - x if type(x) is dict else {"pixel_values": x} - ) # todo check if huggingface models sometimes other things??? - model_output = self.model(**model_input) - return model_output.last_hidden_state - - def get_raw_output(self, x: torch.Tensor) -> dict: - model_input = {"pixel_values": x} - model_output = self.model(**model_input) - return model_output - - def get_output_at_keys(self, x: torch.Tensor, keys: str | list[str]) -> dict[str, torch.Tensor]: - if isinstance(keys, str): - keys = [keys] - - model_input = {"pixel_values": x} - model_output = self.model(**model_input) - return {model_output[key] for key in keys} if len(keys) > 1 else model_output[keys[0]] - - class BaseAhcoreJitModel(ScriptModule): """ Base class for the jit compiled models in Ahcore. diff --git a/ahcore/models/layers/attention.py b/ahcore/models/layers/attention.py index 9bf2f30..0a09fe1 100644 --- a/ahcore/models/layers/attention.py +++ b/ahcore/models/layers/attention.py @@ -3,86 +3,6 @@ import torch from torch import nn - -"""Most of this stuff is adapted from utils from https://github.com/owkin/HistoSSLscaling/tree/main""" - - -class SelfAttention(nn.Module): - """Multi-Head Self-Attention. - - Implementation adapted from https://github.com/rwightman/pytorch-image-models. - - Parameters - ---------- - in_features : int - Number of input features. - - num_heads : int = 8 - Number of attention heads. Should be an integer greater or equal to 1. - - qkv_bias : bool = False - Whether to add a bias to the linear projection for query, key and value. - - attn_dropout : float = 0.0 - Dropout rate (applied before the multiplication with the values). - - proj_dropout : float = 0.0 - Dropout rate (applied after the multiplication with the values). - """ - - def __init__( - self, - in_features: int, - num_heads: int = 8, - qkv_bias: bool = False, - attn_dropout: float = 0.0, - proj_dropout: float = 0.0, - ): - super().__init__() - self.in_features = in_features - self.num_heads = num_heads - self.qkv_bias = qkv_bias - self.attn_dropout = attn_dropout - self.proj_dropout = proj_dropout - - self.__build() - - def __build(self): - """Build the `SelfAttention` module.""" - head_dim = self.in_features // self.num_heads - self.scale = head_dim**-0.5 - self.qkv = nn.Linear(self.in_features, self.in_features * 3, bias=self.qkv_bias) - self.attn_drop = nn.Dropout(self.attn_dropout) - self.proj = nn.Linear(self.in_features, self.in_features) - self.proj_drop = nn.Dropout(self.proj_dropout) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass. - - Parameters - ---------- - x : torch.Tensor - Input tensor, shape (B, seq_len, in_features). - - Returns - ------- - out : torch.Tensor - Output tensor, shape (B, seq_len, in_features). - """ - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] - - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - class GatedAttention(nn.Module): """Gated Attention, as defined in https://arxiv.org/abs/1802.04712. Permutation invariant Layer on dim 1. diff --git a/ahcore/transforms/pre_transforms.py b/ahcore/transforms/pre_transforms.py index 3b9de88..1ca034d 100644 --- a/ahcore/transforms/pre_transforms.py +++ b/ahcore/transforms/pre_transforms.py @@ -119,24 +119,6 @@ def __repr__(self) -> str: return f"PreTransformTaskFactory(transforms={self._transforms})" -class ApplyHuggingfaceTransforms: - - def __init__(self, pretrained_model_name_or_path: str, **kwargs): - self._processor = AutoImageProcessor.from_pretrained(pretrained_model_name_or_path, **kwargs) - - def __call__(self, sample: DlupDatasetSample) -> DlupDatasetSample: - # Apply the huggingface transforms here - if isinstance(sample["image"], pyvips.Image): - img = np.array(sample["image"]).numpy() - if isinstance(sample["image"], np.ndarray): - img = sample["image"] - else: - raise ValueError(f"The image must be a pyvips.Image or a numpy array, got {type(sample['image'])}") - - sample["image"]: np.ndarray = self._processor(img)["pixel_values"][0] - - return sample - class LabelToClassIndex: """ diff --git a/ahcore/utils/io.py b/ahcore/utils/io.py index 03e7881..bc7b2dd 100644 --- a/ahcore/utils/io.py +++ b/ahcore/utils/io.py @@ -28,7 +28,6 @@ from omegaconf.errors import InterpolationKeyError from pytorch_lightning import LightningModule from pytorch_lightning.utilities import rank_zero_only -import transformers from ahcore.models.base_jit_model import BaseAhcoreJitModel @@ -239,7 +238,7 @@ def load_weights(model: LightningModule, config: DictConfig) -> LightningModule: The model loaded from the checkpoint file. """ _model = getattr(model, "_model") - if isinstance(_model, BaseAhcoreJitModel) or isinstance(_model, transformers.modeling_utils.PreTrainedModel): + if isinstance(_model, BaseAhcoreJitModel): return model else: # Load checkpoint weights diff --git a/setup.py b/setup.py index 32daad0..bc097fd 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,6 @@ "zarr==2.17.2", "sqlalchemy>=2.0.21", "imageio>=2.34.0", - "transformers>=4.44.1", ] From 84c645ac264301d58df0da5b30aa7a041bdbc00e Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Wed, 4 Sep 2024 13:05:10 +0200 Subject: [PATCH 09/14] minor changes --- ahcore/models/MIL/ABmil.py | 2 +- ahcore/transforms/pre_transforms.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/ahcore/models/MIL/ABmil.py b/ahcore/models/MIL/ABmil.py index 7976b43..8dd241e 100644 --- a/ahcore/models/MIL/ABmil.py +++ b/ahcore/models/MIL/ABmil.py @@ -114,7 +114,7 @@ def get_attention(self, x: torch.Tensor) -> torch.Tensor: """ tiles_emb = self.embed_mlp(x) - attention_weights = self.attention_layer.attention(tiles_emb) + _, attention_weights = self.attention_layer(tiles_emb, return_attention_weights=True) return attention_weights def forward( diff --git a/ahcore/transforms/pre_transforms.py b/ahcore/transforms/pre_transforms.py index 1ca034d..ae1bdc7 100644 --- a/ahcore/transforms/pre_transforms.py +++ b/ahcore/transforms/pre_transforms.py @@ -19,8 +19,6 @@ from ahcore.utils.io import get_logger from ahcore.utils.types import DlupDatasetSample -from transformers import AutoImageProcessor - PreTransformCallable = Callable[[Any], Any] logger = get_logger(__name__) From 931dc7bfd762b37ed68cde505cc7197692adf1e7 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Wed, 4 Sep 2024 14:28:19 +0200 Subject: [PATCH 10/14] fix pre-commit --- ahcore/models/MIL/ABmil.py | 35 +++------- ahcore/models/MIL/transmil.py | 101 +++++++++++++--------------- ahcore/models/layers/attention.py | 7 +- ahcore/transforms/pre_transforms.py | 7 +- 4 files changed, 59 insertions(+), 91 deletions(-) diff --git a/ahcore/models/MIL/ABmil.py b/ahcore/models/MIL/ABmil.py index 8dd241e..d1523d2 100644 --- a/ahcore/models/MIL/ABmil.py +++ b/ahcore/models/MIL/ABmil.py @@ -1,23 +1,23 @@ -from ahcore.models.layers.MLP import MLP -from ahcore.models.layers.attention import GatedAttention - from typing import List, Optional import torch from torch import nn +from ahcore.models.layers.attention import GatedAttention +from ahcore.models.layers.MLP import MLP + class ABMIL(nn.Module): """ Attention-based MIL (Multiple Instance Learning) classification model (See [1]_). - This model is adapted from https://github.com/owkin/HistoSSLscaling/blob/main/rl_benchmarks/models/slide_models/abmil.py. + This model is adapted from + https://github.com/owkin/HistoSSLscaling/blob/main/rl_benchmarks/models/slide_models/abmil.py. It uses an attention mechanism to aggregate features from multiple instances (tiles) into a single prediction. Methods ------- - get_attention(x: torch.Tensor) -> torch.Tensor - Computes the attention weights for the input features. - forward(features: torch.Tensor, return_attention_weights: bool = False) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor] + forward(features: torch.Tensor, return_attention_weights: bool = False) + -> torch.Tensor | tuple[torch.Tensor, torch.Tensor] Forward pass of the ABMIL model. References @@ -98,25 +98,6 @@ def __init__( activation=classifier_activation, ) - def get_attention(self, x: torch.Tensor) -> torch.Tensor: - """ - Computes the attention weights for the input features. - - Parameters - ---------- - x : torch.Tensor - Input tensor of shape (batch_size, n_tiles, in_features) representing the features of tiles. - - Returns - ------- - torch.Tensor - Attention weights for each tile. - - """ - tiles_emb = self.embed_mlp(x) - _, attention_weights = self.attention_layer(tiles_emb, return_attention_weights=True) - return attention_weights - def forward( self, features: torch.Tensor, @@ -144,7 +125,7 @@ def forward( scaled_tiles_emb, attention_weights = self.attention_layer( tiles_emb, return_attention_weights=True ) # BxN_tilesx128 --> Bx128 - logits = self.classifier(scaled_tiles_emb) # Bx128 --> Bx1 + logits: torch.Tensor = self.classifier(scaled_tiles_emb) # Bx128 --> Bx1 if return_attention_weights: return logits, attention_weights diff --git a/ahcore/models/MIL/transmil.py b/ahcore/models/MIL/transmil.py index 6a7150b..b481c1f 100644 --- a/ahcore/models/MIL/transmil.py +++ b/ahcore/models/MIL/transmil.py @@ -1,26 +1,17 @@ -# this file includes the original nystrom attention and transmil model from https://github.com/lucidrains/nystrom-attention/blob/main/nystrom_attention/nystrom_attention.py and https://github.com/szc19990412/TransMIL/blob/main/models/TransMIL.py, respectively. - +# this file includes the original nystrom attention and transmil model +# from https://github.com/lucidrains/nystrom-attention/blob/main/nystrom_attention/nystrom_attention.py +# and https://github.com/szc19990412/TransMIL/blob/main/models/TransMIL.py, respectively. +from math import ceil +from typing import Any, Optional -import torch -import torch.nn as nn -import torch.nn.functional as F import numpy as np - -from math import ceil import torch -from torch import nn, einsum import torch.nn.functional as F - from einops import rearrange, reduce +from torch import nn as nn -# helper functions - - -def exists(val): - return val is not None - -def moore_penrose_iter_pinv(x, iters=6): +def moore_penrose_iter_pinv(x: torch.Tensor, iters: int = 6) -> torch.Tensor: device = x.device abs_x = torch.abs(x) @@ -28,12 +19,12 @@ def moore_penrose_iter_pinv(x, iters=6): row = abs_x.sum(dim=-2) z = rearrange(x, "... i j -> ... j i") / (torch.max(col) * torch.max(row)) - I = torch.eye(x.shape[-1], device=device) - I = rearrange(I, "i j -> () i j") + eye = torch.eye(x.shape[-1], device=device) + eye = rearrange(eye, "i j -> () i j") for _ in range(iters): xz = x @ z - z = 0.25 * z @ (13 * I - (xz @ (15 * I - (xz @ (7 * I - xz))))) + z = 0.25 * z @ (13 * eye - (xz @ (15 * eye - (xz @ (7 * eye - xz))))) return z @@ -44,16 +35,16 @@ def moore_penrose_iter_pinv(x, iters=6): class NystromAttention(nn.Module): def __init__( self, - dim, - dim_head=64, - heads=8, - num_landmarks=256, - pinv_iterations=6, - residual=True, - residual_conv_kernel=33, - eps=1e-8, - dropout=0.0, - ): + dim: int, + dim_head: int = 64, + heads: int = 8, + num_landmarks: int = 256, + pinv_iterations: int = 6, + residual: bool = True, + residual_conv_kernel: int = 33, + eps: float = 1e-8, + dropout: float = 0.0, + ) -> None: super().__init__() self.eps = eps inner_dim = heads * dim_head @@ -73,8 +64,10 @@ def __init__( padding = residual_conv_kernel // 2 self.res_conv = nn.Conv2d(heads, heads, (kernel_size, 1), padding=(padding, 0), groups=heads, bias=False) - def forward(self, x: torch.Tensor, mask=None, return_attn=False): - b, n, _ = x.shape + def forward( + self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_attn: bool = False + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + b, n, _ = x.shape h, m, iters, eps = self.heads, self.num_landmarks, self.pinv_iterations, self.eps # pad so that sequence can be evenly divided into m landmarks @@ -83,7 +76,7 @@ def forward(self, x: torch.Tensor, mask=None, return_attn=False): padding = m - (n % m) x = F.pad(x, (0, 0, padding, 0), value=0) - if exists(mask): + if mask is not None: mask = F.pad(mask, (padding, 0), value=False) # derive query, keys, values @@ -93,7 +86,7 @@ def forward(self, x: torch.Tensor, mask=None, return_attn=False): # set masked positions to 0 in queries, keys, values - if exists(mask): + if mask is not None: mask = rearrange(mask, "b n -> b () n") q, k, v = map(lambda t: t * mask[..., None], (q, k, v)) @@ -101,18 +94,19 @@ def forward(self, x: torch.Tensor, mask=None, return_attn=False): # generate landmarks by sum reduction, and then calculate mean using the mask - l = ceil(n / m) + l_dim = ceil(n / m) landmark_einops_eq = "... (n l) d -> ... n d" - q_landmarks = reduce(q, landmark_einops_eq, "sum", l=l) - k_landmarks = reduce(k, landmark_einops_eq, "sum", l=l) + q_landmarks = reduce(q, landmark_einops_eq, "sum", l=l_dim) + k_landmarks = reduce(k, landmark_einops_eq, "sum", l=l_dim) # calculate landmark mask, and also get sum of non-masked elements in preparation for masked mean - divisor = l - if exists(mask): - mask_landmarks_sum = reduce(mask, "... (n l) -> ... n", "sum", l=l) + if mask is not None: + mask_landmarks_sum = reduce(mask, "... (n l) -> ... n", "sum", l=l_dim) divisor = mask_landmarks_sum[..., None] + eps mask_landmarks = mask_landmarks_sum > 0 + else: + divisor = torch.Tensor([l_dim]).to(q_landmarks.device) # masked mean (if mask exists) @@ -122,13 +116,13 @@ def forward(self, x: torch.Tensor, mask=None, return_attn=False): # similarities einops_eq = "... i d, ... j d -> ... i j" - sim1 = einsum(einops_eq, q, k_landmarks) - sim2 = einsum(einops_eq, q_landmarks, k_landmarks) - sim3 = einsum(einops_eq, q_landmarks, k) + sim1 = torch.einsum(einops_eq, q, k_landmarks) + sim2 = torch.einsum(einops_eq, q_landmarks, k_landmarks) + sim3 = torch.einsum(einops_eq, q_landmarks, k) # masking - if exists(mask): + if mask is not None: mask_value = -torch.finfo(q.dtype).max sim1.masked_fill_(~(mask[..., None] * mask_landmarks[..., None, :]), mask_value) sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value) @@ -139,7 +133,7 @@ def forward(self, x: torch.Tensor, mask=None, return_attn=False): attn1, attn2, attn3 = map(lambda t: t.softmax(dim=-1), (sim1, sim2, sim3)) attn2_inv = moore_penrose_iter_pinv(attn2, iters) - out = (attn1 @ attn2_inv) @ (attn3 @ v) + out: torch.Tensor = (attn1 @ attn2_inv) @ (attn3 @ v) # add depth-wise conv residual of values @@ -160,8 +154,7 @@ def forward(self, x: torch.Tensor, mask=None, return_attn=False): class TransLayer(nn.Module): - - def __init__(self, norm_layer=nn.LayerNorm, dim=512): + def __init__(self, norm_layer: type = nn.LayerNorm, dim: int = 512) -> None: super().__init__() self.norm = norm_layer(dim) self.attn = NystromAttention( @@ -176,20 +169,20 @@ def __init__(self, norm_layer=nn.LayerNorm, dim=512): dropout=0.1, ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.norm(x)) return x class PPEG(nn.Module): - def __init__(self, dim=512): + def __init__(self, dim: int = 512) -> None: super(PPEG, self).__init__() self.proj = nn.Conv2d(dim, dim, 7, 1, 7 // 2, groups=dim) self.proj1 = nn.Conv2d(dim, dim, 5, 1, 5 // 2, groups=dim) self.proj2 = nn.Conv2d(dim, dim, 3, 1, 3 // 2, groups=dim) - def forward(self, x, H, W): + def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: B, _, C = x.shape cls_token, feat_token = x[:, 0], x[:, 1:] cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W) @@ -200,7 +193,7 @@ def forward(self, x, H, W): class TransMIL(nn.Module): - def __init__(self, n_classes): + def __init__(self, n_classes: int) -> None: super(TransMIL, self).__init__() self.pos_layer = PPEG(dim=512) self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU()) @@ -211,7 +204,7 @@ def __init__(self, n_classes): self.norm = nn.LayerNorm(512) self._fc2 = nn.Linear(512, self.n_classes) - def forward(self, features, **kwargs): + def forward(self, features: torch.Tensor, **kwargs: Any) -> torch.Tensor: h = features # [B, n, 1024] h = self._fc1(h) # [B, n, 512] @@ -240,8 +233,6 @@ def forward(self, features, **kwargs): h = self.norm(h)[:, 0] # ---->predict - logits = self._fc2(h) # [B, n_classes] - Y_hat = torch.argmax(logits, dim=1) - Y_prob = F.softmax(logits, dim=1) - results_dict = {"logits": logits, "Y_prob": Y_prob, "Y_hat": Y_hat} + logits: torch.Tensor = self._fc2(h) # [B, n_classes] + return logits diff --git a/ahcore/models/layers/attention.py b/ahcore/models/layers/attention.py index 0a09fe1..a29ee1c 100644 --- a/ahcore/models/layers/attention.py +++ b/ahcore/models/layers/attention.py @@ -1,8 +1,7 @@ -from typing import Optional, List, Union, Tuple - import torch from torch import nn + class GatedAttention(nn.Module): """Gated Attention, as defined in https://arxiv.org/abs/1802.04712. Permutation invariant Layer on dim 1. @@ -29,8 +28,8 @@ def __init__( def forward( self, features: torch.Tensor, - return_attention_weights=False, - ) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]: + return_attention_weights: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """Forward pass. Parameters ---------- diff --git a/ahcore/transforms/pre_transforms.py b/ahcore/transforms/pre_transforms.py index ae1bdc7..2430778 100644 --- a/ahcore/transforms/pre_transforms.py +++ b/ahcore/transforms/pre_transforms.py @@ -117,7 +117,6 @@ def __repr__(self) -> str: return f"PreTransformTaskFactory(transforms={self._transforms})" - class LabelToClassIndex: """ Maps label values to class indices according to the index_map specified in the data description. @@ -217,14 +216,12 @@ class ImageToTensor: """ def __call__(self, sample: DlupDatasetSample) -> dict[str, DlupDatasetSample]: - tile: pyvips.Image | np.ndarray = sample["image"] + tile: pyvips.Image = sample["image"] # Flatten the image to remove the alpha channel, using white as the background color tile_ = tile.flatten(background=[255, 255, 255]) # Convert VIPS image to a numpy array then to a torch tensor - if type(tile_) == pyvips.Image: - np_image = tile_.numpy() - + np_image = tile_.numpy() sample["image"] = torch.from_numpy(np_image).permute(2, 0, 1).float() if sample["image"].sum() == 0: From e3ad333d3db43cf7b43b1d299534e096aab28c20 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Wed, 4 Sep 2024 17:07:58 +0200 Subject: [PATCH 11/14] added tests --- ahcore/models/MIL/transmil.py | 6 +++--- tests/test_models/test_models.py | 22 ++++++++++++++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) create mode 100644 tests/test_models/test_models.py diff --git a/ahcore/models/MIL/transmil.py b/ahcore/models/MIL/transmil.py index b481c1f..7870905 100644 --- a/ahcore/models/MIL/transmil.py +++ b/ahcore/models/MIL/transmil.py @@ -193,10 +193,10 @@ def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: class TransMIL(nn.Module): - def __init__(self, n_classes: int) -> None: + def __init__(self, in_features: int = 1024, n_classes: int = 1) -> None: super(TransMIL, self).__init__() self.pos_layer = PPEG(dim=512) - self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU()) + self._fc1 = nn.Sequential(nn.Linear(in_features, 512), nn.ReLU()) self.cls_token = nn.Parameter(torch.randn(1, 1, 512)) self.n_classes = n_classes self.layer1 = TransLayer(dim=512) @@ -217,7 +217,7 @@ def forward(self, features: torch.Tensor, **kwargs: Any) -> torch.Tensor: # ---->cls_token B = h.shape[0] - cls_tokens = self.cls_token.expand(B, -1, -1).cuda() + cls_tokens = self.cls_token.expand(B, -1, -1).to(h.device) h = torch.cat((cls_tokens, h), dim=1) # ---->Translayer x1 diff --git a/tests/test_models/test_models.py b/tests/test_models/test_models.py new file mode 100644 index 0000000..edf8f21 --- /dev/null +++ b/tests/test_models/test_models.py @@ -0,0 +1,22 @@ +import pytest +import torch + +from ahcore.models.MIL.ABmil import ABMIL +from ahcore.models.MIL.transmil import TransMIL + + +@pytest.fixture +def input_data(B: int = 16, N_tiles: int = 1000, feature_dim: int = 768) -> torch.Tensor: + return torch.randn(B, N_tiles, feature_dim) + + +def test_ABmil_shape(input_data: torch.Tensor) -> None: + model = ABMIL(in_features=768) + output = model(input_data) + assert output.shape == (16, 1) + + +def test_TransMIL_shape(input_data: torch.Tensor) -> None: + model = TransMIL(in_features=768, n_classes=2) + output = model(input_data) + assert output.shape == (16, 2) From c8ae15414a47526c720b7b3546ee1f00dbe9daa4 Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Thu, 5 Sep 2024 14:31:26 +0200 Subject: [PATCH 12/14] added test and put nystrom attention in the attention file --- ahcore/models/MIL/transmil.py | 180 ++++-------------------------- ahcore/models/layers/attention.py | 147 ++++++++++++++++++++++++ tests/test_models/test_models.py | 34 +++++- 3 files changed, 199 insertions(+), 162 deletions(-) diff --git a/ahcore/models/MIL/transmil.py b/ahcore/models/MIL/transmil.py index 7870905..c8a3a5c 100644 --- a/ahcore/models/MIL/transmil.py +++ b/ahcore/models/MIL/transmil.py @@ -1,156 +1,14 @@ # this file includes the original nystrom attention and transmil model # from https://github.com/lucidrains/nystrom-attention/blob/main/nystrom_attention/nystrom_attention.py # and https://github.com/szc19990412/TransMIL/blob/main/models/TransMIL.py, respectively. -from math import ceil -from typing import Any, Optional + +from typing import Any import numpy as np import torch -import torch.nn.functional as F -from einops import rearrange, reduce from torch import nn as nn - -def moore_penrose_iter_pinv(x: torch.Tensor, iters: int = 6) -> torch.Tensor: - device = x.device - - abs_x = torch.abs(x) - col = abs_x.sum(dim=-1) - row = abs_x.sum(dim=-2) - z = rearrange(x, "... i j -> ... j i") / (torch.max(col) * torch.max(row)) - - eye = torch.eye(x.shape[-1], device=device) - eye = rearrange(eye, "i j -> () i j") - - for _ in range(iters): - xz = x @ z - z = 0.25 * z @ (13 * eye - (xz @ (15 * eye - (xz @ (7 * eye - xz))))) - - return z - - -# main attention class - - -class NystromAttention(nn.Module): - def __init__( - self, - dim: int, - dim_head: int = 64, - heads: int = 8, - num_landmarks: int = 256, - pinv_iterations: int = 6, - residual: bool = True, - residual_conv_kernel: int = 33, - eps: float = 1e-8, - dropout: float = 0.0, - ) -> None: - super().__init__() - self.eps = eps - inner_dim = heads * dim_head - - self.num_landmarks = num_landmarks - self.pinv_iterations = pinv_iterations - - self.heads = heads - self.scale = dim_head**-0.5 - self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) - - self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) - - self.residual = residual - if residual: - kernel_size = residual_conv_kernel - padding = residual_conv_kernel // 2 - self.res_conv = nn.Conv2d(heads, heads, (kernel_size, 1), padding=(padding, 0), groups=heads, bias=False) - - def forward( - self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_attn: bool = False - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - b, n, _ = x.shape - h, m, iters, eps = self.heads, self.num_landmarks, self.pinv_iterations, self.eps - # pad so that sequence can be evenly divided into m landmarks - - remainder = n % m - if remainder > 0: - padding = m - (n % m) - x = F.pad(x, (0, 0, padding, 0), value=0) - - if mask is not None: - mask = F.pad(mask, (padding, 0), value=False) - - # derive query, keys, values - - q, k, v = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - - # set masked positions to 0 in queries, keys, values - - if mask is not None: - mask = rearrange(mask, "b n -> b () n") - q, k, v = map(lambda t: t * mask[..., None], (q, k, v)) - - q = q * self.scale - - # generate landmarks by sum reduction, and then calculate mean using the mask - - l_dim = ceil(n / m) - landmark_einops_eq = "... (n l) d -> ... n d" - q_landmarks = reduce(q, landmark_einops_eq, "sum", l=l_dim) - k_landmarks = reduce(k, landmark_einops_eq, "sum", l=l_dim) - - # calculate landmark mask, and also get sum of non-masked elements in preparation for masked mean - - if mask is not None: - mask_landmarks_sum = reduce(mask, "... (n l) -> ... n", "sum", l=l_dim) - divisor = mask_landmarks_sum[..., None] + eps - mask_landmarks = mask_landmarks_sum > 0 - else: - divisor = torch.Tensor([l_dim]).to(q_landmarks.device) - - # masked mean (if mask exists) - - q_landmarks = q_landmarks / divisor - k_landmarks = k_landmarks / divisor - - # similarities - - einops_eq = "... i d, ... j d -> ... i j" - sim1 = torch.einsum(einops_eq, q, k_landmarks) - sim2 = torch.einsum(einops_eq, q_landmarks, k_landmarks) - sim3 = torch.einsum(einops_eq, q_landmarks, k) - - # masking - - if mask is not None: - mask_value = -torch.finfo(q.dtype).max - sim1.masked_fill_(~(mask[..., None] * mask_landmarks[..., None, :]), mask_value) - sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value) - sim3.masked_fill_(~(mask_landmarks[..., None] * mask[..., None, :]), mask_value) - - # eq (15) in the paper and aggregate values - - attn1, attn2, attn3 = map(lambda t: t.softmax(dim=-1), (sim1, sim2, sim3)) - attn2_inv = moore_penrose_iter_pinv(attn2, iters) - - out: torch.Tensor = (attn1 @ attn2_inv) @ (attn3 @ v) - - # add depth-wise conv residual of values - - if self.residual: - out = out + self.res_conv(v) - - # merge and combine heads - - out = rearrange(out, "b h n d -> b n (h d)", h=h) - out = self.to_out(out) - out = out[:, -n:] - - if return_attn: - attn = attn1 @ attn2_inv @ attn3 - return out, attn - - return out +from ahcore.models.layers.attention import NystromAttention class TransLayer(nn.Module): @@ -193,27 +51,27 @@ def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: class TransMIL(nn.Module): - def __init__(self, in_features: int = 1024, n_classes: int = 1) -> None: + def __init__(self, in_features: int = 1024, out_features: int = 1, hidden_dimension: int = 512) -> None: super(TransMIL, self).__init__() - self.pos_layer = PPEG(dim=512) - self._fc1 = nn.Sequential(nn.Linear(in_features, 512), nn.ReLU()) - self.cls_token = nn.Parameter(torch.randn(1, 1, 512)) - self.n_classes = n_classes - self.layer1 = TransLayer(dim=512) - self.layer2 = TransLayer(dim=512) - self.norm = nn.LayerNorm(512) - self._fc2 = nn.Linear(512, self.n_classes) + self.pos_layer = PPEG(dim=hidden_dimension) + self._fc1 = nn.Sequential(nn.Linear(in_features, hidden_dimension), nn.ReLU()) + self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dimension)) + self.n_classes = out_features + self.layer1 = TransLayer(dim=hidden_dimension) + self.layer2 = TransLayer(dim=hidden_dimension) + self.norm = nn.LayerNorm(hidden_dimension) + self._fc2 = nn.Linear(hidden_dimension, self.n_classes) def forward(self, features: torch.Tensor, **kwargs: Any) -> torch.Tensor: - h = features # [B, n, 1024] + h = features # [B, n, in_features] - h = self._fc1(h) # [B, n, 512] + h = self._fc1(h) # [B, n, hidden_dimension] # ---->pad H = h.shape[1] _H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H))) add_length = _H * _W - H - h = torch.cat([h, h[:, :add_length, :]], dim=1) # [B, N, 512] + h = torch.cat([h, h[:, :add_length, :]], dim=1) # [B, N, hidden_dimension] # ---->cls_token B = h.shape[0] @@ -221,18 +79,18 @@ def forward(self, features: torch.Tensor, **kwargs: Any) -> torch.Tensor: h = torch.cat((cls_tokens, h), dim=1) # ---->Translayer x1 - h = self.layer1(h) # [B, N, 512] + h = self.layer1(h) # [B, N, hidden_dimension] # ---->PPEG - h = self.pos_layer(h, _H, _W) # [B, N, 512] + h = self.pos_layer(h, _H, _W) # [B, N, hidden_dimension] # ---->Translayer x2 - h = self.layer2(h) # [B, N, 512] + h = self.layer2(h) # [B, N, hidden_dimension] # ---->cls_token h = self.norm(h)[:, 0] # ---->predict - logits: torch.Tensor = self._fc2(h) # [B, n_classes] + logits: torch.Tensor = self._fc2(h) # [B, out_features] return logits diff --git a/ahcore/models/layers/attention.py b/ahcore/models/layers/attention.py index a29ee1c..369c574 100644 --- a/ahcore/models/layers/attention.py +++ b/ahcore/models/layers/attention.py @@ -1,4 +1,9 @@ +import math +from typing import Optional + import torch +import torch.nn.functional as F +from einops import rearrange, reduce from torch import nn @@ -56,3 +61,145 @@ def forward( return scaled_attention.squeeze(1), attention_weights return scaled_attention.squeeze(1) + + +def moore_penrose_iter_pinv(x: torch.Tensor, iters: int = 6) -> torch.Tensor: + device = x.device + + abs_x = torch.abs(x) + col = abs_x.sum(dim=-1) + row = abs_x.sum(dim=-2) + z = rearrange(x, "... i j -> ... j i") / (torch.max(col) * torch.max(row)) + + eye = torch.eye(x.shape[-1], device=device) + eye = rearrange(eye, "i j -> () i j") + + for _ in range(iters): + xz = x @ z + z = 0.25 * z @ (13 * eye - (xz @ (15 * eye - (xz @ (7 * eye - xz))))) + + return z + + +# main attention class + + +class NystromAttention(nn.Module): + def __init__( + self, + dim: int, + dim_head: int = 64, + heads: int = 8, + num_landmarks: int = 256, + pinv_iterations: int = 6, + residual: bool = True, + residual_conv_kernel: int = 33, + eps: float = 1e-8, + dropout: float = 0.0, + ) -> None: + super().__init__() + self.eps = eps + inner_dim = heads * dim_head + + self.num_landmarks = num_landmarks + self.pinv_iterations = pinv_iterations + + self.heads = heads + self.scale = dim_head**-0.5 + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + + self.residual = residual + if residual: + kernel_size = residual_conv_kernel + padding = residual_conv_kernel // 2 + self.res_conv = nn.Conv2d(heads, heads, (kernel_size, 1), padding=(padding, 0), groups=heads, bias=False) + + def forward( + self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_attn: bool = False + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + b, n, _ = x.shape + h, m, iters, eps = self.heads, self.num_landmarks, self.pinv_iterations, self.eps + # pad so that sequence can be evenly divided into m landmarks + + remainder = n % m + if remainder > 0: + padding = m - (n % m) + x = F.pad(x, (0, 0, padding, 0), value=0) + + if mask is not None: + mask = F.pad(mask, (padding, 0), value=False) + + # derive query, keys, values + + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + # set masked positions to 0 in queries, keys, values + + if mask is not None: + mask = rearrange(mask, "b n -> b () n") + q, k, v = map(lambda t: t * mask[..., None], (q, k, v)) + + q = q * self.scale + + # generate landmarks by sum reduction, and then calculate mean using the mask + + l_dim = math.ceil(n / m) + landmark_einops_eq = "... (n l) d -> ... n d" + q_landmarks = reduce(q, landmark_einops_eq, "sum", l=l_dim) + k_landmarks = reduce(k, landmark_einops_eq, "sum", l=l_dim) + + # calculate landmark mask, and also get sum of non-masked elements in preparation for masked mean + + if mask is not None: + mask_landmarks_sum = reduce(mask, "... (n l) -> ... n", "sum", l=l_dim) + divisor = mask_landmarks_sum[..., None] + eps + mask_landmarks = mask_landmarks_sum > 0 + else: + divisor = torch.Tensor([l_dim]).to(q_landmarks.device) + + # masked mean (if mask exists) + + q_landmarks = q_landmarks / divisor + k_landmarks = k_landmarks / divisor + + # similarities + + einops_eq = "... i d, ... j d -> ... i j" + sim1 = torch.einsum(einops_eq, q, k_landmarks) + sim2 = torch.einsum(einops_eq, q_landmarks, k_landmarks) + sim3 = torch.einsum(einops_eq, q_landmarks, k) + + # masking + + if mask is not None: + mask_value = -torch.finfo(q.dtype).max + sim1.masked_fill_(~(mask[..., None] * mask_landmarks[..., None, :]), mask_value) + sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value) + sim3.masked_fill_(~(mask_landmarks[..., None] * mask[..., None, :]), mask_value) + + # eq (15) in the paper and aggregate values + + attn1, attn2, attn3 = map(lambda t: t.softmax(dim=-1), (sim1, sim2, sim3)) + attn2_inv = moore_penrose_iter_pinv(attn2, iters) + + out: torch.Tensor = (attn1 @ attn2_inv) @ (attn3 @ v) + + # add depth-wise conv residual of values + + if self.residual: + out = out + self.res_conv(v) + + # merge and combine heads + + out = rearrange(out, "b h n d -> b n (h d)", h=h) + out = self.to_out(out) + out = out[:, -n:] + + if return_attn: + attn = attn1 @ attn2_inv @ attn3 + return out, attn + + return out diff --git a/tests/test_models/test_models.py b/tests/test_models/test_models.py index edf8f21..9f3b487 100644 --- a/tests/test_models/test_models.py +++ b/tests/test_models/test_models.py @@ -1,6 +1,8 @@ import pytest import torch +from ahcore.models.layers.attention import GatedAttention, NystromAttention +from ahcore.models.layers.MLP import MLP from ahcore.models.MIL.ABmil import ABMIL from ahcore.models.MIL.transmil import TransMIL @@ -15,8 +17,38 @@ def test_ABmil_shape(input_data: torch.Tensor) -> None: output = model(input_data) assert output.shape == (16, 1) + output, attentions = model(input_data, return_attention_weights=True) + assert output.shape == (16, 1) + assert attentions.shape == (16, 1000, 1) + def test_TransMIL_shape(input_data: torch.Tensor) -> None: - model = TransMIL(in_features=768, n_classes=2) + model = TransMIL(in_features=768, out_features=2) output = model(input_data) assert output.shape == (16, 2) + + +def test_MLP_shape(input_data: torch.Tensor) -> None: + model = MLP(in_features=768, out_features=2, hidden=[128], dropout=[0.1]) + output = model(input_data) + assert output.shape == (16, 1000, 2) + + +def test_MLP_hidden_dropout() -> None: + with pytest.raises(ValueError): + MLP(in_features=768, out_features=2, hidden=None, dropout=[0.1]) + + +def test_attention_shape(input_data: torch.Tensor) -> None: + model = GatedAttention(dim=768) + output = model(input_data) + assert output.shape == (16, 768) + + +def test_nystrom_att_with_mask(input_data: torch.Tensor) -> None: + model = NystromAttention( + dim=768, dim_head=768 // 8, heads=8, num_landmarks=1, pinv_iterations=6, residual=True, dropout=0.1 + ) + output, attn = model(input_data, mask=torch.ones_like(input_data, dtype=torch.bool)[:, :, 0], return_attn=True) + assert output.shape == (16, 1000, 768) + assert attn.shape == (16, 8, 1000, 1000) From b25881d8dbbdc8505b3b384d38c08394322504ec Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Thu, 5 Sep 2024 14:43:15 +0200 Subject: [PATCH 13/14] out_features --> num_classes --- ahcore/models/MIL/ABmil.py | 4 ++-- ahcore/models/MIL/transmil.py | 4 ++-- tests/test_models/test_models.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ahcore/models/MIL/ABmil.py b/ahcore/models/MIL/ABmil.py index d1523d2..056f4f1 100644 --- a/ahcore/models/MIL/ABmil.py +++ b/ahcore/models/MIL/ABmil.py @@ -33,7 +33,7 @@ class ABMIL(nn.Module): def __init__( self, in_features: int, - out_features: int = 1, + num_classes: int = 1, attention_dimension: int = 128, temperature: float = 1.0, embed_mlp_hidden: Optional[List[int]] = None, @@ -91,7 +91,7 @@ def __init__( self.classifier = MLP( in_features=attention_dimension, - out_features=out_features, + out_features=num_classes, bias=classifier_bias, hidden=classifier_hidden, dropout=classifier_dropout, diff --git a/ahcore/models/MIL/transmil.py b/ahcore/models/MIL/transmil.py index c8a3a5c..27cef7e 100644 --- a/ahcore/models/MIL/transmil.py +++ b/ahcore/models/MIL/transmil.py @@ -51,12 +51,12 @@ def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: class TransMIL(nn.Module): - def __init__(self, in_features: int = 1024, out_features: int = 1, hidden_dimension: int = 512) -> None: + def __init__(self, in_features: int = 1024, num_classes: int = 1, hidden_dimension: int = 512) -> None: super(TransMIL, self).__init__() self.pos_layer = PPEG(dim=hidden_dimension) self._fc1 = nn.Sequential(nn.Linear(in_features, hidden_dimension), nn.ReLU()) self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dimension)) - self.n_classes = out_features + self.n_classes = num_classes self.layer1 = TransLayer(dim=hidden_dimension) self.layer2 = TransLayer(dim=hidden_dimension) self.norm = nn.LayerNorm(hidden_dimension) diff --git a/tests/test_models/test_models.py b/tests/test_models/test_models.py index 9f3b487..139207b 100644 --- a/tests/test_models/test_models.py +++ b/tests/test_models/test_models.py @@ -23,7 +23,7 @@ def test_ABmil_shape(input_data: torch.Tensor) -> None: def test_TransMIL_shape(input_data: torch.Tensor) -> None: - model = TransMIL(in_features=768, out_features=2) + model = TransMIL(in_features=768, num_classes=2) output = model(input_data) assert output.shape == (16, 2) From 0efb015f9b8a19614d23a888b7706e355e25badf Mon Sep 17 00:00:00 2001 From: JorenB Date: Thu, 5 Sep 2024 22:12:01 +0200 Subject: [PATCH 14/14] fix failing pre-commit checks --- ahcore/data/dataset.py | 6 ++++-- ahcore/models/base_jit_model.py | 3 +-- ahcore/models/layers/MLP.py | 3 +-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/ahcore/data/dataset.py b/ahcore/data/dataset.py index b0b932f..3eaed15 100644 --- a/ahcore/data/dataset.py +++ b/ahcore/data/dataset.py @@ -88,10 +88,12 @@ def __len__(self) -> int: return self.cumulative_sizes[-1] @overload - def __getitem__(self, index: int) -> DlupDatasetSample: ... + def __getitem__(self, index: int) -> DlupDatasetSample: + ... @overload - def __getitem__(self, index: slice) -> list[DlupDatasetSample]: ... + def __getitem__(self, index: slice) -> list[DlupDatasetSample]: + ... def __getitem__(self, index: Union[int, slice]) -> DlupDatasetSample | list[DlupDatasetSample]: """Returns the sample at the given index.""" diff --git a/ahcore/models/base_jit_model.py b/ahcore/models/base_jit_model.py index 94c0f29..502fb48 100644 --- a/ahcore/models/base_jit_model.py +++ b/ahcore/models/base_jit_model.py @@ -1,9 +1,8 @@ from pathlib import Path from typing import Any -import torch -from torch.jit import ScriptModule, load from torch import nn +from torch.jit import ScriptModule, load class BaseAhcoreJitModel(ScriptModule): diff --git a/ahcore/models/layers/MLP.py b/ahcore/models/layers/MLP.py index f255602..3bffe41 100644 --- a/ahcore/models/layers/MLP.py +++ b/ahcore/models/layers/MLP.py @@ -1,6 +1,5 @@ -from typing import Optional, List, Union, Tuple +from typing import List, Optional -import torch from torch import nn """Most of this stuff is adapted from utils from https://github.com/owkin/HistoSSLscaling/tree/main"""