diff --git a/.github/workflows/ci-action.yml b/.github/workflows/ci-action.yml index c96a881..d8dd0ef 100644 --- a/.github/workflows/ci-action.yml +++ b/.github/workflows/ci-action.yml @@ -22,6 +22,9 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + + - name: Pre-commit + run: pip install pre-commit && cd ${{ github.workspace }}/ouroboros_repo && pre-commit run --all-files - name: Install Ouroboros run: cd ${{ github.workspace }}/ouroboros_repo && pwd && pip install . - name: Run test script diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 86658ba..cf1908e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,20 +2,11 @@ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.2.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.9.2 hooks: - - id: trailing-whitespace - - id: end-of-file-fixer - - id: check-yaml - - id: check-added-large-files - - repo: https://github.com/psf/black - rev: 22.10.0 - hooks: - - id: black - - repo: https://github.com/pycqa/flake8 - rev: 7.1.1 - hooks: - - id: flake8 - additional_dependencies: [flake8-docstrings, flake8-bugbear] - exclude: .*setup\.py + # Run the linter. + - id: ruff + # Run the formatter. + - id: ruff-format diff --git a/examples/place_descriptor_debugging.py b/examples/place_descriptor_debugging.py index 2a0be9b..51788cb 100644 --- a/examples/place_descriptor_debugging.py +++ b/examples/place_descriptor_debugging.py @@ -3,10 +3,10 @@ import cv2 import matplotlib.pyplot as plt import numpy as np -from ouroboros_salad.salad_model import get_salad_model from spark_dataset_interfaces.rosbag_dataloader import RosbagDataLoader import ouroboros as ob +from ouroboros_salad.salad_model import get_salad_model embedding_model = get_salad_model() diff --git a/examples/salad_example.py b/examples/salad_example.py index aa8b6bf..bcf2d5a 100644 --- a/examples/salad_example.py +++ b/examples/salad_example.py @@ -1,7 +1,9 @@ import pathlib + import imageio.v3 as iio -from ouroboros_salad.salad_model import get_salad_model + import ouroboros as ob +from ouroboros_salad.salad_model import get_salad_model def resource_dir(): diff --git a/extra/ouroboros_ros/nodes/place_matching_example.py b/extra/ouroboros_ros/nodes/place_matching_example.py index 70acd92..6050cc1 100755 --- a/extra/ouroboros_ros/nodes/place_matching_example.py +++ b/extra/ouroboros_ros/nodes/place_matching_example.py @@ -1,14 +1,13 @@ #!/usr/bin/env python3 +import cv_bridge import rospy - from dynamic_reconfigure.server import Server +from ouroboros_ros.cfg import PlaceDescriptorDebuggingConfig from sensor_msgs.msg import Image -import cv_bridge -from ouroboros import VlcDb, SparkImage -from ouroboros_salad.salad_model import get_salad_model -from ouroboros_ros.cfg import PlaceDescriptorDebuggingConfig +from ouroboros import SparkImage, VlcDb from ouroboros.utils.plotting_utils import display_image_pair +from ouroboros_salad.salad_model import get_salad_model class PlaceDescriptorExample: diff --git a/extra/ouroboros_ros/setup.py b/extra/ouroboros_ros/setup.py index 3a1e21a..c12f894 100644 --- a/extra/ouroboros_ros/setup.py +++ b/extra/ouroboros_ros/setup.py @@ -1,5 +1,6 @@ # ! DO NOT MANUALLY INVOKE THIS setup.py, USE CATKIN INSTEAD from distutils.core import setup + from catkin_pkg.python_setup import generate_distutils_setup # fetch values from package.xml diff --git a/extra/ouroboros_ros/src/ouroboros_ros/utils.py b/extra/ouroboros_ros/src/ouroboros_ros/utils.py index a8967d2..de64d40 100644 --- a/extra/ouroboros_ros/src/ouroboros_ros/utils.py +++ b/extra/ouroboros_ros/src/ouroboros_ros/utils.py @@ -1,6 +1,6 @@ +import numpy as np import rospy import tf2_ros -import numpy as np from ouroboros import VlcPose diff --git a/extra/ouroboros_ros/src/ouroboros_ros/vlc_server_ros.py b/extra/ouroboros_ros/src/ouroboros_ros/vlc_server_ros.py index 9c66c5f..0b93adf 100644 --- a/extra/ouroboros_ros/src/ouroboros_ros/vlc_server_ros.py +++ b/extra/ouroboros_ros/src/ouroboros_ros/vlc_server_ros.py @@ -15,7 +15,6 @@ def update_plot(line, pts, images_to_pose): - if len(images_to_pose) == 0: return @@ -35,7 +34,6 @@ def update_plot(line, pts, images_to_pose): def plot_lc(lc, image_to_pose): - query_pos = image_to_pose[lc.from_image_uuid].position match_pos = image_to_pose[lc.to_image_uuid].position plt.plot( @@ -52,7 +50,6 @@ def build_lc_message( from_T_to, pose_cov, ): - relative_position = from_T_to[:3, 3] relative_orientation = R.from_matrix(from_T_to[:3, :3]) @@ -77,9 +74,7 @@ def build_lc_message( class VlcServerRos: - def __init__(self): - self.tf_buffer = tf2_ros.Buffer() self.listener = tf2_ros.TransformListener(self.tf_buffer) @@ -133,7 +128,6 @@ def __init__(self): self.image_sub = rospy.Subscriber("~image_in", Image, self.image_callback) def image_callback(self, msg): - if not ( self.last_vlc_frame_time is None or (rospy.Time.now() - self.last_vlc_frame_time).to_sec() @@ -173,7 +167,6 @@ def image_callback(self, msg): pg = PoseGraph() pg.header.stamp = rospy.Time.now() for lc in loop_closures: - if self.show_plots: plot_lc(lc, self.images_to_pose) @@ -210,7 +203,6 @@ def run(self): self.loop_rate.sleep() def step(self): - if self.show_plots: update_plot(self.position_line, self.position_points, self.images_to_pose) diff --git a/pyproject.toml b/pyproject.toml index e14cb4c..d1e1a9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,9 @@ profile = "black" testpaths = "tests" addopts = ["--cov-report=term-missing"] +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401", "F403"] + [tool.tox] requires = ["tox>=4"] skip_missing_interpreters = true diff --git a/setup.py b/setup.py index 105c532..f50bd3d 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup setup( name="ouroboros", diff --git a/src/ouroboros-superglue/superglue_model.py b/src/ouroboros-superglue/superglue_model.py index 7af42d9..2ba7f66 100644 --- a/src/ouroboros-superglue/superglue_model.py +++ b/src/ouroboros-superglue/superglue_model.py @@ -1,3 +1,2 @@ def get_superglue_model(): return lambda x: None - diff --git a/src/ouroboros/__init__.py b/src/ouroboros/__init__.py index 4e116ab..6ce7bf5 100644 --- a/src/ouroboros/__init__.py +++ b/src/ouroboros/__init__.py @@ -9,5 +9,4 @@ VlcImageMetadata, VlcPose, ) - from ouroboros.vlc_server.vlc_server import VlcServer diff --git a/src/ouroboros/utils/__init__.py b/src/ouroboros/utils/__init__.py index 139597f..e69de29 100644 --- a/src/ouroboros/utils/__init__.py +++ b/src/ouroboros/utils/__init__.py @@ -1,2 +0,0 @@ - - diff --git a/src/ouroboros/utils/plotting_utils.py b/src/ouroboros/utils/plotting_utils.py index 5e2ffe0..8115050 100644 --- a/src/ouroboros/utils/plotting_utils.py +++ b/src/ouroboros/utils/plotting_utils.py @@ -1,9 +1,8 @@ +import cv2 import matplotlib import matplotlib.pyplot as plt import numpy as np -import cv2 - cv2.startWindowThread() diff --git a/src/ouroboros/vlc_db/__init__.py b/src/ouroboros/vlc_db/__init__.py index 6b2dd05..72d3eb9 100644 --- a/src/ouroboros/vlc_db/__init__.py +++ b/src/ouroboros/vlc_db/__init__.py @@ -5,5 +5,5 @@ from ouroboros.vlc_db.vlc_image import * from ouroboros.vlc_db.vlc_image_table import * from ouroboros.vlc_db.vlc_lc_table import * -from ouroboros.vlc_db.vlc_session_table import * from ouroboros.vlc_db.vlc_pose import * +from ouroboros.vlc_db.vlc_session_table import * diff --git a/src/ouroboros/vlc_db/spark_loop_closure.py b/src/ouroboros/vlc_db/spark_loop_closure.py index 71e06bb..18758af 100644 --- a/src/ouroboros/vlc_db/spark_loop_closure.py +++ b/src/ouroboros/vlc_db/spark_loop_closure.py @@ -1,4 +1,5 @@ from dataclasses import dataclass + import numpy as np diff --git a/src/ouroboros/vlc_db/vlc_db.py b/src/ouroboros/vlc_db/vlc_db.py index 66a7d1d..803a4c3 100644 --- a/src/ouroboros/vlc_db/vlc_db.py +++ b/src/ouroboros/vlc_db/vlc_db.py @@ -1,16 +1,16 @@ from datetime import datetime -from typing import Union, Callable, TypeVar, Tuple, List, Optional +from typing import Callable, List, Optional, Tuple, TypeVar, Union + import numpy as np from ouroboros.vlc_db.spark_image import SparkImage +from ouroboros.vlc_db.spark_loop_closure import SparkLoopClosure +from ouroboros.vlc_db.utils import epoch_ns_from_datetime from ouroboros.vlc_db.vlc_image import VlcImage - from ouroboros.vlc_db.vlc_image_table import VlcImageTable from ouroboros.vlc_db.vlc_lc_table import LcTable -from ouroboros.vlc_db.vlc_session_table import SessionTable -from ouroboros.vlc_db.spark_loop_closure import SparkLoopClosure -from ouroboros.vlc_db.utils import epoch_ns_from_datetime from ouroboros.vlc_db.vlc_pose import VlcPose +from ouroboros.vlc_db.vlc_session_table import SessionTable class KeypointSizeException(BaseException): @@ -137,9 +137,9 @@ def batch_query_embeddings( Returns the top k closest matches and the match distances """ - assert ( - embeddings.ndim == 2 - ), f"Batch query requires an NxD array of query embeddings, not {embeddings.shape}" + assert embeddings.ndim == 2, ( + f"Batch query requires an NxD array of query embeddings, not {embeddings.shape}" + ) return self._image_table.query_embeddings(embeddings, k, similarity_metric) diff --git a/src/ouroboros/vlc_db/vlc_image_table.py b/src/ouroboros/vlc_db/vlc_image_table.py index 86181de..663b11f 100644 --- a/src/ouroboros/vlc_db/vlc_image_table.py +++ b/src/ouroboros/vlc_db/vlc_image_table.py @@ -1,13 +1,14 @@ +import uuid from datetime import datetime from typing import Union -import uuid + import numpy as np +from ouroboros.vlc_db.invertible_vector_store import InvertibleVectorStore from ouroboros.vlc_db.spark_image import SparkImage +from ouroboros.vlc_db.utils import epoch_ns_from_datetime from ouroboros.vlc_db.vlc_image import VlcImage, VlcImageMetadata from ouroboros.vlc_db.vlc_pose import VlcPose -from ouroboros.vlc_db.invertible_vector_store import InvertibleVectorStore -from ouroboros.vlc_db.utils import epoch_ns_from_datetime class VlcImageTable: diff --git a/src/ouroboros/vlc_db/vlc_pose.py b/src/ouroboros/vlc_db/vlc_pose.py index ccfb3d3..db51999 100644 --- a/src/ouroboros/vlc_db/vlc_pose.py +++ b/src/ouroboros/vlc_db/vlc_pose.py @@ -1,4 +1,5 @@ from dataclasses import dataclass + import numpy as np from scipy.spatial.transform import Rotation as R diff --git a/src/ouroboros/vlc_server/vlc_server.py b/src/ouroboros/vlc_server/vlc_server.py index 6f84476..00e1f40 100644 --- a/src/ouroboros/vlc_server/vlc_server.py +++ b/src/ouroboros/vlc_server/vlc_server.py @@ -1,12 +1,12 @@ # TODO: move this to ouroboros, not ouroboros_ros -from typing import Tuple, List, Optional -from ouroboros.utils.plotting_utils import display_image_pair -import ouroboros as ob from datetime import datetime +from typing import List, Optional, Tuple + +import ouroboros as ob +from ouroboros.utils.plotting_utils import display_image_pair class VlcServer: - def __init__( self, place_method, @@ -71,7 +71,6 @@ def __init__( def add_and_query_frame( self, image: ob.SparkImage, time_ns: int, pose_hint: ob.VlcPose = None ) -> Tuple[str, Optional[List[ob.SparkLoopClosure]]]: - embedding = self.place_model.infer(image, pose_hint) image_matches, similarities = self.vlc_db.query_embeddings_max_time( @@ -111,7 +110,6 @@ def add_and_query_frame( # TODO: support multiple possible place descriptor matches if image_match is not None: - if not self.strict_keypoint_evaluation: # Since we just added the current image, we know that no keypoints # or descriptors have been generated for it diff --git a/src/ouroboros_gt/gt_descriptors.py b/src/ouroboros_gt/gt_descriptors.py index 2676d63..560b95c 100644 --- a/src/ouroboros_gt/gt_descriptors.py +++ b/src/ouroboros_gt/gt_descriptors.py @@ -8,7 +8,6 @@ def get_gt_descriptor_model(): class GtDescriptorModel: - def __init__(self): pass diff --git a/src/ouroboros_gt/gt_keypoint_detection.py b/src/ouroboros_gt/gt_keypoint_detection.py index d70ca70..82f5404 100644 --- a/src/ouroboros_gt/gt_keypoint_detection.py +++ b/src/ouroboros_gt/gt_keypoint_detection.py @@ -8,7 +8,6 @@ def get_gt_keypoint_model(): class GtKeypointModel: - def __init__(self): pass diff --git a/src/ouroboros_gt/gt_place_recognition.py b/src/ouroboros_gt/gt_place_recognition.py index 038c80b..bc5892c 100644 --- a/src/ouroboros_gt/gt_place_recognition.py +++ b/src/ouroboros_gt/gt_place_recognition.py @@ -8,9 +8,7 @@ def get_gt_place_model(): class GtPlaceModel: - def __init__(self): - self.embedding_size = 8 self.lc_recent_lockout_ns = 20 * 1e9 self.lc_distance_threshold = 5 diff --git a/src/ouroboros_gt/gt_pose_recovery.py b/src/ouroboros_gt/gt_pose_recovery.py index 8bfd034..e4f0291 100644 --- a/src/ouroboros_gt/gt_pose_recovery.py +++ b/src/ouroboros_gt/gt_pose_recovery.py @@ -1,5 +1,5 @@ -from ouroboros.vlc_db.vlc_pose import pose_from_quat_trans, invert_pose, VlcPose import ouroboros as ob +from ouroboros.vlc_db.vlc_pose import VlcPose, invert_pose, pose_from_quat_trans def recover_pose(query_descriptors, match_descriptors): @@ -22,7 +22,6 @@ def get_gt_pose_model(): class GtPoseModel: - def __init__(self): pass diff --git a/src/ouroboros_salad/salad_model.py b/src/ouroboros_salad/salad_model.py index 48e5bbe..8a9f14f 100644 --- a/src/ouroboros_salad/salad_model.py +++ b/src/ouroboros_salad/salad_model.py @@ -1,15 +1,14 @@ import numpy as np import torch import torchvision.transforms as T + import ouroboros as ob torch.backends.cudnn.benchmark = True class SaladModel: - def __init__(self, model): - self.embedding_size = 8448 self.similarity_metric = "ip" self.model = model diff --git a/tests/test_vlc_db.py b/tests/test_vlc_db.py index 4ec9317..1fc5c15 100644 --- a/tests/test_vlc_db.py +++ b/tests/test_vlc_db.py @@ -81,16 +81,16 @@ def test_query_embedding(): imgs, sims = vlc_db.query_embeddings(np.array([0, 0.2, 0.8]), 2) assert len(imgs) == 2, "query_embeddings returned wrong number of matches" - assert set((a_id, d_id)) == set( - [img.metadata.image_uuid for img in imgs] - ), "Query did not match k closest" + assert set((a_id, d_id)) == set([img.metadata.image_uuid for img in imgs]), ( + "Query did not match k closest" + ) imgs, sims = vlc_db.query_embeddings(np.array([0, 0.2, 0.8]), 3) assert len(imgs) == 3, "query_embeddings returned wrong number of matches" - assert set((a_id, c_id, d_id)) == set( - [img.metadata.image_uuid for img in imgs] - ), "Query did not match k closest" + assert set((a_id, c_id, d_id)) == set([img.metadata.image_uuid for img in imgs]), ( + "Query did not match k closest" + ) def test_query_embedding_custom_similarity(): @@ -162,12 +162,12 @@ def test_batch_query_embeddings(): imgs, sims = vlc_db.batch_query_embeddings(np.array([[0, 1, 0], [0, 0.2, 0.2]]), 1) assert len(imgs) == 2, "number of results different from number of query vectors" assert len(imgs[0]) == 1, "returned wrong number of matches for vector" - assert ( - imgs[0][0].metadata.image_uuid == b_id - ), "Did not match correct image embedding" - assert ( - imgs[1][0].metadata.image_uuid == c_id - ), "Did not match correct image embedding" + assert imgs[0][0].metadata.image_uuid == b_id, ( + "Did not match correct image embedding" + ) + assert imgs[1][0].metadata.image_uuid == c_id, ( + "Did not match correct image embedding" + ) def test_batch_query_embeddings_uuid_filter():