Skip to content

Commit

Permalink
Add DetConSLoss and DetConBLoss (#1771)
Browse files Browse the repository at this point in the history
* allow access to v2 transforms availability from anywhere

* import unnecessary after exception

Co-authored-by: guarin <[email protected]>

* reformat

* add implementation of AddGridTransform

* make AddGridTransform importable

* add tests for AddGridTransform

* reformat

* enhance docstring

* fix typing issues

* fix import when transforms.v2 not available

* add additional type ignore for 3.7 compatibility

* reformat

* add transform to docs

* change header

* add explanation on data structures

* use kw args

* add assertion for mask dimension to be geq 2

* remove unnecessary fixtures

* make argument order consistent

* add DetCon SingleView and MultiView transforms

* add MultiViewTransform BaseClass for v2 transforms

* add tests for DetCon transform

* export all newly added transforms

* add DetCon single view and DetCon multi view transforms

* add torchvision transforms v2 compatible MultiViewTransforms

* make newly added transforms public

* remove unnecessary fixtures

* add tests for DetCon transform

* remove wrongfully added files

* add DetCon transform and MultiView transforms for v2 to docs

* fix docs references

* fix import issues for minimal dependencies

* fixing code format

* add test for multiviewtransformv2

* fix testing of multiview

* use singular AddGridTransforms

* consistent naming to DetConS

* adjust docstring reference numbering

* name refactoring

* start detconloss implementation

* implement detconloss

* test detconloss

* make detconloss public

* add detconloss to docs

* Update lightly/loss/detcon_loss.py

Co-authored-by: guarin <[email protected]>

* initial small fixes

* add comments and avoid 0 division

* remove labels_ext

* remove labels_ext

* revert normalization; some formatting

* complete rewrite of tests

* fix typing issues

* remove unused imports

* Update lightly/loss/detcon_loss.py

Co-authored-by: guarin <[email protected]>

* move to f-strings

* create test classes

* squeeze instead of 0 indexing

* formatting

* few more comments on tensor shapes

* additional comments on tensor shapes

---------

Co-authored-by: guarin <[email protected]>
  • Loading branch information
liopeer and guarin authored Jan 6, 2025
1 parent 356ae56 commit 0dcf780
Show file tree
Hide file tree
Showing 4 changed files with 608 additions and 0 deletions.
6 changes: 6 additions & 0 deletions docs/source/lightly.loss.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ lightly.loss
.. autoclass:: lightly.loss.dcl_loss.DCLWLoss
:members:

.. autoclass:: lightly.loss.detcon_loss.DetConBLoss
:members:

.. autoclass:: lightly.loss.detcon_loss.DetConSLoss
:members:

.. autoclass:: lightly.loss.dino_loss.DINOLoss
:members:

Expand Down
1 change: 1 addition & 0 deletions lightly/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# All Rights Reserved
from lightly.loss.barlow_twins_loss import BarlowTwinsLoss
from lightly.loss.dcl_loss import DCLLoss, DCLWLoss
from lightly.loss.detcon_loss import DetConBLoss, DetConSLoss
from lightly.loss.dino_loss import DINOLoss
from lightly.loss.emp_ssl_loss import EMPSSLLoss
from lightly.loss.ibot_loss import IBOTPatchLoss
Expand Down
316 changes: 316 additions & 0 deletions lightly/loss/detcon_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import distributed as dist
from torch.nn import Module


class DetConSLoss(Module):
"""Implementation of the DetConS loss. [2]_
The inputs are two-fold:
- Two latent representations of the same batch under different views, as generated\
by SimCLR [3]_ and additional pooling over the regions of the segmentation.
- Two integer masks that indicate the regions of the segmentation that were used\
for pooling.
For calculating the contrastive loss, regions under the same mask in the same image
(under a different view) are considered as positives and everything else as
negatives. With :math:`v_m` and :math:`v_{m'}'` being the pooled feature maps under
mask :math:`m` and :math:`m'` respectively, and additionally scaled to a norm of
:math:`\\frac{1}{\\sqrt{\\tau}}`, the formula for the contrastive loss is
.. math::
\\mathcal{L} = \\sum_{m}\\sum_{m'} \\mathbb{1}_{m, m'} \\left[ - \\log\
\\frac{\\exp(v_m \\cdot v_{m'}')}{\\exp(v_m \\cdot v_{m'}') +\
\\sum_{n}\\exp (v_m \\cdot v_{m'}')} \\right]
where :math:`\\mathbb{1}_{m, m'}` is 1 if the masks are the same and 0 otherwise.
References:
.. [2] DetCon https://arxiv.org/abs/2103.10957
.. [3] SimCLR https://arxiv.org/abs/2002.05709
Attributes:
temperature:
The temperature :math:`\\tau` in the contrastive loss.
gather_distributed:
If True, the similarity matrix is gathered across all GPUs before the loss
is calculated. Else, the loss is calculated on each GPU separately.
"""

def __init__(
self, temperature: float = 0.1, gather_distributed: bool = True
) -> None:
super().__init__()
self.detconbloss = DetConBLoss(
temperature=temperature, gather_distributed=gather_distributed
)

def forward(
self, view0: Tensor, view1: Tensor, mask_view0: Tensor, mask_view1: Tensor
) -> Tensor:
"""Calculate the contrastive loss under the same mask in the same image.
The tensor shapes and value ranges are given by variables :math:`B, M, D, N`,
where :math:`B` is the batch size, :math:`M` is the sampled number of image
masks / regions, :math:`D` is the embedding size and :math:`N` is the total
number of masks.
Args:
view0: Mask-pooled output for the first view, a float tensor of shape
:math:`(B, M, D)`.
pred_view1: Mask-pooled output for the second view, a float tensor of shape
:math:`(B, M, D)`.
mask_view0: Indices corresponding to the sampled masks for the first view,
an integer tensor of shape :math:`(B, M)` with (possibly repeated)
indices in the range :math:`[0, N)`.
mask_view1: Indices corresponding to the sampled masks for the second view,
an integer tensor of shape (B, M) with (possibly repeated) indices in
the range :math:`[0, N)`.
Returns:
A scalar float tensor containing the contrastive loss.
"""
loss: Tensor = self.detconbloss(
view0, view1, view0, view1, mask_view0, mask_view1
)
return loss


class DetConBLoss(Module):
"""Implementation of the DetConB loss. [0]_
The inputs are three-fold:
- Two latent representations of the same batch under different views, as generated\
by BYOL's [1]_ prediction branch and additional pooling over the regions of\
the segmentation.
- Two latent representations of the same batch under different views, as generated\
by BYOL's target branch and additional pooling over the regions of the\
segmentation.
- Two integer masks that indicate the regions of the segmentation that were used\
for pooling.
For calculating the contrastive loss, regions under the same mask in the same image
(under a different view) are considered as positives and everything else as
negatives. With :math:`v_m` and :math:`v_{m'}'` being the pooled feature maps under
mask :math:`m` and :math:`m'` respectively, and additionally scaled to a norm of
:math:`\\frac{1}{\\sqrt{\\tau}}`, the formula for the contrastive loss is
.. math::
\\mathcal{L} = \\sum_{m}\\sum_{m'} \\mathbb{1}_{m, m'} \\left[ - \\log \
\\frac{\\exp(v_m \\cdot v_{m'}')}{\\exp(v_m \\cdot v_{m'}') + \\sum_{n}\\exp \
(v_m \\cdot v_{m'}')} \\right]
where :math:`\\mathbb{1}_{m, m'}` is 1 if the masks are the same and 0 otherwise.
Since :math:`v_m` and :math:`v_{m'}'` stem from different branches, the loss is
symmetrized by also calculating the loss with the roles of the views reversed. [1]_
References:
.. [0] DetCon https://arxiv.org/abs/2103.10957
.. [1] BYOL https://arxiv.org/abs/2006.07733
Attributes:
temperature:
The temperature :math:`\\tau` in the contrastive loss.
gather_distributed:
If True, the similarity matrix is gathered across all GPUs before the loss
is calculated. Else, the loss is calculated on each GPU separately.
"""

def __init__(
self, temperature: float = 0.1, gather_distributed: bool = True
) -> None:
super().__init__()
self.eps = 1e-8
self.temperature = temperature
self.gather_distributed = gather_distributed
if abs(self.temperature) < self.eps:
raise ValueError(f"Illegal temperature: abs({self.temperature}) < 1e-8")
if self.gather_distributed and not dist.is_available():
raise ValueError(
"gather_distributed is True but torch.distributed is not available. "
"Please set gather_distributed=False or install a torch version with "
"distributed support."
)

def forward(
self,
pred_view0: Tensor,
pred_view1: Tensor,
target_view0: Tensor,
target_view1: Tensor,
mask_view0: Tensor,
mask_view1: Tensor,
) -> Tensor:
"""Calculate the contrastive loss under the same mask in the same image.
The tensor shapes and value ranges are given by variables :math:`B, M, D, N`,
where :math:`B` is the batch size, :math:`M` is the sampled number of image
masks / regions, :math:`D` is the embedding size and :math:`N` is the total
number of masks.
Args:
pred_view0: Mask-pooled output of the prediction branch for the first view,
a float tensor of shape :math:`(B, M, D)`.
pred_view1: Mask-pooled output of the prediction branch for the second view,
a float tensor of shape :math:`(B, M, D)`.
target_view0: Mask-pooled output of the target branch for the first view,
a float tensor of shape :math:`(B, M, D)`.
target_view1: Mask-pooled output of the target branch for the second view,
a float tensor of shape :math:`(B, M, D)`.
mask_view0: Indices corresponding to the sampled masks for the first view,
an integer tensor of shape :math:`(B, M)` with (possibly repeated)
indices in the range :math:`[0, N)`.
mask_view1: Indices corresponding to the sampled masks for the second view,
an integer tensor of shape (B, M) with (possibly repeated) indices in
the range :math:`[0, N)`.
Returns:
A scalar float tensor containing the contrastive loss.
"""
b, m, d = pred_view0.size()
infinity_proxy = 1e9

# gather distributed
if not self.gather_distributed or dist.get_world_size() < 2:
target_view0_large = target_view0
target_view1_large = target_view1
labels_local = torch.eye(b, device=pred_view0.device)
enlarged_b = b
else:
target_view0_large = torch.cat(dist.gather(target_view0), dim=0)
target_view1_large = torch.cat(dist.gather(target_view1), dim=0)
replica_id = dist.get_rank()
labels_idx = torch.arange(b, device=pred_view0.device) + replica_id * b
enlarged_b = b * dist.get_world_size()
labels_local = F.one_hot(labels_idx, num_classes=enlarged_b)

# normalize
pred_view0 = F.normalize(pred_view0, p=2, dim=2)
pred_view1 = F.normalize(pred_view1, p=2, dim=2)
target_view0_large = F.normalize(target_view0_large, p=2, dim=2)
target_view1_large = F.normalize(target_view1_large, p=2, dim=2)

### Expand Labels ###
# labels_local at this point only points towards the diagonal of the batch, i.e.
# indicates to compare between the same samples across views.
labels_local = labels_local[:, None, :, None] # (b, 1, b * world_size, 1)

### Calculate Similarity Matrices ###
# tensors of shape (b, m, b * world_size, m), indicating similarities between regions across
# views and samples in the batch
logits_aa = (
torch.einsum("abk,uvk->abuv", pred_view0, target_view0_large)
/ self.temperature
)
logits_bb = (
torch.einsum("abk,uvk->abuv", pred_view1, target_view1_large)
/ self.temperature
)
logits_ab = (
torch.einsum("abk,uvk->abuv", pred_view0, target_view1_large)
/ self.temperature
)
logits_ba = (
torch.einsum("abk,uvk->abuv", pred_view1, target_view0_large)
/ self.temperature
)

### Find Corresponding Regions Across Views ###
same_mask_aa = _same_mask(mask_view0, mask_view0)
same_mask_bb = _same_mask(mask_view1, mask_view1)
same_mask_ab = _same_mask(mask_view0, mask_view1)
same_mask_ba = _same_mask(mask_view1, mask_view0)

### Remove Similarities Between Corresponding Views But Different Regions ###
# labels_local initially compared all features across views, but we only want to
# compare the same regions across views.
# (b, 1, b * world_size, 1) * (b, m, 1, m) -> (b, m, b * world_size, m)
labels_aa = labels_local * same_mask_aa
labels_bb = labels_local * same_mask_bb
labels_ab = labels_local * same_mask_ab
labels_ba = labels_local * same_mask_ba

### Remove Logits And Lables Between The Same View ###
logits_aa = logits_aa - infinity_proxy * labels_aa
logits_bb = logits_bb - infinity_proxy * labels_bb
labels_aa = 0.0 * labels_aa
labels_bb = 0.0 * labels_bb

### Arrange Labels ###
# (b, m, b * world_size * 2, m)
labels_abaa = torch.cat([labels_ab, labels_aa], dim=2)
labels_babb = torch.cat([labels_ba, labels_bb], dim=2)
# (b, m, b * world_size * 2 * m)
labels_0 = labels_abaa.view(b, m, -1)
labels_1 = labels_babb.view(b, m, -1)

### Count Number of Positives For Every Region (per sample) ###
num_positives_0 = torch.sum(labels_0, dim=-1, keepdim=True)
num_positives_1 = torch.sum(labels_1, dim=-1, keepdim=True)

### Scale The Labels By The Number of Positives To Weight Loss Value ###
labels_0 = labels_0 / torch.maximum(num_positives_0, torch.tensor(1))
labels_1 = labels_1 / torch.maximum(num_positives_1, torch.tensor(1))

### Count How Many Overlapping Regions We Have Across Views ###
obj_area_0 = torch.sum(same_mask_aa, dim=(2, 3))
obj_area_1 = torch.sum(same_mask_bb, dim=(2, 3))
# make sure we don't divide by zero
obj_area_0 = torch.maximum(obj_area_0, torch.tensor(self.eps))
obj_area_1 = torch.maximum(obj_area_1, torch.tensor(self.eps))

### Calculate Weights For The Loss ###
# last dim of num_positives is anyway 1, from the torch.sum above
weights_0 = torch.gt(num_positives_0.squeeze(-1), 1e-3).float()
weights_0 = weights_0 / obj_area_0
weights_1 = torch.gt(num_positives_1.squeeze(-1), 1e-3).float()
weights_1 = weights_1 / obj_area_1

### Arrange Logits ###
logits_abaa = torch.cat([logits_ab, logits_aa], dim=2)
logits_babb = torch.cat([logits_ba, logits_bb], dim=2)
logits_abaa = logits_abaa.view(b, m, -1)
logits_babb = logits_babb.view(b, m, -1)

# return labels_0, logits_abaa, weights_0, labels_1, logits_babb, weights_1

### Derive Cross Entropy Loss ###
# targets/labels are are a weighted float tensor of same shape as logits,
# which is why we can't use F.cross_entropy (expects integer targets)
loss_a = _torch_manual_cross_entropy(labels_0, logits_abaa, weights_0)
loss_b = _torch_manual_cross_entropy(labels_1, logits_babb, weights_1)
loss = loss_a + loss_b
return loss


def _same_mask(mask0: Tensor, mask1: Tensor) -> Tensor:
"""Find equal masks/regions across views of the same image.
Args:
mask0: Indices corresponding to the sampled masks for the first view,
an integer tensor of shape :math:`(B, M)` with (possibly repeated)
indices in the range :math:`[0, N)`.
mask1: Indices corresponding to the sampled masks for the second view,
an integer tensor of shape (B, M) with (possibly repeated) indices in
the range :math:`[0, N)`.
Returns:
Tensor: A float tensor of shape :math:`(B, M, 1, M)` where the first :math:`M`
dimensions is for the regions/masks of the first view and the last :math:`M`
dimensions is for the regions/masks of the second view. For every sample
:math:`k` in the batch (separately), the tensor is effectively a 2D index
matrix where the entry :math:`(k, i, :, j)` is 1 if the masks :math:`mask0(k, i)`
and :math:`mask1(k, j)'` are the same and 0 otherwise.
"""
return (mask0[:, :, None] == mask1[:, None, :]).float()[:, :, None, :]


def _torch_manual_cross_entropy(
labels: Tensor, logits: Tensor, weight: Tensor
) -> Tensor:
ce = -weight * torch.sum(labels * F.log_softmax(logits, dim=-1), dim=-1)
return torch.mean(ce)
Loading

0 comments on commit 0dcf780

Please sign in to comment.