Skip to content

Commit

Permalink
Test pose estimation trackers (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
philipqueen authored Jun 17, 2024
1 parent 79959ad commit ef0c0ff
Show file tree
Hide file tree
Showing 18 changed files with 349 additions and 25 deletions.
41 changes: 41 additions & 0 deletions .github/workflows/python-testing.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: SkellyTracker Tests

on:
pull_request:
branches: [ main ]
paths:
- 'skellytracker/**'

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: System Info
run: |
uname -a || true
lsb_release -a || true
gcc --version || true
env
- name: Set up Python 3.x
uses: actions/setup-python@v4
with:
# Semantic version range syntax or exact version of a Python version
python-version: '3.9'
# Optional - x64 or x86 architecture, defaults to x64
architecture: 'x64'
cache: 'pip'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install "-e.[all]"
- name: Fix OpenCV conflict
run: |
pip uninstall -y opencv-python opencv-contrib-python
pip install opencv-contrib-python==4.8.1.78
- name: Set PYTHONPATH
run: echo "PYTHONPATH=$PYTHONPATH:$(pwd)/skellytracker" >> $GITHUB_ENV
- name: Run Tests with Pytest
run: |
pip install pytest
pytest skellytracker/tests
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ keywords = [
#dynamic = ["dependencies"]
dependencies = [
"opencv-contrib-python==4.8.*",
"numpy<2",
"pydantic==1.*",
"tqdm==4.*",
]
Expand Down
4 changes: 3 additions & 1 deletion skellytracker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
print(f"adding base_package_path: {base_package_path} : to sys.path")
sys.path.insert(0, str(base_package_path)) # add parent directory to sys.path

print(f"sys path: {sys.path}")

from skellytracker.system.default_paths import get_log_file_path
from skellytracker.system.logging_configuration import configure_logging

Expand All @@ -38,4 +40,4 @@



configure_logging(log_file_path=get_log_file_path())
configure_logging(log_file_path=str(get_log_file_path()))
1 change: 1 addition & 0 deletions skellytracker/system/default_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
BASE_FOLDER_NAME = f"{__package_name__}_data"
LOGS_INFO_AND_SETTINGS_FOLDER_NAME = "logs_info_and_settings"
LOG_FILE_FOLDER_NAME = "logs"
FIGSHARE_TEST_IMAGE_URL = "https://figshare.com/ndownloader/files/47043898"

def get_base_folder_path():
base_folder = Path().home() / BASE_FOLDER_NAME
Expand Down
Empty file removed skellytracker/test/__init__.py
Empty file.
File renamed without changes.
14 changes: 14 additions & 0 deletions skellytracker/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import numpy as np
import pytest
from skellytracker.utilities.download_test_image import download_test_image


class SessionInfo:
test_image: np.ndarray

def pytest_sessionstart(session):
SessionInfo.test_image = download_test_image()

@pytest.fixture
def test_image():
return SessionInfo.test_image
126 changes: 126 additions & 0 deletions skellytracker/tests/test_mediapipe_holistic_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import math
import cv2
import pytest
import numpy as np


from skellytracker.trackers.mediapipe_tracker.mediapipe_model_info import (
MediapipeModelInfo,
)
from skellytracker.trackers.mediapipe_tracker.mediapipe_holistic_tracker import (
MediapipeHolisticTracker,
)


@pytest.mark.usefixtures("test_image")
def test_process_image(test_image):
tracker = MediapipeHolisticTracker(model_complexity=0)
tracked_objects = tracker.process_image(test_image)

assert len(tracked_objects) == 4
assert tracked_objects["pose_landmarks"] is not None
assert tracked_objects["pose_landmarks"].extra["landmarks"] is not None
assert tracked_objects["right_hand_landmarks"] is not None
assert tracked_objects["right_hand_landmarks"].extra["landmarks"] is not None
assert tracked_objects["left_hand_landmarks"] is not None
assert tracked_objects["left_hand_landmarks"].extra["landmarks"] is not None
assert tracked_objects["face_landmarks"] is not None
assert tracked_objects["face_landmarks"].extra["landmarks"] is not None


@pytest.mark.usefixtures("test_image")
def test_annotate_image(test_image):
tracker = MediapipeHolisticTracker(model_complexity=0)
tracker.process_image(test_image)

assert tracker.annotated_image is not None


@pytest.mark.usefixtures("test_image")
def test_record(test_image):
tracker = MediapipeHolisticTracker(model_complexity=0)
tracked_objects = tracker.process_image(test_image)
tracker.recorder.record(tracked_objects=tracked_objects)
assert len(tracker.recorder.recorded_objects) == 1
assert len(tracker.recorder.recorded_objects[0]) == 4

processed_results = tracker.recorder.process_tracked_objects(
image_size=test_image.shape[:2]
)
assert processed_results is not None
assert processed_results.shape == (
1,
MediapipeModelInfo.num_tracked_points_total,
3,
)

expected_results = np.array(
[
[
[724.9490356445312, 80.74257552623749, -993.9854431152344],
[748.7973022460938, 76.24467730522156, -955.4129028320312],
[757.65625, 77.80313193798065, -955.689697265625],
[765.3969573974609, 79.39812362194061, -955.8861541748047],
[724.6346282958984, 73.26869666576385, -956.6817474365234],
[717.9697418212891, 72.75395393371582, -956.7975616455078],
[711.0582733154297, 72.3196667432785, -956.7622375488281],
[775.3928375244141, 86.46261692047119, -668.9157867431641],
[700.8257293701172, 76.53615295886993, -673.5513305664062],
[736.4148712158203, 91.9196355342865, -881.9746398925781],
[704.9551391601562, 87.58252501487732, -883.7062072753906],
[822.0133209228516, 152.90948510169983, -438.8511276245117],
[543.9170837402344, 128.66411805152893, -474.09820556640625],
[830.2363586425781, 222.54773139953613, -348.53546142578125],
[391.7664337158203, 185.8674716949463, -379.22679901123047],
[850.4001617431641, 281.42955780029297, -597.9225921630859],
[251.3930320739746, 234.69798803329468, -603.9112854003906],
[856.3973236083984, 297.1209526062012, -668.1196594238281],
[213.4073257446289, 246.7552900314331, -670.8364105224609],
[846.2405395507812, 297.8205370903015, -768.86962890625],
[218.36742401123047, 250.46862602233887, -770.2429962158203],
[841.2351226806641, 292.3808026313782, -645.0504302978516],
[237.8342628479004, 244.97743606567383, -649.0667724609375],
[642.2610473632812, 267.2499203681946, 19.359580278396606],
[489.5075225830078, 252.38561153411865, -18.99002432823181],
[545.5322647094727, 353.45691204071045, 95.11228561401367],
[421.98192596435547, 343.2889795303345, -111.5739917755127],
[483.64437103271484, 413.63812923431396, 892.3109436035156],
[386.21551513671875, 410.32382011413574, 586.4739990234375],
[469.1791534423828, 416.32561683654785, 958.9944458007812],
[386.70398712158203, 417.09611892700195, 642.3279571533203],
[473.2674789428711, 447.85693645477295, 729.4257354736328],
[355.9325408935547, 444.5236587524414, 379.90428924560547],
[253.68976593017578, 234.29851055145264, 0.0001452891228836961],
[240.58324813842773, 234.31915283203125, -17.537919282913208],
[225.88279724121094, 236.56160831451416, -26.281862258911133],
[216.74800872802734, 239.36949491500854, -33.34794282913208],
[207.5124740600586, 241.5495729446411, -39.871764183044434],
[203.52766036987305, 250.0307822227478, -13.84335994720459],
[184.57786560058594, 257.0957851409912, -22.172749042510986],
[172.95087814331055, 261.3565707206726, -29.567382335662842],
[163.63656997680664, 264.60676431655884, -34.58906173706055],
[210.67567825317383, 252.71584510803223, -10.44108510017395],
[192.9056167602539, 260.5885577201843, -17.135307788848877],
[181.97023391723633, 265.8717155456543, -23.215365409851074],
[173.35857391357422, 269.7653818130493, -27.345290184020996],
[219.38091278076172, 253.6610770225525, -8.99298369884491],
[204.17118072509766, 261.2510418891907, -14.699053764343262],
[195.27652740478516, 266.1923360824585, -19.270341396331787],
[187.59790420532227, 269.8603105545044, -22.399024963378906],
[230.7835578918457, 253.28086853027344, -9.420499801635742],
[223.0868148803711, 259.03817653656006, -14.548875093460083],
[218.3838653564453, 262.8340816497803, -17.04281210899353],
[214.09095764160156, 265.90606927871704, -18.470606803894043],
[831.3526916503906, 290.13811111450195, 0.00027009031327906996],
[834.6914672851562, 290.1472735404968, -28.73823642730713],
[838.1427764892578, 291.96690559387207, -45.007004737854004],
[837.1710205078125, 295.0093674659729, -56.84781551361084],
[838.1177520751953, 298.5115385055542, -67.84458637237549],
[836.1019134521484, 307.75739192962646, -34.55303907394409],
[831.5267944335938, 317.115318775177, -48.21352005004883],
]
]
)
assert np.allclose(
processed_results[:, :60, :], expected_results[:, :60, :], atol=1
)
File renamed without changes.
117 changes: 117 additions & 0 deletions skellytracker/tests/test_yolo_pose_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import math
import pytest
import numpy as np


from skellytracker.trackers.yolo_tracker.yolo_model_info import YOLOModelInfo
from skellytracker.trackers.yolo_tracker.yolo_tracker import YOLOPoseTracker


@pytest.mark.usefixtures("test_image")
def test_process_image(test_image):
tracker = YOLOPoseTracker(model_size="nano")
tracked_objects = tracker.process_image(test_image)

assert len(tracked_objects) == 1
tracked_person = tracked_objects["tracked_person"]
assert tracked_person.pixel_x is not None
assert tracked_person.pixel_y is not None
assert math.isclose(tracked_person.pixel_x, 266.48523, rel_tol=1e-2)
assert math.isclose(tracked_person.pixel_y, 273.92798, rel_tol=1e-2)

landmarks = tracked_person.extra["landmarks"]
assert landmarks is not None
assert landmarks.shape == (1, YOLOModelInfo.num_tracked_points, 2)

expected_results = np.array(
[
[
[392.56927490234375, 140.40118408203125],
[414.94940185546875, 132.90655517578125],
[386.4353332519531, 125.51483154296875],
[446.2061767578125, 157.98883056640625],
[373.6619873046875, 138.93646240234375],
[453.78662109375, 265.081787109375],
[317.9375305175781, 231.9653778076172],
[465.893310546875, 396.12274169921875],
[220.12176513671875, 325.96636962890625],
[465.23358154296875, 499.0487365722656],
[142.17066955566406, 407.7397155761719],
[352.325439453125, 468.44671630859375],
[268.8867492675781, 448.7227783203125],
[310.899658203125, 630.5478515625],
[227.7810821533203, 617.9011840820312],
[269.08587646484375, 733.4285888671875],
[213.1557159423828, 741.54541015625],
]
]
)

assert np.allclose(landmarks, expected_results)


@pytest.mark.usefixtures("test_image")
def test_annotate_image(test_image):
tracker = YOLOPoseTracker(model_size="nano")
tracker.process_image(test_image)

assert tracker.annotated_image is not None


@pytest.mark.usefixtures("test_image")
def test_record(test_image):
tracker = YOLOPoseTracker(model_size="nano")
tracked_objects = tracker.process_image(test_image)
tracker.recorder.record(tracked_objects=tracked_objects)
assert len(tracker.recorder.recorded_objects) == 1

processed_results = tracker.recorder.process_tracked_objects()
assert processed_results is not None
assert processed_results.shape == (1, YOLOModelInfo.num_tracked_points, 3)

expected_results = np.array(
[
[
[392.56927490234375, 140.40118408203125, np.nan],
[414.94940185546875, 132.90655517578125, np.nan],
[386.4353332519531, 125.51483154296875, np.nan],
[446.2061767578125, 157.98883056640625, np.nan],
[373.6619873046875, 138.93646240234375, np.nan],
[453.78662109375, 265.081787109375, np.nan],
[317.9375305175781, 231.9653778076172, np.nan],
[465.893310546875, 396.12274169921875, np.nan],
[220.12176513671875, 325.96636962890625, np.nan],
[465.23358154296875, 499.0487365722656, np.nan],
[142.17066955566406, 407.7397155761719, np.nan],
[352.325439453125, 468.44671630859375, np.nan],
[268.8867492675781, 448.7227783203125, np.nan],
[310.899658203125, 630.5478515625, np.nan],
[227.7810821533203, 617.9011840820312, np.nan],
[269.08587646484375, 733.4285888671875, np.nan],
[213.1557159423828, 741.54541015625, np.nan],
]
]
)
assert np.allclose(processed_results[:, :, :2], expected_results[:, :, :2], atol=1e-2)
assert np.isnan(processed_results[:, :, 2]).all()


class MockKeypoints:
xy: list = []


class MockResults:
keypoints: MockKeypoints = MockKeypoints()


def test_unpack_empty_results():
tracker = YOLOPoseTracker(model_size="nano")
results = [MockResults()]
tracker.unpack_results(results=results)

tracked_person = tracker.tracked_objects["tracked_person"]
assert tracked_person.pixel_x is None
assert tracked_person.pixel_y is None
assert tracked_person.extra["landmarks"].shape == (1, YOLOModelInfo.num_tracked_points, 2)
assert np.isnan(tracked_person.extra["landmarks"][0, :, 0]).all()
assert np.isnan(tracked_person.extra["landmarks"][0, :, 1]).all()
4 changes: 2 additions & 2 deletions skellytracker/trackers/base_tracker/base_recorder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
import logging
from typing import Dict
from typing import Dict, Optional

import numpy as np

Expand All @@ -20,7 +20,7 @@ def __init__(self):

@abstractmethod
def record(
self, tracked_objects: Dict[str, TrackedObject], annotated_image: np.ndarray
self, tracked_objects: Dict[str, TrackedObject], annotated_image: Optional[np.ndarray] = None
) -> None:
"""
Record the tracked objects as they are created by the tracker.
Expand Down
1 change: 0 additions & 1 deletion skellytracker/trackers/base_tracker/base_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def __init__(
):
self.recorder = recorder
self.annotated_image = None
self.raw_image = None
self.tracked_objects: Dict[str, TrackedObject] = {}

for name in tracked_object_names:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ def process_image(self, image: np.ndarray, **kwargs) -> Dict[str, TrackedObject]
self.tracked_objects["brightest_point"].pixel_y = largest_patch_centroid[1]
self.tracked_objects["brightest_point"].extra["thresholded_image"] = thresholded_image

self.raw_image = image.copy()

self.annotated_image = self.annotate_image(image=image,
tracked_objects=self.tracked_objects)

Expand Down
2 changes: 0 additions & 2 deletions skellytracker/trackers/charuco_tracker/charuco_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ def process_image(self, image: np.ndarray, **kwargs) -> Dict[str, TrackedObject]
self.tracked_objects[object_id].pixel_x = corner[0][0]
self.tracked_objects[object_id].pixel_y = corner[0][1]

self.raw_image = image.copy()

self.annotated_image = self.annotate_image(image=image,
tracked_objects=self.tracked_objects)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def record(self, tracked_objects: Dict[str, TrackedObject]) -> None:
def process_tracked_objects(self, **kwargs) -> np.ndarray:
image_size = kwargs.get("image_size")
if image_size is None:
raise ValueError("image_size must be provided to process tracked objects")
raise ValueError(f"image_size must be provided to process tracked objects from {__class__.__name__}")
self.recorded_objects_array = np.zeros(
(
len(self.recorded_objects),
Expand Down
Loading

0 comments on commit ef0c0ff

Please sign in to comment.