Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mediapipe into development #53

Merged
merged 3 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion skellytracker/trackers/demo_viewers/webcam_demo_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def run(self):
overlay_string += (
"Controls:\n"
f"'SPACE'/'{chr(KEY_PAUSE_P)}': pause\n"
f"'{chr(KEY_SHOW_INFO)}': {'show info' if not show_info else "hide info"}\n"
f"'{chr(KEY_SHOW_INFO)}': {'show info' if not show_info else 'hide info'}\n"
f"'{chr(KEY_SHOW_OVERLAY)}': show overlay\n"
f"'{chr(KEY_SET_AUTO_EXPOSURE)}': auto-exposure\n"
f"'{chr(KEY_INCREASE_EXPOSURE)}'/'{chr(KEY_DECREASE_EXPOSURE)}': exposure +/-\n"
Expand Down
1 change: 1 addition & 0 deletions skellytracker/trackers/mediapipe_tracker/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .__mediapipe_tracker import MediapipeTracker, MediapipeTrackerConfig
50 changes: 50 additions & 0 deletions skellytracker/trackers/mediapipe_tracker/__mediapipe_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import logging
from dataclasses import dataclass, field

import numpy as np

from skellytracker.trackers.base_tracker.base_tracker import BaseTracker, BaseTrackerConfig
from skellytracker.trackers.mediapipe_tracker.mediapipe_annotator import MediapipeAnnotatorConfig, MediapipeImageAnnotator
from skellytracker.trackers.mediapipe_tracker.mediapipe_detector import MediapipeDetector, MediapipeDetectorConfig
from skellytracker.trackers.mediapipe_tracker.mediapipe_observation import MediapipeObservation

logger = logging.getLogger(__name__)

@dataclass
class MediapipeTrackerConfig(BaseTrackerConfig):
detector_config: MediapipeDetectorConfig = field(default_factory = MediapipeDetectorConfig)
annotator_config: MediapipeAnnotatorConfig = field(default_factory = MediapipeAnnotatorConfig)


@dataclass
class MediapipeTracker(BaseTracker):
config: MediapipeTrackerConfig
detector: MediapipeDetector
annotator: MediapipeImageAnnotator

@classmethod
def create(cls, config: MediapipeTrackerConfig | None = None):
if config is None:
config = MediapipeTrackerConfig()
detector = MediapipeDetector.create(config.detector_config)

return cls(
config=config,
detector=detector,
annotator=MediapipeImageAnnotator.create(config.annotator_config),

)

def process_image(self,
image: np.ndarray,
annotate_image: bool = False) -> tuple[np.ndarray, MediapipeObservation] | MediapipeObservation:

latest_observation = self.detector.detect(image)
if annotate_image:
return self.annotator.annotate_image(image=image,
latest_observation=latest_observation), latest_observation
return latest_observation


if __name__ == "__main__":
MediapipeTracker.create().demo()
71 changes: 71 additions & 0 deletions skellytracker/trackers/mediapipe_tracker/mediapipe_annotator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from dataclasses import field, dataclass

import numpy as np
import cv2
from mediapipe.python.solutions import holistic as mp_holistic
from mediapipe.python.solutions import drawing_utils

from skellytracker.trackers.base_tracker.base_tracker import BaseImageAnnotatorConfig, BaseImageAnnotator
from skellytracker.trackers.mediapipe_tracker.mediapipe_observation import MediapipeObservation


class MediapipeAnnotatorConfig(BaseImageAnnotatorConfig):
show_tracks: int | None = 15
corner_marker_type: int = cv2.MARKER_DIAMOND
corner_marker_size: int = 10
corner_marker_thickness: int = 2
corner_marker_color: tuple[int, int, int] = (0, 0, 255)

aruco_lines_thickness: int = 2
aruco_lines_color: tuple[int, int, int] = (0, 255, 0)

text_color: tuple[int, int, int] = (215, 115, 40)
text_size: float = .5
text_thickness: int = 2
text_font: int = cv2.FONT_HERSHEY_SIMPLEX


@dataclass
class MediapipeImageAnnotator(BaseImageAnnotator):
config: MediapipeAnnotatorConfig
observations: list[MediapipeObservation] = field(default_factory=list)

@classmethod
def create(cls, config: MediapipeAnnotatorConfig):
return cls(config=config)

def annotate_image(
self,
image: np.ndarray,
latest_observation: MediapipeObservation | None = None,
) -> np.ndarray:
image_height, image_width = image.shape[:2]
text_offset = int(image_height * 0.01)

if latest_observation is None:
return image.copy()
# Copy the original image for annotation
annotated_image = image.copy()

drawing_utils.draw_landmarks(
annotated_image,
latest_observation.pose_landmarks,
mp_holistic.POSE_CONNECTIONS,
)
drawing_utils.draw_landmarks(
annotated_image,
latest_observation.right_hand_landmarks,
mp_holistic.HAND_CONNECTIONS,
)
drawing_utils.draw_landmarks(
annotated_image,
latest_observation.left_hand_landmarks,
mp_holistic.HAND_CONNECTIONS,
)
drawing_utils.draw_landmarks(
annotated_image,
latest_observation.face_landmarks,
mp_holistic.FACEMESH_TESSELATION,
)

return annotated_image
42 changes: 42 additions & 0 deletions skellytracker/trackers/mediapipe_tracker/mediapipe_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from dataclasses import dataclass

import cv2
import numpy as np
import mediapipe as mp

from skellytracker.trackers.base_tracker.base_tracker import BaseDetectorConfig, BaseDetector
from skellytracker.trackers.mediapipe_tracker.mediapipe_observation import MediapipeObservation


class MediapipeDetectorConfig(BaseDetectorConfig):
model_complexity = 2
min_detection_confidence = 0.5
min_tracking_confidence = 0.5
static_image_mode = False
smooth_landmarks = True


@dataclass
class MediapipeDetector(BaseDetector):
config: MediapipeDetectorConfig
detector: mp.solutions.holistic.Holistic

@classmethod
def create(cls, config: MediapipeDetectorConfig):
detector = mp.solutions.holistic.Holistic(
model_complexity=config.model_complexity,
min_detection_confidence=config.min_detection_confidence,
min_tracking_confidence=config.min_tracking_confidence,
static_image_mode=config.static_image_mode,
smooth_landmarks=config.smooth_landmarks,
)
return cls(
config=config,
detector=detector,
)

def detect(self, image: np.ndarray) -> MediapipeObservation:
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # assumes we're always getting BGR input - check with Jon to verify
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think BGR is a very cv2 specific convention at this point, its a rather old convention (.bmp images were BGR, for example), RGB is a much much more common standard these days.

Is Basler RGB? That is, do images come out of the Basler pipeline as RGB or BGR?

We should have a policy of converting cv2 images to RGB immediately upon read, and assuming they are RGB anywhere else in their lifetime.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, the RGB camera comes out in RGB, the IR in monochrome.

I think immediate conversion sounds like a good idea

return MediapipeObservation.from_holistic_results(self.detector.process(rgb_image),
image_size=(int(image.shape[0]), int(image.shape[1])),
)
140 changes: 140 additions & 0 deletions skellytracker/trackers/mediapipe_tracker/mediapipe_observation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import json
from dataclasses import dataclass
from typing import NamedTuple

import numpy as np
from mediapipe.python.solutions import holistic as mp_holistic
from mediapipe.python.solutions.face_mesh import FACEMESH_NUM_LANDMARKS_WITH_IRISES
from mediapipe.framework.formats.landmark_pb2 import NormalizedLandmarkList

from skellytracker.trackers.base_tracker.base_tracker import BaseObservation


@dataclass
class MediapipeObservation(BaseObservation):
pose_landmarks: NormalizedLandmarkList
right_hand_landmarks: NormalizedLandmarkList
left_hand_landmarks: NormalizedLandmarkList
face_landmarks: NormalizedLandmarkList

image_size: tuple[int, int]

@classmethod
def from_holistic_results(cls,
mediapipe_results: NamedTuple,
image_size: tuple[int, int]):
return cls(
pose_landmarks=mediapipe_results.pose_landmarks,
right_hand_landmarks=mediapipe_results.right_hand_landmarks,
left_hand_landmarks=mediapipe_results.left_hand_landmarks,
face_landmarks=mediapipe_results.face_landmarks,
image_size=image_size
)

@property
def charuco_empty(self):
return self.detected_charuco_corner_ids is None

@property
def aruco_empty(self):
return self.detected_aruco_marker_ids is None

@property
def body_landmark_names(self) -> list[str]:
return [landmark.name.lower() for landmark in mp_holistic.PoseLandmark]

@property
def hand_landmark_names(self) -> list[str]:
return [landmark.name.lower() for landmark in mp_holistic.HandLandmark]

@property
def num_body_points(self) -> int:
return len(self.body_landmark_names)

@property
def num_single_hand_points(self) -> int:
return len(self.hand_landmark_names)

@property
def num_face_points(self) -> int:
return FACEMESH_NUM_LANDMARKS_WITH_IRISES

@property
def num_total_points(self) -> int:
return self.num_body_points + (2 * self.num_single_hand_points) + self.num_face_points

@property
def pose_trajectories(self) -> np.ndarray[..., 3]: # TODO: not sure how to type these array sizes, seems like it doesn't like the `self` reference
Copy link
Member

@jonmatthis jonmatthis Jan 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typing is tricky with numpy arrays - I think there are pacakages that help with it, but nothing particularly venerable.

I think there's a way to set up custom type annotations and validators in pydantic (or dataclasses) or something? Or just type hint them as 'np.ndarray' and then just do a shape validation on initialization or read?

if self.pose_landmarks is None:
return np.full((self.num_body_points, 3), np.nan)

return self.landmarks_to_array(self.pose_landmarks)

@property
def right_hand_trajectories(self) -> np.ndarray[..., 3]:
if self.right_hand_landmarks is None:
return np.full((self.num_single_hand_points, 3), np.nan)

return self.landmarks_to_array(self.right_hand_landmarks)

@property
def left_hand_trajectories(self) -> np.ndarray[..., 3]:
if self.left_hand_landmarks is None:
return np.full((self.num_single_hand_points, 3), np.nan)

return self.landmarks_to_array(self.left_hand_landmarks)

@property
def face_trajectories(self) -> np.ndarray[..., 3]:
if self.face_landmarks is None:
return np.full((self.num_face_points, 3), np.nan)

return self.landmarks_to_array(self.face_landmarks)

@property
def all_holistic_trajectories(self) -> np.ndarray[533, 3]:
return np.concatenate(
# this order matters, do not change
(
self.pose_trajectories,
self.right_hand_trajectories,
self.left_hand_trajectories,
self.face_trajectories,
),
axis=0,
)

def landmarks_to_array(self, landmarks: NormalizedLandmarkList) -> np.ndarray[..., 3]:
landmark_array = np.array(
[
(landmark.x, landmark.y, landmark.z)
for landmark in landmarks.landmark
]
)

# convert from normalized image coordinates to pixel coordinates
landmark_array *= np.array([self.image_size[0], self.image_size[1], self.image_size[0]]) # multiply z by image width per mediapipe docs

return landmark_array

def to_serializable_dict(self) -> dict:
d = {
"pose_trajectories": self.pose_trajectories.tolist(),
"right_hand_trajectories": self.right_hand_trajectories.tolist(),
"left_hand_trajectories": self.left_hand_trajectories.tolist(),
"face_trajectories": self.face_trajectories.tolist(),
"image_size": self.image_size
}
try:
json.dumps(d).encode("utf-8")
except Exception as e:
raise ValueError(f"Failed to serialize CharucoObservation to JSON: {e}")
return d

def to_json_string(self) -> str:
return json.dumps(self.to_serializable_dict(), indent=4)

def to_json_bytes(self) -> bytes:
return self.to_json_string().encode("utf-8")

MediapipeObservations = list[MediapipeObservation]