Skip to content

Commit

Permalink
ultralytics 8.2.72 SAM 2 multiple-bboxes support (ultralytics#14928)
Browse files Browse the repository at this point in the history
Co-authored-by: Glenn Jocher <[email protected]>
  • Loading branch information
Laughing-q and glenn-jocher authored Aug 3, 2024
1 parent 2187649 commit bea4c93
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 16 deletions.
2 changes: 1 addition & 1 deletion ultralytics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license

__version__ = "8.2.71"
__version__ = "8.2.72"

import os

Expand Down
25 changes: 10 additions & 15 deletions ultralytics/models/sam2/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,28 +102,23 @@ def prompt_inference(
if bboxes is not None:
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
bboxes *= r
bboxes = bboxes.view(-1, 2, 2) * r
bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
# NOTE: merge "boxes" and "points" into a single "points" input
# (where boxes are added at the beginning) to model.sam_prompt_encoder
if points is not None:
points = torch.cat([bboxes, points], dim=1)
labels = torch.cat([bbox_labels, labels], dim=1)
else:
points, labels = bboxes, bbox_labels
if masks is not None:
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)

points = (points, labels) if points is not None else None
# TODO: Embed prompts
# if bboxes is not None:
# box_coords = bboxes.reshape(-1, 2, 2)
# box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=bboxes.device)
# box_labels = box_labels.repeat(bboxes.size(0), 1)
# # we merge "boxes" and "points" into a single "concat_points" input (where
# # boxes are added at the beginning) to sam_prompt_encoder
# if concat_points is not None:
# concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
# concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
# concat_points = (concat_coords, concat_labels)
# else:
# concat_points = (box_coords, box_labels)

sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
points=points,
boxes=bboxes,
boxes=None,
masks=masks,
)
# Predict masks
Expand Down

0 comments on commit bea4c93

Please sign in to comment.