From 84d5745d28702610faa121efa9203c9ca7c2d305 Mon Sep 17 00:00:00 2001 From: Adel Memariani Date: Thu, 9 Jan 2025 15:39:52 +0100 Subject: [PATCH] Add stability to Label Relaxation loss --- .../__pycache__/__init__.cpython-310.pyc | Bin 170 -> 0 bytes .../label_relaxation.cpython-310.pyc | Bin 1503 -> 0 bytes dicee/losses/label_relaxation.py | 43 +++++++++++++----- 3 files changed, 31 insertions(+), 12 deletions(-) delete mode 100644 dicee/losses/__pycache__/__init__.cpython-310.pyc delete mode 100644 dicee/losses/__pycache__/label_relaxation.cpython-310.pyc diff --git a/dicee/losses/__pycache__/__init__.cpython-310.pyc b/dicee/losses/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 553edb9be18783c9196001b53fb30edcbc1ddfe3..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 170 zcmd1j<>g`k01*g5GJ|*r#nfOlt7481TKPHVI7jfRS*YApfrw(#6gg4i1m`X+Y@*9OmZz- z*sCPKm5|!0bfq34ufjG(;4XE5Aeq@c3o0}V4u|A$W;pZB)}zsY0KIFo=}kb$pSZY7 zJ}kb0VV(dJL{LRioS7meO({j$Q>^q-pOWv1@I`P*L?G!rxbV_Y(4WXSd=GcY@t8VR zRlo-)O_kfCt}>P5$z4K!NaizOlB86Sl!@dmIiVsD;boF~A`N&UCPCusoo2X6OS z_Cl&3rOIAsX!u!e%wNC1L;B(M>U5=&ch!LHegP|||Ty_Ny@ z1{KV?n?*JJygjGEQ7znuXLF<99LVMO9eNU6858EOVYHN;jOIb34=0 z%xWdvma3ZvtuqiRZ?tIQZ_#-a0#R3|zCCmQRf zhsZ|AHi6wxm+)Sj^BeNd&y!hQ%1I`qntWI1O(`pDCZfpYV_CkCLKM~X{obhzW}Cv_%1w3s`v|Hj?Vz+`27l_^XSjFT6(>` zB5i<(7vYt^_8Q_jz=z$&LD3(K2Q3v5#+BvXSy!y5MXkwbUDEEId_Q6tA2I;W#$6 zZg4Jj-CuNmDYI%E>n+su7Z9Zj`;gNu@WJ2kS!MzEMJdx5k8z!jdKI3-MyApMn3wR( zDwi&Vo@tcTJ7BP9yJ$K(sr6Z=1^O|c01GLA(c7j2?||;pNA%#|Ps1Jdhyg4g-UrL| z{c+IsaF^-lN8JRS20dJ5rGy?Ha9-A;Q78{Nf7xWJKj|7le~z(?)@_$dr8vJ6={B5d z47d9;;J?CPMPulj?WEfqM;e3I448|%q8wJ0##Yh$xEKD=dGD0(SnjgXB_jT+kMC<} OB4YZ!8qy*CnEekC*?xuq diff --git a/dicee/losses/label_relaxation.py b/dicee/losses/label_relaxation.py index 91597ec4..ed6ab19e 100644 --- a/dicee/losses/label_relaxation.py +++ b/dicee/losses/label_relaxation.py @@ -2,6 +2,7 @@ from torch import nn from torch.nn import functional as F + class LabelRelaxationLoss(nn.Module): def __init__(self, alpha=0.1, dim=-1, logits_provided=True, one_hot_encode_trgts=False, num_classes=-1): super(LabelRelaxationLoss, self).__init__() @@ -11,25 +12,43 @@ def __init__(self, alpha=0.1, dim=-1, logits_provided=True, one_hot_encode_trgts self.logits_provided = logits_provided self.one_hot_encode_trgts = one_hot_encode_trgts self.num_classes = num_classes + self.eps = 1e-9 # Small epsilon for numerical stability def forward(self, pred, target): if self.logits_provided: - pred = pred.softmax(dim=self.dim) + pred = F.softmax(pred, dim=self.dim) + + # Ensure predictions are stable (minimal clamping) + pred = torch.clamp(pred, min=self.eps, max=1.0) - # Apply one-hot encoding to targets if self.one_hot_encode_trgts: - target = F.one_hot(target, num_classes=self.num_classes) + target = F.one_hot(target, num_classes=self.num_classes).float() - # Construct credal set with torch.no_grad(): - sum_y_hat_prime = torch.sum((torch.ones_like(target) - target) * pred, dim=-1) - pred_hat = self.alpha * pred / torch.unsqueeze(sum_y_hat_prime, dim=-1) - target_credal = torch.where(target > self.gz_threshold, torch.ones_like(target) - self.alpha, pred_hat) + sum_y_hat_prime = torch.sum((1.0 - target) * pred, dim=-1) + + sum_y_hat_prime = sum_y_hat_prime + self.eps + + pred_hat = self.alpha * pred / sum_y_hat_prime.unsqueeze(dim=-1) + + target_credal = torch.where( + target > self.gz_threshold, + 1.0 - self.alpha, + pred_hat + ) + + target_credal = torch.clamp(target_credal, min=self.eps, max=1.0) + + divergence = F.kl_div(pred.log(), target_credal, log_target=False, reduction="none") + divergence = torch.sum(divergence, dim=-1) + + pred_alignment = torch.sum(pred * target, dim=-1) + + mask = pred_alignment > (1.0 - self.alpha) + mask = mask.float() + + result = (1.0 - mask) * divergence - # Calculate divergence - divergence = torch.sum(F.kl_div(pred.log(), target_credal, log_target=False, reduction="none"), dim=-1) + return torch.mean(result) - pred = torch.sum(pred * target, dim=-1) - result = torch.where(torch.gt(pred, 1. - self.alpha), torch.zeros_like(divergence), divergence) - return torch.mean(result) \ No newline at end of file