diff --git a/lightly/loss/detcon_loss.py b/lightly/loss/detcon_loss.py index 0b563230b..8b5596690 100644 --- a/lightly/loss/detcon_loss.py +++ b/lightly/loss/detcon_loss.py @@ -301,8 +301,8 @@ def _same_mask(mask0: Tensor, mask1: Tensor) -> Tensor: dimensions is for the regions/masks of the first view and the last :math:`M` dimensions is for the regions/masks of the second view. For every sample :math:`k` in the batch (separately), the tensor is effectively a 2D index - matrix where the entry :math:`(k, i, :, j)` is 1 if the masks :math:`m_i` - and :math:`m_j'` are the same and 0 otherwise. + matrix where the entry :math:`(k, i, :, j)` is 1 if the masks :math:`mask0(k, i)` + and :math:`mask1(k, j)'` are the same and 0 otherwise. """ return (mask0[:, :, None] == mask1[:, None, :]).float()[:, :, None, :]