Skip to content

Commit

Permalink
ruff formatted
Browse files Browse the repository at this point in the history
  • Loading branch information
toyat522 committed Jan 23, 2025
1 parent cc53ac3 commit 5e04949
Show file tree
Hide file tree
Showing 10 changed files with 888 additions and 742 deletions.
5 changes: 3 additions & 2 deletions src/ouroboros-keypoint-match/common/keypoint_matcher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .matcher_base import MatcherBase


class KeypointMatcher:
def set_matcher(self, matcher: MatcherBase):
self.matcher = matcher
Expand All @@ -12,7 +13,7 @@ def load_image(self, img_path, resize=None):
resize (tuple of int, optional): Desired size to resize the images to, specified as (height, width).
"""
return self.matcher.load_image(img_path, resize)

def execute(self, img0, img1):
"""
Performs image matching using the selected strategy.
Expand All @@ -21,7 +22,7 @@ def execute(self, img0, img1):
img1 (Tensor): Second image to be matched.
max_num_keypoints (int): Maximum number of keypoints to be found.
device (str, optional): The device to run the matching on, either "cuda" or "cpu".
Returns:
kpts0 (Tensor): Keypoints detected in the first image after feature extraction.
kpts1 (Tensor): Keypoints detected in the second image after feature extraction.
Expand Down
1 change: 1 addition & 0 deletions src/ouroboros-keypoint-match/common/matcher_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod


class MatcherBase(ABC):
@abstractmethod
def __init__(self, max_num_keypoints=2048, device=None):
Expand Down
16 changes: 12 additions & 4 deletions src/ouroboros-keypoint-match/lightglue/lightglue_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@
import matplotlib.pyplot as plt
import torch


class LightGlueMatcher(MatcherBase):
def __init__(self, max_num_keypoints=2048, device=None):
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device

self.extractor = SuperPoint(max_num_keypoints=max_num_keypoints).eval().to(self.device)
self.extractor = (
SuperPoint(max_num_keypoints=max_num_keypoints).eval().to(self.device)
)
self.matcher = LightGlue(features="superpoint").eval().to(self.device)

def load_image(self, img_path, resize):
Expand All @@ -30,10 +33,16 @@ def execute(self, img0, img1):

# Match the features
matches01 = self.matcher({"image0": feats0, "image1": feats1})
feats0, feats1, matches01 = [rbd(x) for x in [feats0, feats1, matches01]] # remove batch dimension
feats0, feats1, matches01 = [
rbd(x) for x in [feats0, feats1, matches01]
] # remove batch dimension

# Keep the matching keypoints
kpts0, kpts1, matches = feats0["keypoints"], feats1["keypoints"], matches01["matches"]
kpts0, kpts1, matches = (
feats0["keypoints"],
feats1["keypoints"],
matches01["matches"],
)
m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]

return kpts0, kpts1, m_kpts0, m_kpts1
Expand All @@ -49,4 +58,3 @@ def get_matcher_model(self):

def get_descriptor_model(self):
return self.extractor

4 changes: 1 addition & 3 deletions src/ouroboros-keypoint-match/lightglue/superpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ def sample_descriptors(keypoints, descriptors, s: int = 8):
keypoints = keypoints - s / 2 + 0.5
keypoints /= torch.tensor(
[(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
).to(
keypoints
)[None]
).to(keypoints)[None]
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
args = {"align_corners": True} if torch.__version__ >= "1.3" else {}
descriptors = torch.nn.functional.grid_sample(
Expand Down
956 changes: 478 additions & 478 deletions src/ouroboros-keypoint-match/notebooks/LightSuperGlueComparison.ipynb

Large diffs are not rendered by default.

21 changes: 11 additions & 10 deletions src/ouroboros-keypoint-match/superglue/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,28 +47,29 @@


class Matching(torch.nn.Module):
""" Image Matching Frontend (SuperPoint + SuperGlue) """
"""Image Matching Frontend (SuperPoint + SuperGlue)"""

def __init__(self, config={}):
super().__init__()
self.superpoint = SuperPoint(config.get('superpoint', {}))
self.superglue = SuperGlue(config.get('superglue', {}))
self.superpoint = SuperPoint(config.get("superpoint", {}))
self.superglue = SuperGlue(config.get("superglue", {}))

@torch.no_grad()
def forward(self, data):
""" Run SuperPoint (optionally) and SuperGlue
"""Run SuperPoint (optionally) and SuperGlue
SuperPoint is skipped if ['keypoints0', 'keypoints1'] exist in input
Args:
data: dictionary with minimal keys: ['image0', 'image1']
"""
pred = {}

# Extract SuperPoint (keypoints, scores, descriptors) if not provided
if 'keypoints0' not in data:
pred0 = self.superpoint({'image': data['image0']})
pred = {**pred, **{k+'0': v for k, v in pred0.items()}}
if 'keypoints1' not in data:
pred1 = self.superpoint({'image': data['image1']})
pred = {**pred, **{k+'1': v for k, v in pred1.items()}}
if "keypoints0" not in data:
pred0 = self.superpoint({"image": data["image0"]})
pred = {**pred, **{k + "0": v for k, v in pred0.items()}}
if "keypoints1" not in data:
pred1 = self.superpoint({"image": data["image1"]})
pred = {**pred, **{k + "1": v for k, v in pred1.items()}}

# Batch all features
# We should either have i) one image per batch, or
Expand Down
Loading

0 comments on commit 5e04949

Please sign in to comment.