Skip to content

Commit

Permalink
Add stability to Label Relaxation loss
Browse files Browse the repository at this point in the history
  • Loading branch information
adelmemariani committed Jan 9, 2025
1 parent 461528d commit 84d5745
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 12 deletions.
Binary file removed dicee/losses/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file not shown.
43 changes: 31 additions & 12 deletions dicee/losses/label_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch import nn
from torch.nn import functional as F


class LabelRelaxationLoss(nn.Module):
def __init__(self, alpha=0.1, dim=-1, logits_provided=True, one_hot_encode_trgts=False, num_classes=-1):
super(LabelRelaxationLoss, self).__init__()
Expand All @@ -11,25 +12,43 @@ def __init__(self, alpha=0.1, dim=-1, logits_provided=True, one_hot_encode_trgts
self.logits_provided = logits_provided
self.one_hot_encode_trgts = one_hot_encode_trgts
self.num_classes = num_classes
self.eps = 1e-9 # Small epsilon for numerical stability

def forward(self, pred, target):
if self.logits_provided:
pred = pred.softmax(dim=self.dim)
pred = F.softmax(pred, dim=self.dim)

# Ensure predictions are stable (minimal clamping)
pred = torch.clamp(pred, min=self.eps, max=1.0)

# Apply one-hot encoding to targets
if self.one_hot_encode_trgts:
target = F.one_hot(target, num_classes=self.num_classes)
target = F.one_hot(target, num_classes=self.num_classes).float()

# Construct credal set
with torch.no_grad():
sum_y_hat_prime = torch.sum((torch.ones_like(target) - target) * pred, dim=-1)
pred_hat = self.alpha * pred / torch.unsqueeze(sum_y_hat_prime, dim=-1)
target_credal = torch.where(target > self.gz_threshold, torch.ones_like(target) - self.alpha, pred_hat)
sum_y_hat_prime = torch.sum((1.0 - target) * pred, dim=-1)

sum_y_hat_prime = sum_y_hat_prime + self.eps

pred_hat = self.alpha * pred / sum_y_hat_prime.unsqueeze(dim=-1)

target_credal = torch.where(
target > self.gz_threshold,
1.0 - self.alpha,
pred_hat
)

target_credal = torch.clamp(target_credal, min=self.eps, max=1.0)

divergence = F.kl_div(pred.log(), target_credal, log_target=False, reduction="none")
divergence = torch.sum(divergence, dim=-1)

pred_alignment = torch.sum(pred * target, dim=-1)

mask = pred_alignment > (1.0 - self.alpha)
mask = mask.float()

result = (1.0 - mask) * divergence

# Calculate divergence
divergence = torch.sum(F.kl_div(pred.log(), target_credal, log_target=False, reduction="none"), dim=-1)
return torch.mean(result)

pred = torch.sum(pred * target, dim=-1)

result = torch.where(torch.gt(pred, 1. - self.alpha), torch.zeros_like(divergence), divergence)
return torch.mean(result)

0 comments on commit 84d5745

Please sign in to comment.