Skip to content

Commit

Permalink
Track multiple brightest points (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
philipqueen authored Jun 17, 2024
1 parent ef0c0ff commit e70939a
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 37 deletions.
2 changes: 1 addition & 1 deletion skellytracker/RUN_ME.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
def main(demo_tracker: str = "mediapipe_holistic_tracker"):

if demo_tracker == "brightest_point_tracker":
BrightestPointTracker().demo()
BrightestPointTracker(num_points=2).demo()

elif demo_tracker == "charuco_tracker":
CharucoTracker(
Expand Down
2 changes: 1 addition & 1 deletion skellytracker/SINGLE_IMAGE_RUN.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
image_path = Path("/Path/To/Your/Image.jpg")

if demo_tracker == "brightest_point_tracker":
BrightestPointTracker().image_demo(image_path=image_path)
BrightestPointTracker(num_points=2).image_demo(image_path=image_path)

elif demo_tracker == "charuco_tracker":
CharucoTracker(
Expand Down
84 changes: 84 additions & 0 deletions skellytracker/test/test_brightest_point_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import cv2
import pytest
import numpy as np


from skellytracker.trackers.bright_point_tracker.brightest_point_tracker import (
BrightestPointTracker,
)


@pytest.fixture
def sample_image():
"""
Create a sample image with bright spots for testing.
"""
image = np.zeros((100, 100, 3), dtype=np.uint8)
cv2.circle(image, (30, 30), 10, (255, 255, 255), -1) # Bright spot
cv2.circle(image, (70, 70), 15, (255, 255, 255), -1) # Brighter spot
return image


def test_process_image_with_one_brightest_point(sample_image):
tracker = BrightestPointTracker(num_points=1, luminance_threshold=200)
tracked_objects = tracker.process_image(sample_image)

assert len(tracked_objects) == 1
assert tracked_objects["brightest_point_0"].pixel_x == 70
assert tracked_objects["brightest_point_0"].pixel_y == 70


def test_process_image_with_two_brightest_points(sample_image):
tracker = BrightestPointTracker(num_points=2, luminance_threshold=200)
tracked_objects = tracker.process_image(sample_image)

assert len(tracked_objects) == 2
assert tracked_objects["brightest_point_0"].pixel_x == 70
assert tracked_objects["brightest_point_0"].pixel_y == 70
assert tracked_objects["brightest_point_1"].pixel_x == 30
assert tracked_objects["brightest_point_1"].pixel_y == 30


def test_process_image_with_five_brightest_points(sample_image):
tracker = BrightestPointTracker(num_points=5, luminance_threshold=200)
tracked_objects = tracker.process_image(sample_image)

assert len(tracked_objects) == 5
assert tracked_objects["brightest_point_0"].pixel_x == 70
assert tracked_objects["brightest_point_0"].pixel_y == 70
assert tracked_objects["brightest_point_1"].pixel_x == 30
assert tracked_objects["brightest_point_1"].pixel_y == 30
assert tracked_objects["brightest_point_2"].pixel_x is None
assert tracked_objects["brightest_point_2"].pixel_y is None
assert tracked_objects["brightest_point_3"].pixel_x is None
assert tracked_objects["brightest_point_3"].pixel_y is None
assert tracked_objects["brightest_point_4"].pixel_x is None
assert tracked_objects["brightest_point_4"].pixel_y is None


def test_annotate_image(sample_image):
tracker = BrightestPointTracker(num_points=2, luminance_threshold=200)
tracked_objects = tracker.process_image(sample_image)

assert tracker.annotated_image is not None

# Check if the bright spots have markers
bright_point_0 = tracked_objects["brightest_point_0"]
bright_point_1 = tracked_objects["brightest_point_1"]

assert tracker.annotated_image[
bright_point_0.pixel_y, bright_point_0.pixel_x
].tolist() == [0, 0, 255]
assert tracker.annotated_image[
bright_point_1.pixel_y, bright_point_1.pixel_x
].tolist() == [0, 0, 255]


def test_record(sample_image):
tracker = BrightestPointTracker(num_points=2, luminance_threshold=200)
tracked_objects = tracker.process_image(sample_image)
tracker.recorder.record(tracked_objects=tracked_objects)

assert len(tracker.recorder.recorded_objects) == 1

assert len(tracker.recorder.recorded_objects[0]) == 2
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,22 @@

class BrightestPointRecorder(BaseRecorder):
def record(self, tracked_objects: Dict[str, TrackedObject]) -> None:
self.recorded_objects.append(deepcopy(tracked_objects["brightest_point"]))
self.recorded_objects.append(
[
(tracked_object.pixel_x, tracked_object.pixel_y)
for tracked_object in tracked_objects.values()
if "brightest_point" in tracked_object.object_id
]
)

def process_tracked_objects(self, **kwargs) -> np.ndarray:
self.recorded_objects_array = np.zeros((len(self.recorded_objects), 1, 3))
num_frames = len(self.recorded_objects)
num_points = len(self.recorded_objects[0]) if num_frames > 0 else 0

self.recorded_objects_array = np.zeros((num_frames, num_points, 2))
for i, recorded_object in enumerate(self.recorded_objects):
self.recorded_objects_array[i, 0, 0] = recorded_object.pixel_x
self.recorded_objects_array[i, 0, 1] = recorded_object.pixel_y
self.recorded_objects_array[i, 0, 2] = recorded_object.depth_z
for j, (pixel_x, pixel_y) in enumerate(recorded_object):
self.recorded_objects_array[i, j, 0] = pixel_x
self.recorded_objects_array[i, j, 1] = pixel_y

return self.recorded_objects_array
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@

import cv2
import numpy as np
from pydantic import BaseModel

from skellytracker.trackers.base_tracker.base_tracker import BaseTracker
from skellytracker.trackers.base_tracker.tracked_object import TrackedObject
from skellytracker.trackers.bright_point_tracker.brightest_point_recorder import BrightestPointRecorder
from skellytracker.trackers.bright_point_tracker.brightest_point_recorder import (
BrightestPointRecorder,
)

UPPER_BOUND_COLOR = [255, 255, 255]

Expand All @@ -15,62 +18,93 @@
logger = logging.getLogger(__name__)


class BrightPatch(BaseModel):
area: float
centroid_x: int
centroid_y: int


class BrightestPointTracker(BaseTracker):
luminance_threshold: int = 200
def __init__(self, num_points: int = 1, luminance_threshold: int = 200):
super().__init__(
tracked_object_names=[f"brightest_point_{i}" for i in range(num_points)],
recorder=BrightestPointRecorder(),
)

def __init__(self):
super().__init__(tracked_object_names=["brightest_point"], recorder=BrightestPointRecorder())
self.num_points = num_points
self.luminance_threshold = luminance_threshold

def process_image(self, image: np.ndarray, **kwargs) -> Dict[str, TrackedObject]:
# Convert the image to grayscale
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

# Threshold the image to get only bright regions
_, thresholded_image = cv2.threshold(gray_image, self.luminance_threshold, 255, cv2.THRESH_BINARY)
_, thresholded_image = cv2.threshold(
gray_image, self.luminance_threshold, 255, cv2.THRESH_BINARY
)

# Find contours of the bright regions
bright_patches, _ = cv2.findContours(thresholded_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
bright_patches, _ = cv2.findContours(
thresholded_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)

# Process each bright patch separately
largest_area = 0
largest_patch_centroid = None
patch_list = []
for patch in bright_patches:
# Calculate the centroid of the bright patch
patch_moments = cv2.moments(patch)
if patch_moments["m00"] != 0: # Avoid division by zero
centroid_x = int(patch_moments["m10"] / patch_moments["m00"])
centroid_y = int(patch_moments["m01"] / patch_moments["m00"])

# Keep track of the largest patch and its centroid
area = cv2.contourArea(patch)
if area > largest_area:
largest_area = area
largest_patch_centroid = (centroid_x, centroid_y)
patch_list.append(
BrightPatch(
area=cv2.contourArea(patch),
centroid_x=centroid_x,
centroid_y=centroid_y,
)
)

largest_patches = sorted(
patch_list, key=lambda patch: patch.area, reverse=True
)[: self.num_points]

for i, patch in enumerate(largest_patches):
self.tracked_objects[f"brightest_point_{i}"].pixel_x = patch.centroid_x
self.tracked_objects[f"brightest_point_{i}"].pixel_y = patch.centroid_y
self.tracked_objects[f"brightest_point_{i}"].extra["thresholdedimage"] = thresholded_image

# If a largest patch was found, update the tracked object
if largest_patch_centroid is not None:
self.tracked_objects["brightest_point"].pixel_x = largest_patch_centroid[0]
self.tracked_objects["brightest_point"].pixel_y = largest_patch_centroid[1]
self.tracked_objects["brightest_point"].extra["thresholded_image"] = thresholded_image
for i in range(len(largest_patches), self.num_points):
self.tracked_objects[f"brightest_point_{i}"].pixel_x = None # TODO: Is this the right value for missing data?
self.tracked_objects[f"brightest_point_{i}"].pixel_y = None

self.annotated_image = self.annotate_image(image=image,
tracked_objects=self.tracked_objects)
self.annotated_image = self.annotate_image(
image=image, tracked_objects=self.tracked_objects
)

return self.tracked_objects

def annotate_image(self, image: np.ndarray, tracked_objects: Dict[str, TrackedObject], **kwargs) -> np.ndarray:
# Copy the original image for annotation
def annotate_image(
self, image: np.ndarray, tracked_objects: Dict[str, TrackedObject], **kwargs
) -> np.ndarray:
annotated_image = image.copy()

# Draw a red 'X' over the largest bright patch
if tracked_objects["brightest_point"].pixel_x is not None and tracked_objects[
"brightest_point"].pixel_y is not None:
cv2.drawMarker(annotated_image,
(tracked_objects["brightest_point"].pixel_x, tracked_objects["brightest_point"].pixel_y),
(0, 0, 255), markerType=cv2.MARKER_CROSS, markerSize=20, thickness=2)
for key, tracked_object in tracked_objects.items():
if (
"brightest_point" in key
and tracked_object.pixel_x is not None
and tracked_object.pixel_y is not None
):
cv2.drawMarker(
img=annotated_image,
position=(tracked_object.pixel_x, tracked_object.pixel_y),
color=(0, 0, 255),
markerType=cv2.MARKER_CROSS,
markerSize=20,
thickness=2,
)

return annotated_image


if __name__ == "__main__":
BrightestPointTracker().demo()
BrightestPointTracker(num_points=2).demo()

0 comments on commit e70939a

Please sign in to comment.