-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mediapipe into development #53
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .__mediapipe_tracker import MediapipeTracker, MediapipeTrackerConfig |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import logging | ||
from dataclasses import dataclass, field | ||
|
||
import numpy as np | ||
|
||
from skellytracker.trackers.base_tracker.base_tracker import BaseTracker, BaseTrackerConfig | ||
from skellytracker.trackers.mediapipe_tracker.mediapipe_annotator import MediapipeAnnotatorConfig, MediapipeImageAnnotator | ||
from skellytracker.trackers.mediapipe_tracker.mediapipe_detector import MediapipeDetector, MediapipeDetectorConfig | ||
from skellytracker.trackers.mediapipe_tracker.mediapipe_observation import MediapipeObservation | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
@dataclass | ||
class MediapipeTrackerConfig(BaseTrackerConfig): | ||
detector_config: MediapipeDetectorConfig = field(default_factory = MediapipeDetectorConfig) | ||
annotator_config: MediapipeAnnotatorConfig = field(default_factory = MediapipeAnnotatorConfig) | ||
|
||
|
||
@dataclass | ||
class MediapipeTracker(BaseTracker): | ||
config: MediapipeTrackerConfig | ||
detector: MediapipeDetector | ||
annotator: MediapipeImageAnnotator | ||
|
||
@classmethod | ||
def create(cls, config: MediapipeTrackerConfig | None = None): | ||
if config is None: | ||
config = MediapipeTrackerConfig() | ||
detector = MediapipeDetector.create(config.detector_config) | ||
|
||
return cls( | ||
config=config, | ||
detector=detector, | ||
annotator=MediapipeImageAnnotator.create(config.annotator_config), | ||
|
||
) | ||
|
||
def process_image(self, | ||
image: np.ndarray, | ||
annotate_image: bool = False) -> tuple[np.ndarray, MediapipeObservation] | MediapipeObservation: | ||
|
||
latest_observation = self.detector.detect(image) | ||
if annotate_image: | ||
return self.annotator.annotate_image(image=image, | ||
latest_observation=latest_observation), latest_observation | ||
return latest_observation | ||
|
||
|
||
if __name__ == "__main__": | ||
MediapipeTracker.create().demo() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from dataclasses import field, dataclass | ||
|
||
import numpy as np | ||
import cv2 | ||
from mediapipe.python.solutions import holistic as mp_holistic | ||
from mediapipe.python.solutions import drawing_utils | ||
|
||
from skellytracker.trackers.base_tracker.base_tracker import BaseImageAnnotatorConfig, BaseImageAnnotator | ||
from skellytracker.trackers.mediapipe_tracker.mediapipe_observation import MediapipeObservation | ||
|
||
|
||
class MediapipeAnnotatorConfig(BaseImageAnnotatorConfig): | ||
show_tracks: int | None = 15 | ||
corner_marker_type: int = cv2.MARKER_DIAMOND | ||
corner_marker_size: int = 10 | ||
corner_marker_thickness: int = 2 | ||
corner_marker_color: tuple[int, int, int] = (0, 0, 255) | ||
|
||
aruco_lines_thickness: int = 2 | ||
aruco_lines_color: tuple[int, int, int] = (0, 255, 0) | ||
|
||
text_color: tuple[int, int, int] = (215, 115, 40) | ||
text_size: float = .5 | ||
text_thickness: int = 2 | ||
text_font: int = cv2.FONT_HERSHEY_SIMPLEX | ||
|
||
|
||
@dataclass | ||
class MediapipeImageAnnotator(BaseImageAnnotator): | ||
config: MediapipeAnnotatorConfig | ||
observations: list[MediapipeObservation] = field(default_factory=list) | ||
|
||
@classmethod | ||
def create(cls, config: MediapipeAnnotatorConfig): | ||
return cls(config=config) | ||
|
||
def annotate_image( | ||
self, | ||
image: np.ndarray, | ||
latest_observation: MediapipeObservation | None = None, | ||
) -> np.ndarray: | ||
image_height, image_width = image.shape[:2] | ||
text_offset = int(image_height * 0.01) | ||
|
||
if latest_observation is None: | ||
return image.copy() | ||
# Copy the original image for annotation | ||
annotated_image = image.copy() | ||
|
||
drawing_utils.draw_landmarks( | ||
annotated_image, | ||
latest_observation.pose_landmarks, | ||
mp_holistic.POSE_CONNECTIONS, | ||
) | ||
drawing_utils.draw_landmarks( | ||
annotated_image, | ||
latest_observation.right_hand_landmarks, | ||
mp_holistic.HAND_CONNECTIONS, | ||
) | ||
drawing_utils.draw_landmarks( | ||
annotated_image, | ||
latest_observation.left_hand_landmarks, | ||
mp_holistic.HAND_CONNECTIONS, | ||
) | ||
drawing_utils.draw_landmarks( | ||
annotated_image, | ||
latest_observation.face_landmarks, | ||
mp_holistic.FACEMESH_TESSELATION, | ||
) | ||
|
||
return annotated_image |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
|
||
import cv2 | ||
import numpy as np | ||
import mediapipe as mp | ||
|
||
from skellytracker.trackers.base_tracker.base_tracker import BaseDetectorConfig, BaseDetector | ||
from skellytracker.trackers.mediapipe_tracker.mediapipe_observation import MediapipeObservation | ||
|
||
class MediapipeModelComplexity(int, Enum): | ||
LITE = 0 | ||
FULL = 1 | ||
HEAVY = 2 | ||
|
||
class MediapipeDetectorConfig(BaseDetectorConfig): | ||
model_complexity: MediapipeModelComplexity = MediapipeModelComplexity.HEAVY.value | ||
min_detection_confidence:float = 0.5 | ||
min_tracking_confidence:float = 0.5 | ||
static_image_mode:bool = False | ||
smooth_landmarks:bool = True | ||
|
||
|
||
@dataclass | ||
class MediapipeDetector(BaseDetector): | ||
config: MediapipeDetectorConfig | ||
detector: mp.solutions.holistic.Holistic | ||
|
||
@classmethod | ||
def create(cls, config: MediapipeDetectorConfig): | ||
detector = mp.solutions.holistic.Holistic( | ||
model_complexity=config.model_complexity, | ||
min_detection_confidence=config.min_detection_confidence, | ||
min_tracking_confidence=config.min_tracking_confidence, | ||
static_image_mode=config.static_image_mode, | ||
smooth_landmarks=config.smooth_landmarks, | ||
) | ||
return cls( | ||
config=config, | ||
detector=detector, | ||
) | ||
|
||
def detect(self, image: np.ndarray) -> MediapipeObservation: | ||
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # assumes we're always getting BGR input - check with Jon to verify | ||
return MediapipeObservation.from_holistic_results(self.detector.process(rgb_image), | ||
image_size=(int(image.shape[0]), int(image.shape[1])), | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import json | ||
from dataclasses import dataclass | ||
from typing import NamedTuple | ||
|
||
import numpy as np | ||
from mediapipe.python.solutions import holistic as mp_holistic | ||
from mediapipe.python.solutions.face_mesh import FACEMESH_NUM_LANDMARKS_WITH_IRISES | ||
from mediapipe.framework.formats.landmark_pb2 import NormalizedLandmarkList | ||
|
||
from skellytracker.trackers.base_tracker.base_tracker import BaseObservation | ||
|
||
|
||
@dataclass | ||
class MediapipeObservation(BaseObservation): | ||
pose_landmarks: NormalizedLandmarkList | ||
right_hand_landmarks: NormalizedLandmarkList | ||
left_hand_landmarks: NormalizedLandmarkList | ||
face_landmarks: NormalizedLandmarkList | ||
|
||
image_size: tuple[int, int] | ||
|
||
@classmethod | ||
def from_holistic_results(cls, | ||
mediapipe_results: NamedTuple, | ||
image_size: tuple[int, int]): | ||
return cls( | ||
pose_landmarks=mediapipe_results.pose_landmarks, | ||
right_hand_landmarks=mediapipe_results.right_hand_landmarks, | ||
left_hand_landmarks=mediapipe_results.left_hand_landmarks, | ||
face_landmarks=mediapipe_results.face_landmarks, | ||
image_size=image_size | ||
) | ||
|
||
@property | ||
def charuco_empty(self): | ||
return self.detected_charuco_corner_ids is None | ||
|
||
@property | ||
def aruco_empty(self): | ||
return self.detected_aruco_marker_ids is None | ||
|
||
@property | ||
def body_landmark_names(self) -> list[str]: | ||
return [landmark.name.lower() for landmark in mp_holistic.PoseLandmark] | ||
|
||
@property | ||
def hand_landmark_names(self) -> list[str]: | ||
return [landmark.name.lower() for landmark in mp_holistic.HandLandmark] | ||
|
||
@property | ||
def num_body_points(self) -> int: | ||
return len(self.body_landmark_names) | ||
|
||
@property | ||
def num_single_hand_points(self) -> int: | ||
return len(self.hand_landmark_names) | ||
|
||
@property | ||
def num_face_points(self) -> int: | ||
return FACEMESH_NUM_LANDMARKS_WITH_IRISES | ||
|
||
@property | ||
def num_total_points(self) -> int: | ||
return self.num_body_points + (2 * self.num_single_hand_points) + self.num_face_points | ||
|
||
@property | ||
def pose_trajectories(self) -> np.ndarray[..., 3]: # TODO: not sure how to type these array sizes, seems like it doesn't like the `self` reference | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typing is tricky with numpy arrays - I think there are pacakages that help with it, but nothing particularly venerable. I think there's a way to set up custom type annotations and validators in pydantic (or dataclasses) or something? Or just type hint them as 'np.ndarray' and then just do a shape validation on initialization or read? |
||
if self.pose_landmarks is None: | ||
return np.full((self.num_body_points, 3), np.nan) | ||
|
||
return self.landmarks_to_array(self.pose_landmarks) | ||
|
||
@property | ||
def right_hand_trajectories(self) -> np.ndarray[..., 3]: | ||
if self.right_hand_landmarks is None: | ||
return np.full((self.num_single_hand_points, 3), np.nan) | ||
|
||
return self.landmarks_to_array(self.right_hand_landmarks) | ||
|
||
@property | ||
def left_hand_trajectories(self) -> np.ndarray[..., 3]: | ||
if self.left_hand_landmarks is None: | ||
return np.full((self.num_single_hand_points, 3), np.nan) | ||
|
||
return self.landmarks_to_array(self.left_hand_landmarks) | ||
|
||
@property | ||
def face_trajectories(self) -> np.ndarray[..., 3]: | ||
if self.face_landmarks is None: | ||
return np.full((self.num_face_points, 3), np.nan) | ||
|
||
return self.landmarks_to_array(self.face_landmarks) | ||
|
||
@property | ||
def all_holistic_trajectories(self) -> np.ndarray[533, 3]: | ||
return np.concatenate( | ||
# this order matters, do not change | ||
( | ||
self.pose_trajectories, | ||
self.right_hand_trajectories, | ||
self.left_hand_trajectories, | ||
self.face_trajectories, | ||
), | ||
axis=0, | ||
) | ||
|
||
def landmarks_to_array(self, landmarks: NormalizedLandmarkList) -> np.ndarray[..., 3]: | ||
landmark_array = np.array( | ||
[ | ||
(landmark.x, landmark.y, landmark.z) | ||
for landmark in landmarks.landmark | ||
] | ||
) | ||
|
||
# convert from normalized image coordinates to pixel coordinates | ||
landmark_array *= np.array([self.image_size[0], self.image_size[1], self.image_size[0]]) # multiply z by image width per mediapipe docs | ||
|
||
return landmark_array | ||
|
||
def to_serializable_dict(self) -> dict: | ||
d = { | ||
"pose_trajectories": self.pose_trajectories.tolist(), | ||
"right_hand_trajectories": self.right_hand_trajectories.tolist(), | ||
"left_hand_trajectories": self.left_hand_trajectories.tolist(), | ||
"face_trajectories": self.face_trajectories.tolist(), | ||
"image_size": self.image_size | ||
} | ||
try: | ||
json.dumps(d).encode("utf-8") | ||
except Exception as e: | ||
raise ValueError(f"Failed to serialize CharucoObservation to JSON: {e}") | ||
return d | ||
|
||
def to_json_string(self) -> str: | ||
return json.dumps(self.to_serializable_dict(), indent=4) | ||
|
||
def to_json_bytes(self) -> bytes: | ||
return self.to_json_string().encode("utf-8") | ||
|
||
MediapipeObservations = list[MediapipeObservation] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think BGR is a very
cv2
specific convention at this point, its a rather old convention (.bmp
images were BGR, for example), RGB is a much much more common standard these days.Is Basler RGB? That is, do images come out of the Basler pipeline as RGB or BGR?
We should have a policy of converting cv2 images to RGB immediately upon read, and assuming they are RGB anywhere else in their lifetime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, the RGB camera comes out in RGB, the IR in monochrome.
I think immediate conversion sounds like a good idea