From 0dcf780f4caa27c20d470d5b27f75559788876c5 Mon Sep 17 00:00:00 2001 From: Lionel Peer Date: Mon, 6 Jan 2025 14:43:05 +0100 Subject: [PATCH] Add DetConSLoss and DetConBLoss (#1771) * allow access to v2 transforms availability from anywhere * import unnecessary after exception Co-authored-by: guarin <43336610+guarin@users.noreply.github.com> * reformat * add implementation of AddGridTransform * make AddGridTransform importable * add tests for AddGridTransform * reformat * enhance docstring * fix typing issues * fix import when transforms.v2 not available * add additional type ignore for 3.7 compatibility * reformat * add transform to docs * change header * add explanation on data structures * use kw args * add assertion for mask dimension to be geq 2 * remove unnecessary fixtures * make argument order consistent * add DetCon SingleView and MultiView transforms * add MultiViewTransform BaseClass for v2 transforms * add tests for DetCon transform * export all newly added transforms * add DetCon single view and DetCon multi view transforms * add torchvision transforms v2 compatible MultiViewTransforms * make newly added transforms public * remove unnecessary fixtures * add tests for DetCon transform * remove wrongfully added files * add DetCon transform and MultiView transforms for v2 to docs * fix docs references * fix import issues for minimal dependencies * fixing code format * add test for multiviewtransformv2 * fix testing of multiview * use singular AddGridTransforms * consistent naming to DetConS * adjust docstring reference numbering * name refactoring * start detconloss implementation * implement detconloss * test detconloss * make detconloss public * add detconloss to docs * Update lightly/loss/detcon_loss.py Co-authored-by: guarin <43336610+guarin@users.noreply.github.com> * initial small fixes * add comments and avoid 0 division * remove labels_ext * remove labels_ext * revert normalization; some formatting * complete rewrite of tests * fix typing issues * remove unused imports * Update lightly/loss/detcon_loss.py Co-authored-by: guarin <43336610+guarin@users.noreply.github.com> * move to f-strings * create test classes * squeeze instead of 0 indexing * formatting * few more comments on tensor shapes * additional comments on tensor shapes --------- Co-authored-by: guarin <43336610+guarin@users.noreply.github.com> --- docs/source/lightly.loss.rst | 6 + lightly/loss/__init__.py | 1 + lightly/loss/detcon_loss.py | 316 +++++++++++++++++++++++++++++++++ tests/loss/test_detcon_loss.py | 285 +++++++++++++++++++++++++++++ 4 files changed, 608 insertions(+) create mode 100644 lightly/loss/detcon_loss.py create mode 100644 tests/loss/test_detcon_loss.py diff --git a/docs/source/lightly.loss.rst b/docs/source/lightly.loss.rst index 9396a44a5..2d2ce0ff1 100644 --- a/docs/source/lightly.loss.rst +++ b/docs/source/lightly.loss.rst @@ -14,6 +14,12 @@ lightly.loss .. autoclass:: lightly.loss.dcl_loss.DCLWLoss :members: +.. autoclass:: lightly.loss.detcon_loss.DetConBLoss + :members: + +.. autoclass:: lightly.loss.detcon_loss.DetConSLoss + :members: + .. autoclass:: lightly.loss.dino_loss.DINOLoss :members: diff --git a/lightly/loss/__init__.py b/lightly/loss/__init__.py index 2eb138b32..663d2688c 100644 --- a/lightly/loss/__init__.py +++ b/lightly/loss/__init__.py @@ -4,6 +4,7 @@ # All Rights Reserved from lightly.loss.barlow_twins_loss import BarlowTwinsLoss from lightly.loss.dcl_loss import DCLLoss, DCLWLoss +from lightly.loss.detcon_loss import DetConBLoss, DetConSLoss from lightly.loss.dino_loss import DINOLoss from lightly.loss.emp_ssl_loss import EMPSSLLoss from lightly.loss.ibot_loss import IBOTPatchLoss diff --git a/lightly/loss/detcon_loss.py b/lightly/loss/detcon_loss.py new file mode 100644 index 000000000..b96019976 --- /dev/null +++ b/lightly/loss/detcon_loss.py @@ -0,0 +1,316 @@ +import torch +import torch.nn.functional as F +from torch import Tensor +from torch import distributed as dist +from torch.nn import Module + + +class DetConSLoss(Module): + """Implementation of the DetConS loss. [2]_ + + The inputs are two-fold: + + - Two latent representations of the same batch under different views, as generated\ + by SimCLR [3]_ and additional pooling over the regions of the segmentation. + - Two integer masks that indicate the regions of the segmentation that were used\ + for pooling. + + For calculating the contrastive loss, regions under the same mask in the same image + (under a different view) are considered as positives and everything else as + negatives. With :math:`v_m` and :math:`v_{m'}'` being the pooled feature maps under + mask :math:`m` and :math:`m'` respectively, and additionally scaled to a norm of + :math:`\\frac{1}{\\sqrt{\\tau}}`, the formula for the contrastive loss is + + .. math:: + \\mathcal{L} = \\sum_{m}\\sum_{m'} \\mathbb{1}_{m, m'} \\left[ - \\log\ + \\frac{\\exp(v_m \\cdot v_{m'}')}{\\exp(v_m \\cdot v_{m'}') +\ + \\sum_{n}\\exp (v_m \\cdot v_{m'}')} \\right] + + where :math:`\\mathbb{1}_{m, m'}` is 1 if the masks are the same and 0 otherwise. + + References: + .. [2] DetCon https://arxiv.org/abs/2103.10957 + .. [3] SimCLR https://arxiv.org/abs/2002.05709 + + Attributes: + temperature: + The temperature :math:`\\tau` in the contrastive loss. + gather_distributed: + If True, the similarity matrix is gathered across all GPUs before the loss + is calculated. Else, the loss is calculated on each GPU separately. + """ + + def __init__( + self, temperature: float = 0.1, gather_distributed: bool = True + ) -> None: + super().__init__() + self.detconbloss = DetConBLoss( + temperature=temperature, gather_distributed=gather_distributed + ) + + def forward( + self, view0: Tensor, view1: Tensor, mask_view0: Tensor, mask_view1: Tensor + ) -> Tensor: + """Calculate the contrastive loss under the same mask in the same image. + + The tensor shapes and value ranges are given by variables :math:`B, M, D, N`, + where :math:`B` is the batch size, :math:`M` is the sampled number of image + masks / regions, :math:`D` is the embedding size and :math:`N` is the total + number of masks. + + Args: + view0: Mask-pooled output for the first view, a float tensor of shape + :math:`(B, M, D)`. + pred_view1: Mask-pooled output for the second view, a float tensor of shape + :math:`(B, M, D)`. + mask_view0: Indices corresponding to the sampled masks for the first view, + an integer tensor of shape :math:`(B, M)` with (possibly repeated) + indices in the range :math:`[0, N)`. + mask_view1: Indices corresponding to the sampled masks for the second view, + an integer tensor of shape (B, M) with (possibly repeated) indices in + the range :math:`[0, N)`. + + Returns: + A scalar float tensor containing the contrastive loss. + """ + loss: Tensor = self.detconbloss( + view0, view1, view0, view1, mask_view0, mask_view1 + ) + return loss + + +class DetConBLoss(Module): + """Implementation of the DetConB loss. [0]_ + + The inputs are three-fold: + + - Two latent representations of the same batch under different views, as generated\ + by BYOL's [1]_ prediction branch and additional pooling over the regions of\ + the segmentation. + - Two latent representations of the same batch under different views, as generated\ + by BYOL's target branch and additional pooling over the regions of the\ + segmentation. + - Two integer masks that indicate the regions of the segmentation that were used\ + for pooling. + + For calculating the contrastive loss, regions under the same mask in the same image + (under a different view) are considered as positives and everything else as + negatives. With :math:`v_m` and :math:`v_{m'}'` being the pooled feature maps under + mask :math:`m` and :math:`m'` respectively, and additionally scaled to a norm of + :math:`\\frac{1}{\\sqrt{\\tau}}`, the formula for the contrastive loss is + + .. math:: + \\mathcal{L} = \\sum_{m}\\sum_{m'} \\mathbb{1}_{m, m'} \\left[ - \\log \ + \\frac{\\exp(v_m \\cdot v_{m'}')}{\\exp(v_m \\cdot v_{m'}') + \\sum_{n}\\exp \ + (v_m \\cdot v_{m'}')} \\right] + + where :math:`\\mathbb{1}_{m, m'}` is 1 if the masks are the same and 0 otherwise. + Since :math:`v_m` and :math:`v_{m'}'` stem from different branches, the loss is + symmetrized by also calculating the loss with the roles of the views reversed. [1]_ + + References: + .. [0] DetCon https://arxiv.org/abs/2103.10957 + .. [1] BYOL https://arxiv.org/abs/2006.07733 + + Attributes: + temperature: + The temperature :math:`\\tau` in the contrastive loss. + gather_distributed: + If True, the similarity matrix is gathered across all GPUs before the loss + is calculated. Else, the loss is calculated on each GPU separately. + """ + + def __init__( + self, temperature: float = 0.1, gather_distributed: bool = True + ) -> None: + super().__init__() + self.eps = 1e-8 + self.temperature = temperature + self.gather_distributed = gather_distributed + if abs(self.temperature) < self.eps: + raise ValueError(f"Illegal temperature: abs({self.temperature}) < 1e-8") + if self.gather_distributed and not dist.is_available(): + raise ValueError( + "gather_distributed is True but torch.distributed is not available. " + "Please set gather_distributed=False or install a torch version with " + "distributed support." + ) + + def forward( + self, + pred_view0: Tensor, + pred_view1: Tensor, + target_view0: Tensor, + target_view1: Tensor, + mask_view0: Tensor, + mask_view1: Tensor, + ) -> Tensor: + """Calculate the contrastive loss under the same mask in the same image. + + The tensor shapes and value ranges are given by variables :math:`B, M, D, N`, + where :math:`B` is the batch size, :math:`M` is the sampled number of image + masks / regions, :math:`D` is the embedding size and :math:`N` is the total + number of masks. + + Args: + pred_view0: Mask-pooled output of the prediction branch for the first view, + a float tensor of shape :math:`(B, M, D)`. + pred_view1: Mask-pooled output of the prediction branch for the second view, + a float tensor of shape :math:`(B, M, D)`. + target_view0: Mask-pooled output of the target branch for the first view, + a float tensor of shape :math:`(B, M, D)`. + target_view1: Mask-pooled output of the target branch for the second view, + a float tensor of shape :math:`(B, M, D)`. + mask_view0: Indices corresponding to the sampled masks for the first view, + an integer tensor of shape :math:`(B, M)` with (possibly repeated) + indices in the range :math:`[0, N)`. + mask_view1: Indices corresponding to the sampled masks for the second view, + an integer tensor of shape (B, M) with (possibly repeated) indices in + the range :math:`[0, N)`. + + Returns: + A scalar float tensor containing the contrastive loss. + """ + b, m, d = pred_view0.size() + infinity_proxy = 1e9 + + # gather distributed + if not self.gather_distributed or dist.get_world_size() < 2: + target_view0_large = target_view0 + target_view1_large = target_view1 + labels_local = torch.eye(b, device=pred_view0.device) + enlarged_b = b + else: + target_view0_large = torch.cat(dist.gather(target_view0), dim=0) + target_view1_large = torch.cat(dist.gather(target_view1), dim=0) + replica_id = dist.get_rank() + labels_idx = torch.arange(b, device=pred_view0.device) + replica_id * b + enlarged_b = b * dist.get_world_size() + labels_local = F.one_hot(labels_idx, num_classes=enlarged_b) + + # normalize + pred_view0 = F.normalize(pred_view0, p=2, dim=2) + pred_view1 = F.normalize(pred_view1, p=2, dim=2) + target_view0_large = F.normalize(target_view0_large, p=2, dim=2) + target_view1_large = F.normalize(target_view1_large, p=2, dim=2) + + ### Expand Labels ### + # labels_local at this point only points towards the diagonal of the batch, i.e. + # indicates to compare between the same samples across views. + labels_local = labels_local[:, None, :, None] # (b, 1, b * world_size, 1) + + ### Calculate Similarity Matrices ### + # tensors of shape (b, m, b * world_size, m), indicating similarities between regions across + # views and samples in the batch + logits_aa = ( + torch.einsum("abk,uvk->abuv", pred_view0, target_view0_large) + / self.temperature + ) + logits_bb = ( + torch.einsum("abk,uvk->abuv", pred_view1, target_view1_large) + / self.temperature + ) + logits_ab = ( + torch.einsum("abk,uvk->abuv", pred_view0, target_view1_large) + / self.temperature + ) + logits_ba = ( + torch.einsum("abk,uvk->abuv", pred_view1, target_view0_large) + / self.temperature + ) + + ### Find Corresponding Regions Across Views ### + same_mask_aa = _same_mask(mask_view0, mask_view0) + same_mask_bb = _same_mask(mask_view1, mask_view1) + same_mask_ab = _same_mask(mask_view0, mask_view1) + same_mask_ba = _same_mask(mask_view1, mask_view0) + + ### Remove Similarities Between Corresponding Views But Different Regions ### + # labels_local initially compared all features across views, but we only want to + # compare the same regions across views. + # (b, 1, b * world_size, 1) * (b, m, 1, m) -> (b, m, b * world_size, m) + labels_aa = labels_local * same_mask_aa + labels_bb = labels_local * same_mask_bb + labels_ab = labels_local * same_mask_ab + labels_ba = labels_local * same_mask_ba + + ### Remove Logits And Lables Between The Same View ### + logits_aa = logits_aa - infinity_proxy * labels_aa + logits_bb = logits_bb - infinity_proxy * labels_bb + labels_aa = 0.0 * labels_aa + labels_bb = 0.0 * labels_bb + + ### Arrange Labels ### + # (b, m, b * world_size * 2, m) + labels_abaa = torch.cat([labels_ab, labels_aa], dim=2) + labels_babb = torch.cat([labels_ba, labels_bb], dim=2) + # (b, m, b * world_size * 2 * m) + labels_0 = labels_abaa.view(b, m, -1) + labels_1 = labels_babb.view(b, m, -1) + + ### Count Number of Positives For Every Region (per sample) ### + num_positives_0 = torch.sum(labels_0, dim=-1, keepdim=True) + num_positives_1 = torch.sum(labels_1, dim=-1, keepdim=True) + + ### Scale The Labels By The Number of Positives To Weight Loss Value ### + labels_0 = labels_0 / torch.maximum(num_positives_0, torch.tensor(1)) + labels_1 = labels_1 / torch.maximum(num_positives_1, torch.tensor(1)) + + ### Count How Many Overlapping Regions We Have Across Views ### + obj_area_0 = torch.sum(same_mask_aa, dim=(2, 3)) + obj_area_1 = torch.sum(same_mask_bb, dim=(2, 3)) + # make sure we don't divide by zero + obj_area_0 = torch.maximum(obj_area_0, torch.tensor(self.eps)) + obj_area_1 = torch.maximum(obj_area_1, torch.tensor(self.eps)) + + ### Calculate Weights For The Loss ### + # last dim of num_positives is anyway 1, from the torch.sum above + weights_0 = torch.gt(num_positives_0.squeeze(-1), 1e-3).float() + weights_0 = weights_0 / obj_area_0 + weights_1 = torch.gt(num_positives_1.squeeze(-1), 1e-3).float() + weights_1 = weights_1 / obj_area_1 + + ### Arrange Logits ### + logits_abaa = torch.cat([logits_ab, logits_aa], dim=2) + logits_babb = torch.cat([logits_ba, logits_bb], dim=2) + logits_abaa = logits_abaa.view(b, m, -1) + logits_babb = logits_babb.view(b, m, -1) + + # return labels_0, logits_abaa, weights_0, labels_1, logits_babb, weights_1 + + ### Derive Cross Entropy Loss ### + # targets/labels are are a weighted float tensor of same shape as logits, + # which is why we can't use F.cross_entropy (expects integer targets) + loss_a = _torch_manual_cross_entropy(labels_0, logits_abaa, weights_0) + loss_b = _torch_manual_cross_entropy(labels_1, logits_babb, weights_1) + loss = loss_a + loss_b + return loss + + +def _same_mask(mask0: Tensor, mask1: Tensor) -> Tensor: + """Find equal masks/regions across views of the same image. + + Args: + mask0: Indices corresponding to the sampled masks for the first view, + an integer tensor of shape :math:`(B, M)` with (possibly repeated) + indices in the range :math:`[0, N)`. + mask1: Indices corresponding to the sampled masks for the second view, + an integer tensor of shape (B, M) with (possibly repeated) indices in + the range :math:`[0, N)`. + + Returns: + Tensor: A float tensor of shape :math:`(B, M, 1, M)` where the first :math:`M` + dimensions is for the regions/masks of the first view and the last :math:`M` + dimensions is for the regions/masks of the second view. For every sample + :math:`k` in the batch (separately), the tensor is effectively a 2D index + matrix where the entry :math:`(k, i, :, j)` is 1 if the masks :math:`mask0(k, i)` + and :math:`mask1(k, j)'` are the same and 0 otherwise. + """ + return (mask0[:, :, None] == mask1[:, None, :]).float()[:, :, None, :] + + +def _torch_manual_cross_entropy( + labels: Tensor, logits: Tensor, weight: Tensor +) -> Tensor: + ce = -weight * torch.sum(labels * F.log_softmax(logits, dim=-1), dim=-1) + return torch.mean(ce) diff --git a/tests/loss/test_detcon_loss.py b/tests/loss/test_detcon_loss.py new file mode 100644 index 000000000..a3cf25cdd --- /dev/null +++ b/tests/loss/test_detcon_loss.py @@ -0,0 +1,285 @@ +import numpy as np +import numpy.linalg as npl +import pytest +import scipy.special as sps +import torch +from pytest_mock import MockerFixture +from torch import Tensor +from torch import distributed as dist + +from lightly.loss import DetConBLoss, DetConSLoss + + +class TestDetConSLoss: + def test_DetConSLoss(self) -> None: + loss_fn = DetConSLoss(gather_distributed=False) + pred1 = torch.randn(4, 24, 32) + pred2 = torch.randn(4, 24, 32) + mask1 = torch.randint(0, 16, (4, 24)) + mask2 = torch.randint(0, 16, (4, 24)) + loss = loss_fn(pred1, pred2, mask1, mask2) + assert loss.shape == torch.Size([]) + + +class TestDetConBLoss: + @pytest.mark.parametrize( + "temperature,batch_size,num_classes,sampled_masks,emb_dim", + [ + [0.1, 4, 16, 24, 32], + [0.5, 8, 32, 16, 64], + [1.0, 16, 64, 32, 128], + ], + ) + def test_DetConBLoss_against_original( + self, + temperature: float, + batch_size: int, + num_classes: int, + sampled_masks: int, + emb_dim: int, + ) -> None: + pred1 = torch.randn(batch_size, sampled_masks, emb_dim) + pred2 = torch.randn(batch_size, sampled_masks, emb_dim) + target1 = torch.randn(batch_size, sampled_masks, emb_dim) + target2 = torch.randn(batch_size, sampled_masks, emb_dim) + + mask1 = torch.randint(0, num_classes, (batch_size, sampled_masks)) + mask2 = torch.randint(0, num_classes, (batch_size, sampled_masks)) + + orig_loss = byol_nce_detcon( + pred1.numpy(), + pred2.numpy(), + target1.numpy(), + target2.numpy(), + mask1.numpy(), + mask2.numpy(), + mask1.numpy(), + mask2.numpy(), + temperature=temperature, + ) + + loss_fn = DetConBLoss(temperature=temperature, gather_distributed=False) + loss = loss_fn(pred1, pred2, target1, target2, mask1, mask2) + + assert torch.allclose( + loss, torch.tensor(orig_loss, dtype=torch.float32), atol=1e-4 + ) + + @pytest.mark.parametrize( + "temperature,batch_size,num_classes,sampled_masks,emb_dim,world_size", + [ + [0.1, 4, 16, 24, 32, 1], + [0.5, 8, 32, 16, 64, 2], + [1.0, 16, 64, 32, 128, 4], + ], + ) + def test_DetConBLoss_distributed_against_original( + self, + mocker: MockerFixture, + temperature: float, + batch_size: int, + num_classes: int, + sampled_masks: int, + emb_dim: int, + world_size: int, + ) -> None: + tensors = [ + { + "pred1": torch.randn(batch_size, sampled_masks, emb_dim), + "pred2": torch.randn(batch_size, sampled_masks, emb_dim), + "target1": torch.randn(batch_size, sampled_masks, emb_dim), + "target2": torch.randn(batch_size, sampled_masks, emb_dim), + "mask1": torch.randint(0, num_classes, (batch_size, sampled_masks)), + "mask2": torch.randint(0, num_classes, (batch_size, sampled_masks)), + } + for _ in range(world_size) + ] + + # calculate non-distributed by packing all batches + loss_nondist = byol_nce_detcon( + torch.cat([t["pred1"] for t in tensors], dim=0).numpy(), + torch.cat([t["pred2"] for t in tensors], dim=0).numpy(), + torch.cat([t["target1"] for t in tensors], dim=0).numpy(), + torch.cat([t["target2"] for t in tensors], dim=0).numpy(), + torch.cat([t["mask1"] for t in tensors], dim=0).numpy(), + torch.cat([t["mask2"] for t in tensors], dim=0).numpy(), + torch.cat([t["mask1"] for t in tensors], dim=0).numpy(), + torch.cat([t["mask2"] for t in tensors], dim=0).numpy(), + temperature=temperature, + ) + + mock_is_available = mocker.patch.object(dist, "is_available", return_value=True) + mock_get_world_size = mocker.patch.object( + dist, "get_world_size", return_value=world_size + ) + + loss_fn = DetConBLoss(temperature=temperature, gather_distributed=True) + + total_loss: Tensor = torch.tensor(0.0) + for rank in range(world_size): + mock_get_rank = mocker.patch.object(dist, "get_rank", return_value=rank) + mock_gather = mocker.patch.object( + dist, + "gather", + side_effect=[ + [t["target1"] for t in tensors], + [t["target2"] for t in tensors], + ], + ) + loss_val = loss_fn( + tensors[rank]["pred1"], + tensors[rank]["pred2"], + tensors[rank]["target1"], + tensors[rank]["target2"], + tensors[rank]["mask1"], + tensors[rank]["mask2"], + ) + total_loss += loss_val + total_loss /= world_size + + assert torch.allclose( + total_loss, torch.tensor(loss_nondist, dtype=torch.float32), atol=1e-4 + ) + + +### Original JAX/haiku Implementation with Minimal Changes ### +# Source: https://github.com/google-deepmind/detcon/blob/main/utils/losses.py +# +# Changes: +# 1. change any jnp function to np +# 2. remove distributed implementation (doesn't work anyway with numpy) +# 3. use scipy.special.log_softmax instead of jax.nn.log_softmax +# 4. commented unused +# 5. change hk.one_hot to np.eye +# 6. changer helper function (norm) to numpy.linalg.norm +def byol_nce_detcon( # type: ignore + pred1, + pred2, + target1, + target2, + pind1, + pind2, + tind1, + tind2, + temperature=0.1, + use_replicator_loss=True, + local_negatives=True, +): + """Compute the NCE scores from pairs of predictions and targets. + + This implements the batched form of the loss described in + Section 3.1, Equation 3 in https://arxiv.org/pdf/2103.10957.pdf. + + Args: + pred1 (jnp.array): the prediction from first view. + pred2 (jnp.array): the prediction from second view. + target1 (jnp.array): the projection from first view. + target2 (jnp.array): the projection from second view. + pind1 (jnp.array): mask indices for first view's prediction. + pind2 (jnp.array): mask indices for second view's prediction. + tind1 (jnp.array): mask indices for first view's projection. + tind2 (jnp.array): mask indices for second view's projection. + temperature (float): the temperature to use for the NCE loss. + use_replicator_loss (bool): use cross-replica samples. + local_negatives (bool): whether to include local negatives + + Returns: + A single scalar loss for the XT-NCE objective. + + """ + batch_size = pred1.shape[0] + num_rois = pred1.shape[1] + feature_dim = pred1.shape[-1] + infinity_proxy = 1e9 # Used for masks to proxy a very large number. + + def make_same_obj(ind_0, ind_1): # type: ignore + same_obj = np.equal( + ind_0.reshape([batch_size, num_rois, 1]), + ind_1.reshape([batch_size, 1, num_rois]), + ) + same_obj2 = np.expand_dims(same_obj.astype(np.float32), axis=2) + return same_obj2 + + same_obj_aa = make_same_obj(pind1, tind1) + same_obj_ab = make_same_obj(pind1, tind2) + same_obj_ba = make_same_obj(pind2, tind1) + same_obj_bb = make_same_obj(pind2, tind2) + + # L2 normalize the tensors to use for the cosine-similarity + pred1 = pred1 / npl.norm(pred1, ord=2, axis=-1, keepdims=True) + pred2 = pred2 / npl.norm(pred2, ord=2, axis=-1, keepdims=True) + target1 = target1 / npl.norm(target1, ord=2, axis=-1, keepdims=True) + target2 = target2 / npl.norm(target2, ord=2, axis=-1, keepdims=True) + + target1_large = target1 + target2_large = target2 + labels_local = np.eye(batch_size) + # labels_ext = hk.one_hot(np.arange(batch_size), batch_size * 2) + + labels_local = np.expand_dims(np.expand_dims(labels_local, axis=2), axis=1) + # labels_ext = np.expand_dims(np.expand_dims(labels_ext, axis=2), axis=1) + + # Do our matmuls and mask out appropriately. + logits_aa = np.einsum("abk,uvk->abuv", pred1, target1_large) / temperature + logits_bb = np.einsum("abk,uvk->abuv", pred2, target2_large) / temperature + logits_ab = np.einsum("abk,uvk->abuv", pred1, target2_large) / temperature + logits_ba = np.einsum("abk,uvk->abuv", pred2, target1_large) / temperature + + labels_aa = labels_local * same_obj_aa + labels_ab = labels_local * same_obj_ab + labels_ba = labels_local * same_obj_ba + labels_bb = labels_local * same_obj_bb + + logits_aa = logits_aa - infinity_proxy * labels_local * same_obj_aa + logits_bb = logits_bb - infinity_proxy * labels_local * same_obj_bb + labels_aa = 0.0 * labels_aa + labels_bb = 0.0 * labels_bb + if not local_negatives: + logits_aa = logits_aa - infinity_proxy * labels_local * (1 - same_obj_aa) + logits_ab = logits_ab - infinity_proxy * labels_local * (1 - same_obj_ab) + logits_ba = logits_ba - infinity_proxy * labels_local * (1 - same_obj_ba) + logits_bb = logits_bb - infinity_proxy * labels_local * (1 - same_obj_bb) + + labels_abaa = np.concatenate([labels_ab, labels_aa], axis=2) + labels_babb = np.concatenate([labels_ba, labels_bb], axis=2) + + labels_0 = np.reshape(labels_abaa, [batch_size, num_rois, -1]) + labels_1 = np.reshape(labels_babb, [batch_size, num_rois, -1]) + + num_positives_0 = np.sum(labels_0, axis=-1, keepdims=True) + num_positives_1 = np.sum(labels_1, axis=-1, keepdims=True) + + labels_0 = labels_0 / np.maximum(num_positives_0, 1) + labels_1 = labels_1 / np.maximum(num_positives_1, 1) + + obj_area_0 = np.sum(make_same_obj(pind1, pind1), axis=(2, 3)) + obj_area_1 = np.sum(make_same_obj(pind2, pind2), axis=(2, 3)) + + weights_0 = np.greater(num_positives_0[..., 0], 1e-3).astype("float32") + weights_0 = weights_0 / obj_area_0 + weights_1 = np.greater(num_positives_1[..., 0], 1e-3).astype("float32") + weights_1 = weights_1 / obj_area_1 + + logits_abaa = np.concatenate([logits_ab, logits_aa], axis=2) + logits_babb = np.concatenate([logits_ba, logits_bb], axis=2) + + logits_abaa = np.reshape(logits_abaa, [batch_size, num_rois, -1]) + logits_babb = np.reshape(logits_babb, [batch_size, num_rois, -1]) + + # return labels_0, logits_abaa, weights_0, labels_1, logits_babb, weights_1 + + loss_a = manual_cross_entropy(labels_0, logits_abaa, weights_0) + loss_b = manual_cross_entropy(labels_1, logits_babb, weights_1) + loss = loss_a + loss_b + + return loss + + +def manual_cross_entropy( # type: ignore + labels, + logits, + weight, +): + ce = -weight * np.sum(labels * sps.log_softmax(logits, axis=-1), axis=-1) + mean: float = np.mean(ce) + return mean