Skip to content

Commit

Permalink
ultralytics 8.1.5 add OBB Tracking support (ultralytics#7731)
Browse files Browse the repository at this point in the history
Signed-off-by: Glenn Jocher <[email protected]>
Co-authored-by: UltralyticsAssistant <[email protected]>
Co-authored-by: Glenn Jocher <[email protected]>
Co-authored-by: Hassaan Farooq <[email protected]>
  • Loading branch information
4 people authored Jan 23, 2024
1 parent 12a741c commit f56dd0f
Show file tree
Hide file tree
Showing 11 changed files with 92 additions and 44 deletions.
4 changes: 4 additions & 0 deletions docs/en/reference/utils/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ keywords: Ultralytics YOLO, Utility Operations, segment2box, make_divisible, cli

<br><br>

## ::: ultralytics.utils.ops.regularize_rboxes

<br><br>

## ::: ultralytics.utils.ops.masks2segments

<br><br>
Expand Down
2 changes: 1 addition & 1 deletion ultralytics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license

__version__ = "8.1.4"
__version__ = "8.1.5"

from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO
Expand Down
4 changes: 3 additions & 1 deletion ultralytics/engine/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,16 @@ def __len__(self):
if v is not None:
return len(v)

def update(self, boxes=None, masks=None, probs=None):
def update(self, boxes=None, masks=None, probs=None, obb=None):
"""Update the boxes, masks, and probs attributes of the Results object."""
if boxes is not None:
self.boxes = Boxes(ops.clip_boxes(boxes, self.orig_shape), self.orig_shape)
if masks is not None:
self.masks = Masks(masks, self.orig_shape)
if probs is not None:
self.probs = probs
if obb is not None:
self.obb = OBB(obb, self.orig_shape)

def _apply(self, fn, *args, **kwargs):
"""
Expand Down
6 changes: 3 additions & 3 deletions ultralytics/hub/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,14 @@ def retry_request():
break # Timeout reached, exit loop

response = request_func(*args, **kwargs)
if progress_total:
self._show_upload_progress(progress_total, response)

if response is None:
LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}")
time.sleep(2**i) # Exponential backoff before retrying
continue # Skip further processing and retry

if progress_total:
self._show_upload_progress(progress_total, response)

if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
return response # Success, no need to retry

Expand Down
5 changes: 3 additions & 2 deletions ultralytics/models/yolo/obb/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ def postprocess(self, preds, img, orig_imgs):

results = []
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape, xywh=True)
rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
# xywh, r, conf, cls
obb = torch.cat([pred[:, :4], pred[:, -1:], pred[:, 4:6]], dim=-1)
obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1)
results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb))
return results
58 changes: 34 additions & 24 deletions ultralytics/trackers/byte_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from .basetrack import BaseTrack, TrackState
from .utils import matching
from .utils.kalman_filter import KalmanFilterXYAH
from ..utils.ops import xywh2ltwh
from ..utils import LOGGER


class STrack(BaseTrack):
Expand Down Expand Up @@ -35,26 +37,27 @@ class STrack(BaseTrack):
activate(kalman_filter, frame_id): Activate a new tracklet.
re_activate(new_track, frame_id, new_id): Reactivate a previously lost tracklet.
update(new_track, frame_id): Update the state of a matched track.
convert_coords(tlwh): Convert bounding box to x-y-angle-height format.
convert_coords(tlwh): Convert bounding box to x-y-aspect-height format.
tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format.
tlbr_to_tlwh(tlbr): Convert tlbr bounding box to tlwh format.
tlwh_to_tlbr(tlwh): Convert tlwh bounding box to tlbr format.
"""

shared_kalman = KalmanFilterXYAH()

def __init__(self, tlwh, score, cls):
def __init__(self, xywh, score, cls):
"""Initialize new STrack instance."""
super().__init__()
self._tlwh = np.asarray(self.tlbr_to_tlwh(tlwh[:-1]), dtype=np.float32)
# xywh+idx or xywha+idx
assert len(xywh) in [5, 6], f"expected 5 or 6 values but got {len(xywh)}"
self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32)
self.kalman_filter = None
self.mean, self.covariance = None, None
self.is_activated = False

