From b25881d8dbbdc8505b3b384d38c08394322504ec Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Thu, 5 Sep 2024 14:43:15 +0200 Subject: [PATCH] out_features --> num_classes --- ahcore/models/MIL/ABmil.py | 4 ++-- ahcore/models/MIL/transmil.py | 4 ++-- tests/test_models/test_models.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ahcore/models/MIL/ABmil.py b/ahcore/models/MIL/ABmil.py index d1523d2..056f4f1 100644 --- a/ahcore/models/MIL/ABmil.py +++ b/ahcore/models/MIL/ABmil.py @@ -33,7 +33,7 @@ class ABMIL(nn.Module): def __init__( self, in_features: int, - out_features: int = 1, + num_classes: int = 1, attention_dimension: int = 128, temperature: float = 1.0, embed_mlp_hidden: Optional[List[int]] = None, @@ -91,7 +91,7 @@ def __init__( self.classifier = MLP( in_features=attention_dimension, - out_features=out_features, + out_features=num_classes, bias=classifier_bias, hidden=classifier_hidden, dropout=classifier_dropout, diff --git a/ahcore/models/MIL/transmil.py b/ahcore/models/MIL/transmil.py index c8a3a5c..27cef7e 100644 --- a/ahcore/models/MIL/transmil.py +++ b/ahcore/models/MIL/transmil.py @@ -51,12 +51,12 @@ def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: class TransMIL(nn.Module): - def __init__(self, in_features: int = 1024, out_features: int = 1, hidden_dimension: int = 512) -> None: + def __init__(self, in_features: int = 1024, num_classes: int = 1, hidden_dimension: int = 512) -> None: super(TransMIL, self).__init__() self.pos_layer = PPEG(dim=hidden_dimension) self._fc1 = nn.Sequential(nn.Linear(in_features, hidden_dimension), nn.ReLU()) self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dimension)) - self.n_classes = out_features + self.n_classes = num_classes self.layer1 = TransLayer(dim=hidden_dimension) self.layer2 = TransLayer(dim=hidden_dimension) self.norm = nn.LayerNorm(hidden_dimension) diff --git a/tests/test_models/test_models.py b/tests/test_models/test_models.py index 9f3b487..139207b 100644 --- a/tests/test_models/test_models.py +++ b/tests/test_models/test_models.py @@ -23,7 +23,7 @@ def test_ABmil_shape(input_data: torch.Tensor) -> None: def test_TransMIL_shape(input_data: torch.Tensor) -> None: - model = TransMIL(in_features=768, out_features=2) + model = TransMIL(in_features=768, num_classes=2) output = model(input_data) assert output.shape == (16, 2)