Skip to content

Commit

Permalink
Switch to ruff, format everything
Browse files Browse the repository at this point in the history
  • Loading branch information
GoldenZephyr committed Jan 21, 2025
1 parent 5401a85 commit d9b756d
Show file tree
Hide file tree
Showing 26 changed files with 58 additions and 77 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci-action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Pre-commit
run: pip install pre-commit && cd ${{ github.workspace }}/ouroboros_repo && pre-commit run --all-files
- name: Install Ouroboros
run: cd ${{ github.workspace }}/ouroboros_repo && pwd && pip install .
- name: Run test script
Expand Down
23 changes: 7 additions & 16 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,11 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.9.2
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/psf/black
rev: 22.10.0
hooks:
- id: black
- repo: https://github.com/pycqa/flake8
rev: 7.1.1
hooks:
- id: flake8
additional_dependencies: [flake8-docstrings, flake8-bugbear]
exclude: .*setup\.py
# Run the linter.
- id: ruff
# Run the formatter.
- id: ruff-format
2 changes: 1 addition & 1 deletion examples/place_descriptor_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import cv2
import matplotlib.pyplot as plt
import numpy as np
from ouroboros_salad.salad_model import get_salad_model
from spark_dataset_interfaces.rosbag_dataloader import RosbagDataLoader

import ouroboros as ob
from ouroboros_salad.salad_model import get_salad_model

embedding_model = get_salad_model()

Expand Down
4 changes: 3 additions & 1 deletion examples/salad_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pathlib

import imageio.v3 as iio
from ouroboros_salad.salad_model import get_salad_model

import ouroboros as ob
from ouroboros_salad.salad_model import get_salad_model


def resource_dir():
Expand Down
9 changes: 4 additions & 5 deletions extra/ouroboros_ros/nodes/place_matching_example.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
#!/usr/bin/env python3
import cv_bridge
import rospy

from dynamic_reconfigure.server import Server
from ouroboros_ros.cfg import PlaceDescriptorDebuggingConfig
from sensor_msgs.msg import Image
import cv_bridge

from ouroboros import VlcDb, SparkImage
from ouroboros_salad.salad_model import get_salad_model
from ouroboros_ros.cfg import PlaceDescriptorDebuggingConfig
from ouroboros import SparkImage, VlcDb
from ouroboros.utils.plotting_utils import display_image_pair
from ouroboros_salad.salad_model import get_salad_model


class PlaceDescriptorExample:
Expand Down
1 change: 1 addition & 0 deletions extra/ouroboros_ros/setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# ! DO NOT MANUALLY INVOKE THIS setup.py, USE CATKIN INSTEAD
from distutils.core import setup

from catkin_pkg.python_setup import generate_distutils_setup

# fetch values from package.xml
Expand Down
2 changes: 1 addition & 1 deletion extra/ouroboros_ros/src/ouroboros_ros/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import rospy
import tf2_ros
import numpy as np

from ouroboros import VlcPose

Expand Down
8 changes: 0 additions & 8 deletions extra/ouroboros_ros/src/ouroboros_ros/vlc_server_ros.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@


def update_plot(line, pts, images_to_pose):

if len(images_to_pose) == 0:
return

Expand All @@ -35,7 +34,6 @@ def update_plot(line, pts, images_to_pose):


def plot_lc(lc, image_to_pose):