self.score = score
self.tracklet_len = 0
self.cls = cls
self.idx = tlwh[-1]
self.idx = xywh[-1]
self.angle = xywh[4] if len(xywh) == 6 else None

def predict(self):
"""Predicts mean and covariance using Kalman filter."""
Expand Down Expand Up @@ -123,6 +126,7 @@ def re_activate(self, new_track, frame_id, new_id=False):
self.track_id = self.next_id()
self.score = new_track.score
self.cls = new_track.cls
self.angle = new_track.angle
self.idx = new_track.idx

def update(self, new_track, frame_id):
Expand All @@ -145,10 +149,11 @@ def update(self, new_track, frame_id):

self.score = new_track.score
self.cls = new_track.cls
self.angle = new_track.angle
self.idx = new_track.idx

def convert_coords(self, tlwh):
"""Convert a bounding box's top-left-width-height format to its x-y-angle-height equivalent."""
"""Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent."""
return self.tlwh_to_xyah(tlwh)

@property
Expand All @@ -162,7 +167,7 @@ def tlwh(self):
return ret

@property
def tlbr(self):
def xyxy(self):
"""Convert bounding box to format (min x, min y, max x, max y), i.e., (top left, bottom right)."""
ret = self.tlwh.copy()
ret[2:] += ret[:2]
Expand All @@ -178,19 +183,26 @@ def tlwh_to_xyah(tlwh):
ret[2] /= ret[3]
return ret

@staticmethod
def tlbr_to_tlwh(tlbr):
"""Converts top-left bottom-right format to top-left width height format."""
ret = np.asarray(tlbr).copy()
ret[2:] -= ret[:2]
@property
def xywh(self):
"""Get current position in bounding box format (center x, center y, width, height)."""
ret = np.asarray(self.tlwh).copy()
ret[:2] += ret[2:] / 2
return ret

@staticmethod
def tlwh_to_tlbr(tlwh):
"""Converts tlwh bounding box format to tlbr format."""
ret = np.asarray(tlwh).copy()
ret[2:] += ret[:2]
return ret
@property
def xywha(self):
"""Get current position in bounding box format (center x, center y, width, height, angle)."""
if self.angle is None:
LOGGER.warning("WARNING ⚠️ `angle` attr not found, returning `xywh` instead.")
return self.xywh
return np.concatenate([self.xywh, self.angle[None]])

@property
def result(self):
"""Get current tracking results."""
coords = self.xyxy if self.angle is None else self.xywha
return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]

def __repr__(self):
"""Return a string representation of the BYTETracker object with start and end frames and track ID."""
Expand Down Expand Up @@ -247,7 +259,7 @@ def update(self, results, img=None):
removed_stracks = []

scores = results.conf
bboxes = results.xyxy
bboxes = results.xywhr if hasattr(results, "xywhr") else results.xywh
# Add index
bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1)
cls = results.cls
Expand Down Expand Up @@ -349,10 +361,8 @@ def update(self, results, img=None):
self.removed_stracks.extend(removed_stracks)
if len(self.removed_stracks) > 1000:
self.removed_stracks = self.removed_stracks[-999:] # clip remove stracks to 1000 maximum
return np.asarray(
[x.tlbr.tolist() + [x.track_id, x.score, x.cls, x.idx] for x in self.tracked_stracks if x.is_activated],
dtype=np.float32,
)

return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)

