Skip to content

Commit

Permalink
ultralytics 8.3.47 fix softmax and Classify head commonality (u…
Browse files Browse the repository at this point in the history
…ltralytics#18085)

Signed-off-by: Glenn Jocher <[email protected]>
Co-authored-by: UltralyticsAssistant <[email protected]>
Co-authored-by: Glenn Jocher <[email protected]>
  • Loading branch information
3 people authored Dec 7, 2024
1 parent 95a9796 commit 3d52c4f
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 4 deletions.
2 changes: 1 addition & 1 deletion ultralytics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license

__version__ = "8.3.44"
__version__ = "8.3.47"

import os

Expand Down
4 changes: 3 additions & 1 deletion ultralytics/engine/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
from ultralytics.data.dataset import YOLODataset
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.autobackend import check_class_names, default_class_names
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
from ultralytics.nn.modules import C2f, Classify, Detect, RTDETRDecoder
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel
from ultralytics.utils import (
ARM64,
Expand Down Expand Up @@ -287,6 +287,8 @@ def __call__(self, model=None) -> str:

model = FXModel(model)
for m in model.modules():
if isinstance(m, Classify):
m.export = True
if isinstance(m, (Detect, RTDETRDecoder)): # includes all Detect subclasses like Segment, Pose, OBB
m.dynamic = self.args.dynamic
m.export = True
Expand Down
3 changes: 2 additions & 1 deletion ultralytics/models/yolo/classify/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def postprocess(self, preds, img, orig_imgs):
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

preds = preds[0] if isinstance(preds, (list, tuple)) else preds
return [
Results(orig_img, path=img_path, names=self.model.names, probs=pred.softmax(0))
Results(orig_img, path=img_path, names=self.model.names, probs=pred)
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
]
4 changes: 4 additions & 0 deletions ultralytics/models/yolo/classify/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def finalize_metrics(self, *args, **kwargs):
self.metrics.confusion_matrix = self.confusion_matrix
self.metrics.save_dir = self.save_dir

def postprocess(self, preds):
"""Preprocesses the classification predictions."""
return preds[0] if isinstance(preds, (list, tuple)) else preds

def get_stats(self):
"""Returns a dictionary of metrics obtained by processing targets and predictions."""
self.metrics.process(self.targets, self.pred)
Expand Down
7 changes: 6 additions & 1 deletion ultralytics/nn/modules/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ def kpts_decode(self, bs, kpts):
class Classify(nn.Module):
"""YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2)."""

export = False # export mode

def __init__(self, c1, c2, k=1, s=1, p=None, g=1):
"""Initializes YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape."""
super().__init__()
Expand All @@ -296,7 +298,10 @@ def forward(self, x):
if isinstance(x, list):
x = torch.cat(x, 1)
x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
return x
if self.training:
return x
y = x.softmax(1) # get final output
return y if self.export else (y, x)


class WorldDetect(Detect):
Expand Down
1 change: 1 addition & 0 deletions ultralytics/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,7 @@ class v8ClassificationLoss:

def __call__(self, preds, batch):
"""Compute the classification loss between predictions and true labels."""
preds = preds[1] if isinstance(preds, (list, tuple)) else preds
loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
loss_items = loss.detach()
return loss, loss_items
Expand Down

0 comments on commit 3d52c4f

Please sign in to comment.