Skip to content

Commit

Permalink
few more comments on tensor shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
liopeer committed Jan 6, 2025
1 parent d79ba68 commit b0bff0f
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions lightly/loss/detcon_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def forward(
target_view0_large = target_view0
target_view1_large = target_view1
labels_local = torch.eye(b, device=pred_view0.device)
enlarged_b = b
else:
target_view0_large = torch.cat(dist.gather(target_view0), dim=0)
target_view1_large = torch.cat(dist.gather(target_view1), dim=0)
Expand All @@ -196,11 +197,10 @@ def forward(
### Expand Labels ###
# labels_local at this point only points towards the diagonal of the batch, i.e.
# indicates to compare between the same samples across views.
labels_local = labels_local[:, None, :, None]
# assert labels_local.size() == (b, 1, b, 1)
labels_local = labels_local[:, None, :, None] # (b, 1, b * world_size, 1)

### Calculate Similarity Matrices ###
# tensors of shape (b, m, b, m), indicating similarities between regions across
# tensors of shape (b, m, b * world_size, m), indicating similarities between regions across
# views and samples in the batch
logits_aa = (
torch.einsum("abk,uvk->abuv", pred_view0, target_view0_large)
Expand Down Expand Up @@ -228,8 +228,10 @@ def forward(
### Remove Similarities Between Corresponding Views But Different Regions ###
# labels_local initially compared all features across views, but we only want to
# compare the same regions across views.
labels_aa = labels_local * same_mask_aa # (b, 1, b, 1) * (b, m, 1, m)
labels_bb = labels_local * same_mask_bb
labels_aa = (
labels_local * same_mask_aa
) # (b, 1, b * world_size, 1) * (b, m, 1, m)
labels_bb = labels_local * same_mask_bb # yielding (b, m, b * world_size, m)
labels_ab = labels_local * same_mask_ab
labels_ba = labels_local * same_mask_ba

Expand Down

0 comments on commit b0bff0f

Please sign in to comment.