diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 8eb9bc7812d..16bacbbe53d 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.2.71" +__version__ = "8.2.72" import os diff --git a/ultralytics/models/sam2/predict.py b/ultralytics/models/sam2/predict.py index ca9438a2b7a..f40c0158213 100644 --- a/ultralytics/models/sam2/predict.py +++ b/ultralytics/models/sam2/predict.py @@ -102,28 +102,23 @@ def prompt_inference( if bboxes is not None: bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device) bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes - bboxes *= r + bboxes = bboxes.view(-1, 2, 2) * r + bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1) + # NOTE: merge "boxes" and "points" into a single "points" input + # (where boxes are added at the beginning) to model.sam_prompt_encoder + if points is not None: + points = torch.cat([bboxes, points], dim=1) + labels = torch.cat([bbox_labels, labels], dim=1) + else: + points, labels = bboxes, bbox_labels if masks is not None: masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1) points = (points, labels) if points is not None else None - # TODO: Embed prompts - # if bboxes is not None: - # box_coords = bboxes.reshape(-1, 2, 2) - # box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=bboxes.device) - # box_labels = box_labels.repeat(bboxes.size(0), 1) - # # we merge "boxes" and "points" into a single "concat_points" input (where - # # boxes are added at the beginning) to sam_prompt_encoder - # if concat_points is not None: - # concat_coords = torch.cat([box_coords, concat_points[0]], dim=1) - # concat_labels = torch.cat([box_labels, concat_points[1]], dim=1) - # concat_points = (concat_coords, concat_labels) - # else: - # concat_points = (box_coords, box_labels) sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( points=points, - boxes=bboxes, + boxes=None, masks=masks, ) # Predict masks