query_pos = image_to_pose[lc.from_image_uuid].position
match_pos = image_to_pose[lc.to_image_uuid].position
plt.plot(
Expand All @@ -52,7 +50,6 @@ def build_lc_message(
from_T_to,
pose_cov,
):

relative_position = from_T_to[:3, 3]
relative_orientation = R.from_matrix(from_T_to[:3, :3])

Expand All @@ -77,9 +74,7 @@ def build_lc_message(


class VlcServerRos:

def __init__(self):

self.tf_buffer = tf2_ros.Buffer()
self.listener = tf2_ros.TransformListener(self.tf_buffer)

Expand Down Expand Up @@ -133,7 +128,6 @@ def __init__(self):
self.image_sub = rospy.Subscriber("~image_in", Image, self.image_callback)

def image_callback(self, msg):

if not (
self.last_vlc_frame_time is None
or (rospy.Time.now() - self.last_vlc_frame_time).to_sec()
Expand Down Expand Up @@ -173,7 +167,6 @@ def image_callback(self, msg):
pg = PoseGraph()
pg.header.stamp = rospy.Time.now()
for lc in loop_closures:

if self.show_plots:
plot_lc(lc, self.images_to_pose)

Expand Down Expand Up @@ -210,7 +203,6 @@ def run(self):
self.loop_rate.sleep()

def step(self):

if self.show_plots:
update_plot(self.position_line, self.position_points, self.images_to_pose)

Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ profile = "black"
testpaths = "tests"
addopts = ["--cov-report=term-missing"]

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401", "F403"]

[tool.tox]
requires = ["tox>=4"]
skip_missing_interpreters = true
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from setuptools import setup, find_packages
from setuptools import find_packages, setup

setup(
name="ouroboros",
Expand Down
1 change: 0 additions & 1 deletion src/ouroboros-superglue/superglue_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
def get_superglue_model():
return lambda x: None

1 change: 0 additions & 1 deletion src/ouroboros/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,4 @@
VlcImageMetadata,
VlcPose,
)

from ouroboros.vlc_server.vlc_server import VlcServer
2 changes: 0 additions & 2 deletions src/ouroboros/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@


3 changes: 1 addition & 2 deletions src/ouroboros/utils/plotting_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import cv2

cv2.startWindowThread()


Expand Down
2 changes: 1 addition & 1 deletion src/ouroboros/vlc_db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
from ouroboros.vlc_db.vlc_image import *
from ouroboros.vlc_db.vlc_image_table import *
from ouroboros.vlc_db.vlc_lc_table import *
from ouroboros.vlc_db.vlc_session_table import *
from ouroboros.vlc_db.vlc_pose import *
from ouroboros.vlc_db.vlc_session_table import *
1 change: 1 addition & 0 deletions src/ouroboros/vlc_db/spark_loop_closure.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass

import numpy as np


Expand Down
16 changes: 8 additions & 8 deletions src/ouroboros/vlc_db/vlc_db.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from datetime import datetime
from typing import Union, Callable, TypeVar, Tuple, List, Optional
from typing import Callable, List, Optional, Tuple, TypeVar, Union

import numpy as np

from ouroboros.vlc_db.spark_image import SparkImage
from ouroboros.vlc_db.spark_loop_closure import SparkLoopClosure
from ouroboros.vlc_db.utils import epoch_ns_from_datetime
from ouroboros.vlc_db.vlc_image import VlcImage

from ouroboros.vlc_db.vlc_image_table import VlcImageTable
from ouroboros.vlc_db.vlc_lc_table import LcTable
from ouroboros.vlc_db.vlc_session_table import SessionTable
from ouroboros.vlc_db.spark_loop_closure import SparkLoopClosure
from ouroboros.vlc_db.utils import epoch_ns_from_datetime
from ouroboros.vlc_db.vlc_pose import VlcPose
from ouroboros.vlc_db.vlc_session_table import SessionTable


class KeypointSizeException(BaseException):
Expand Down Expand Up @@ -137,9 +137,9 @@ def batch_query_embeddings(
Returns the top k closest matches and the match distances
"""

assert (
embeddings.ndim == 2
), f"Batch query requires an NxD array of query embeddings, not {embeddings.shape}"
assert embeddings.ndim == 2, (
f"Batch query requires an NxD array of query embeddings, not {embeddings.shape}"
)

return self._image_table.query_embeddings(embeddings, k, similarity_metric)

Expand Down
7 changes: 4 additions & 3 deletions src/ouroboros/vlc_db/vlc_image_table.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import uuid
from datetime import datetime
from typing import Union
import uuid

import numpy as np

from ouroboros.vlc_db.invertible_vector_store import InvertibleVectorStore
from ouroboros.vlc_db.spark_image import SparkImage
from ouroboros.vlc_db.utils import epoch_ns_from_datetime
from ouroboros.vlc_db.vlc_image import VlcImage, VlcImageMetadata
from ouroboros.vlc_db.vlc_pose import VlcPose
from ouroboros.vlc_db.invertible_vector_store import InvertibleVectorStore
from ouroboros.vlc_db.utils import epoch_ns_from_datetime


class VlcImageTable:
Expand Down
1 change: 1 addition & 0 deletions src/ouroboros/vlc_db/vlc_pose.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass

import numpy as np
from scipy.spatial.transform import Rotation as R

Expand Down
10 changes: 4 additions & 6 deletions src/ouroboros/vlc_server/vlc_server.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# TODO: move this to ouroboros, not ouroboros_ros
from typing import Tuple, List, Optional
from ouroboros.utils.plotting_utils import display_image_pair
import ouroboros as ob
from datetime import datetime
from typing import List, Optional, Tuple

import ouroboros as ob
from ouroboros.utils.plotting_utils import display_image_pair


class VlcServer:

def __init__(
self,
place_method,
Expand Down Expand Up @@ -71,7 +71,6 @@ def __init__(
def add_and_query_frame(
self, image: ob.SparkImage, time_ns: int, pose_hint: ob.VlcPose = None
) -> Tuple[str, Optional[List[ob.SparkLoopClosure]]]:

embedding = self.place_model.infer(image, pose_hint)

image_matches, similarities = self.vlc_db.query_embeddings_max_time(
Expand Down Expand Up @@ -111,7 +110,6 @@ def add_and_query_frame(

# TODO: support multiple possible place descriptor matches
if image_match is not None:

if not self.strict_keypoint_evaluation:
# Since we just added the current image, we know that no keypoints
# or descriptors have been generated for it
Expand Down
1 change: 0 additions & 1 deletion src/ouroboros_gt/gt_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ def get_gt_descriptor_model():


class GtDescriptorModel:

def __init__(self):
pass

Expand Down
1 change: 0 additions & 1 deletion src/ouroboros_gt/gt_keypoint_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ def get_gt_keypoint_model():


class GtKeypointModel:

def __init__(self):
pass

Expand Down
2 changes: 0 additions & 2 deletions src/ouroboros_gt/gt_place_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ def get_gt_place_model():


class GtPlaceModel:

def __init__(self):

self.embedding_size = 8
self.lc_recent_lockout_ns = 20 * 1e9
self.lc_distance_threshold = 5
Expand Down
3 changes: 1 addition & 2 deletions src/ouroboros_gt/gt_pose_recovery.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ouroboros.vlc_db.vlc_pose import pose_from_quat_trans, invert_pose, VlcPose
import ouroboros as ob
from ouroboros.vlc_db.vlc_pose import VlcPose, invert_pose, pose_from_quat_trans


def recover_pose(query_descriptors, match_descriptors):
Expand All @@ -22,7 +22,6 @@ def get_gt_pose_model():


class GtPoseModel:

def __init__(self):
pass

Expand Down
3 changes: 1 addition & 2 deletions src/ouroboros_salad/salad_model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import numpy as np
import torch
import torchvision.transforms as T

import ouroboros as ob

torch.backends.cudnn.benchmark = True


class SaladModel:

def __init__(self, model):

self.embedding_size = 8448
self.similarity_metric = "ip"
self.model = model
Expand Down
24 changes: 12 additions & 12 deletions tests/test_vlc_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,16 @@ def test_query_embedding():

imgs, sims = vlc_db.query_embeddings(np.array([0, 0.2, 0.8]), 2)
assert len(imgs) == 2, "query_embeddings returned wrong number of matches"
assert set((a_id, d_id)) == set(
[img.metadata.image_uuid for img in imgs]
), "Query did not match k closest"
assert set((a_id, d_id)) == set([img.metadata.image_uuid for img in imgs]), (
"Query did not match k closest"
)

imgs, sims = vlc_db.query_embeddings(np.array([0, 0.2, 0.8]), 3)
assert len(imgs) == 3, "query_embeddings returned wrong number of matches"

assert set((a_id, c_id, d_id)) == set(
[img.metadata.image_uuid for img in imgs]
), "Query did not match k closest"
assert set((a_id, c_id, d_id)) == set([img.metadata.image_uuid for img in imgs]), (
"Query did not match k closest"
)


def test_query_embedding_custom_similarity():
Expand Down Expand Up @@ -162,12 +162,12 @@ def test_batch_query_embeddings():
imgs, sims = vlc_db.batch_query_embeddings(np.array([[0, 1, 0], [0, 0.2, 0.2]]), 1)
assert len(imgs) == 2, "number of results different from number of query vectors"
assert len(imgs[0]) == 1, "returned wrong number of matches for vector"
assert (
imgs[0][0].metadata.image_uuid == b_id
), "Did not match correct image embedding"
assert (
imgs[1][0].metadata.image_uuid == c_id
), "Did not match correct image embedding"
assert imgs[0][0].metadata.image_uuid == b_id, (
"Did not match correct image embedding"
)
assert imgs[1][0].metadata.image_uuid == c_id, (
"Did not match correct image embedding"
)


def test_batch_query_embeddings_uuid_filter():
Expand Down

0 comments on commit d9b756d

Please sign in to comment.