def get_kalmanfilter(self):
"""Returns a Kalman filter object for tracking bounding boxes."""
Expand Down
10 changes: 6 additions & 4 deletions ultralytics/trackers/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
Raises:
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
"""
if predictor.args.task == "obb":
raise NotImplementedError("ERROR ❌ OBB task does not support track mode!")
if hasattr(predictor, "trackers") and persist:
return

Expand Down Expand Up @@ -54,19 +52,23 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
bs = predictor.dataset.bs
path, im0s = predictor.batch[:2]

is_obb = predictor.args.task == "obb"
for i in range(bs):
if not persist and predictor.vid_path[i] != str(predictor.save_dir / Path(path[i]).name): # new video
predictor.trackers[i].reset()

det = predictor.results[i].boxes.cpu().numpy()
det = (predictor.results[i].obb if is_obb else predictor.results[i].boxes).cpu().numpy()
if len(det) == 0:
continue
tracks = predictor.trackers[i].update(det, im0s[i])
if len(tracks) == 0:
continue
idx = tracks[:, -1].astype(int)
predictor.results[i] = predictor.results[i][idx]
predictor.results[i].update(boxes=torch.as_tensor(tracks[:, :-1]))

update_args = dict()
update_args["obb" if is_obb else "boxes"] = torch.as_tensor(tracks[:, :-1])
predictor.results[i].update(**update_args)


def register_tracker(model: object, persist: bool) -> None:
Expand Down
20 changes: 14 additions & 6 deletions ultralytics/trackers/utils/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import scipy
from scipy.spatial.distance import cdist

from ultralytics.utils.metrics import bbox_ioa
from ultralytics.utils.metrics import bbox_ioa, batch_probiou

try:
import lap # for linear_assignment
Expand Down Expand Up @@ -74,14 +74,22 @@ def iou_distance(atracks: list, btracks: list) -> np.ndarray:
atlbrs = atracks
btlbrs = btracks
else:
atlbrs = [track.tlbr for track in atracks]
btlbrs = [track.tlbr for track in btracks]
atlbrs = [track.xywha if track.angle is not None else track.xyxy for track in atracks]
btlbrs = [track.xywha if track.angle is not None else track.xyxy for track in btracks]

ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
if len(atlbrs) and len(btlbrs):
ious = bbox_ioa(
np.ascontiguousarray(atlbrs, dtype=np.float32), np.ascontiguousarray(btlbrs, dtype=np.float32), iou=True
)
if len(atlbrs[0]) == 5 and len(btlbrs[0]) == 5:
ious = batch_probiou(
np.ascontiguousarray(atlbrs, dtype=np.float32),
np.ascontiguousarray(btlbrs, dtype=np.float32),
).numpy()
else:
ious = bbox_ioa(
np.ascontiguousarray(atlbrs, dtype=np.float32),
np.ascontiguousarray(btlbrs, dtype=np.float32),
iou=True,
)
return 1 - ious # cost matrix


Expand Down
2 changes: 1 addition & 1 deletion ultralytics/utils/callbacks/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def on_model_save(trainer):
# Upload checkpoints with rate limiting
is_best = trainer.best_fitness == trainer.fitness
if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]:
LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_file}")
LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_id}")
session.upload_model(trainer.epoch, trainer.last, is_best)
session.timers["ckpt"] = time() # reset timer

Expand Down
7 changes: 5 additions & 2 deletions ultralytics/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,16 @@ def batch_probiou(obb1, obb2, eps=1e-7):
Calculate the prob iou between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf.
Args:
obb1 (torch.Tensor): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
obb2 (torch.Tensor): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.
obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
obb2 (torch.Tensor | np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
Returns:
(torch.Tensor): A tensor of shape (N, M) representing obb similarities.
"""
obb1 = torch.from_numpy(obb1) if isinstance(obb1, np.ndarray) else obb1
obb2 = torch.from_numpy(obb2) if isinstance(obb2, np.ndarray) else obb2

x1, y1 = obb1[..., :2].split(1, dim=-1)
x2, y2 = (x.squeeze(-1)[None] for x in obb2[..., :2].split(1, dim=-1))
a1, b1, c1 = _get_covariance_matrix(obb1)
Expand Down
18 changes: 18 additions & 0 deletions ultralytics/utils/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,24 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False
return coords


def regularize_rboxes(rboxes):
"""
Regularize rotated boxes in range [0, pi/2].
Args:
rboxes (torch.Tensor): (N, 5), xywhr.
Returns:
(torch.Tensor): The regularized boxes.
"""
x, y, w, h, t = rboxes.unbind(dim=-1)
# Swap edge and angle if h >= w
w_ = torch.where(w > h, w, h)
h_ = torch.where(w > h, h, w)
t = torch.where(w > h, t, t + math.pi / 2) % math.pi
return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes


def masks2segments(masks, strategy="largest"):
"""
It takes a list of masks(n,h,w) and returns a list of segments(n,xy)
Expand Down

0 comments on commit f56dd0f

Please sign in to comment.