Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
out_features --> num_classes
Browse files Browse the repository at this point in the history
  • Loading branch information
moerlemans committed Sep 5, 2024
1 parent c8ae154 commit b25881d
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions ahcore/models/MIL/ABmil.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ABMIL(nn.Module):
def __init__(
self,
in_features: int,
out_features: int = 1,
num_classes: int = 1,
attention_dimension: int = 128,
temperature: float = 1.0,
embed_mlp_hidden: Optional[List[int]] = None,
Expand Down Expand Up @@ -91,7 +91,7 @@ def __init__(

self.classifier = MLP(
in_features=attention_dimension,
out_features=out_features,
out_features=num_classes,
bias=classifier_bias,
hidden=classifier_hidden,
dropout=classifier_dropout,
Expand Down
4 changes: 2 additions & 2 deletions ahcore/models/MIL/transmil.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:


class TransMIL(nn.Module):
def __init__(self, in_features: int = 1024, out_features: int = 1, hidden_dimension: int = 512) -> None:
def __init__(self, in_features: int = 1024, num_classes: int = 1, hidden_dimension: int = 512) -> None:
super(TransMIL, self).__init__()
self.pos_layer = PPEG(dim=hidden_dimension)
self._fc1 = nn.Sequential(nn.Linear(in_features, hidden_dimension), nn.ReLU())
self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dimension))
self.n_classes = out_features
self.n_classes = num_classes
self.layer1 = TransLayer(dim=hidden_dimension)
self.layer2 = TransLayer(dim=hidden_dimension)
self.norm = nn.LayerNorm(hidden_dimension)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_ABmil_shape(input_data: torch.Tensor) -> None:


def test_TransMIL_shape(input_data: torch.Tensor) -> None:
model = TransMIL(in_features=768, out_features=2)
model = TransMIL(in_features=768, num_classes=2)
output = model(input_data)
assert output.shape == (16, 2)

Expand Down

0 comments on commit b25881d

Please sign in to comment.