Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/query interface #3

Merged
merged 3 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 13 additions & 41 deletions examples/place_descriptor_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,54 +67,26 @@ def combine_images(left, right):
uid_to_pose[uid] = pose.matrix()


# Query for closest matches
query_embeddings = np.array([image.embedding for image in vlc_db.iterate_images()])

matches, similarities = vlc_db.query_embeddings(
query_embeddings, -1, similarity_metric="ip"
)


# TODO: move this to db query or utility function
# Ignore matches that are too close temporally or too far in descriptor similarity
times = []
closest = []
putative_loop_closures = []
for key, matches_for_query, similarities_for_query in tqdm(
zip(vlc_db.get_image_keys(), matches, similarities), total=len(matches)
):
ts = vlc_db.get_image(key).metadata.epoch_ns
match_uuid = None
for match_image, similarity in zip(matches_for_query, similarities_for_query):
match_ts = match_image.metadata.epoch_ns
if abs(match_ts - ts) > lc_lockout * 1e9:
match_uuid = match_image.metadata.image_uuid
break

for img in vlc_db.iterate_images():
ts = img.metadata.epoch_ns
matches = vlc_db.query_embeddings_filter(
img.embedding, 1, lambda m, s: abs(m.metadata.epoch_ns - ts) > lc_lockout * 1e9
)
times.append(ts / 1e9)
closest.append(similarity)

if similarity < place_recognition_threshold:
match_uuid = None

if match_uuid is None:
putative_loop_closures.append((key, None))
continue

putative_loop_closures.append((key, match_uuid))
closest.append(matches[0][0])
if matches[0][0] > place_recognition_threshold:
right = matches[0][1].image.rgb
else:
right = None
img = combine_images(img.image.rgb, right)
cv2.imshow("matches", img.astype(np.uint8))
cv2.waitKey(30)

t = [ti - times[0] for ti in times]
plt.ion()
plt.plot(t, closest)
plt.xlabel("Time (s)")
plt.ylabel("Best Descriptor Similarity")
plt.title("Closest Descriptor s.t. Time Constraint")
for key, match_key in tqdm(putative_loop_closures):
left = vlc_db.get_image(key).image.rgb
if match_key is None:
right = np.zeros(left.shape)
else:
right = vlc_db.get_image(match_key).image.rgb
img = combine_images(left, right)
cv2.imshow("matches", img.astype(np.uint8))
cv2.waitKey(30)
45 changes: 13 additions & 32 deletions examples/uhumans_example_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,45 +79,27 @@ def plot_heading(positions, rpy):
# Query for closest matches
query_embeddings = np.array([image.embedding for image in vlc_db.iterate_images()])

# query_fn = functools.partial(
# compute_descriptor_similarity, lc_lockout * 1e9, gt_lc_max_dist
# )

# matches, similarities = vlc_db.query_embeddings(
# np.array(query_embeddings), 1, similarity_metric=query_fn
# )


matches, similarities = vlc_db.query_embeddings(
query_embeddings, -1, similarity_metric="ip"
matches = vlc_db.batch_query_embeddings_uuid_filter(
vlc_db.get_image_keys(),
1,
lambda q, v, s: q.metadata.epoch_ns > v.metadata.epoch_ns + lc_lockout * 1e9,
)


# TODO: move this to db query or utility function
# Ignore matches that are too close temporally or too far in descriptor similarity
putative_loop_closures = []
for key, matches_for_query, similarities_for_query in zip(
vlc_db.get_image_keys(), matches, similarities
):
ts = vlc_db.get_image(key).metadata.epoch_ns
match_uuid = None
for match_image, similarity in zip(matches_for_query, similarities_for_query):
match_ts = match_image.metadata.epoch_ns
if abs(match_ts - ts) > lc_lockout * 1e9 and ts > match_ts:
match_uuid = match_image.metadata.image_uuid
break

if similarity < place_recognition_threshold:
break

if match_uuid is None:
for matches_to_frame in matches:
if len(matches_to_frame) == 0:
continue
similarity = matches_to_frame[0][0]
if similarity < place_recognition_threshold:
continue
putative_loop_closures.append((matches_to_frame[0][1], matches_to_frame[0][2]))

putative_loop_closures.append((key, match_uuid))

# Recover poses from matching loop closures
for key_from, key_to in putative_loop_closures:
img_from = vlc_db.get_image(key_from)
for img_from, img_to in putative_loop_closures:
key_from = img_from.metadata.image_uuid
key_to = img_to.metadata.image_uuid
if img_from.keypoints is None or img_from.descriptors is None:
# TODO: replace with real keypoints and descriptors
# keypoints = generate_keypoints(img_from.image)
Expand All @@ -135,7 +117,6 @@ def plot_heading(positions, rpy):
vlc_db.update_keypoints(key_from, keypoints, descriptors)
img_from = vlc_db.get_image(key_from)

img_to = vlc_db.get_image(key_to)
if img_to.keypoints is None or img_to.descriptors is None:
# TODO: replace with real keypoints and descriptors
# keypoints = generate_keypoints(img_to.image)
Expand Down
2 changes: 1 addition & 1 deletion examples/vlc_server_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def insert_image(db, image):
print(vlc_db.get_image(a_id))

print("Querying 0,1,1")
imgs, dists = vlc_db.query_embeddings(np.array([[0, 1, 1]]), 2)
imgs, dists = vlc_db.query_embeddings(np.array([0, 1, 1]), 2)

computed_ts = datetime.now()
loop_closure = ob.SparkLoopClosure(
Expand Down
133 changes: 132 additions & 1 deletion src/ouroboros/vlc_db/vlc_db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Union
from typing import Union, Callable, TypeVar, Tuple, List, Optional
import numpy as np

from ouroboros.vlc_db.spark_image import SparkImage
Expand All @@ -9,12 +9,16 @@
from ouroboros.vlc_db.vlc_lc_table import LcTable
from ouroboros.vlc_db.vlc_session_table import SessionTable
from ouroboros.vlc_db.spark_loop_closure import SparkLoopClosure
from ouroboros.vlc_db.utils import epoch_ns_from_datetime


class KeypointSizeException:
pass


T = TypeVar("T")


class VlcDb:
def __init__(self, image_embedding_dimension):
self._image_table = VlcImageTable(image_embedding_dimension)
Expand Down Expand Up @@ -43,6 +47,80 @@ def iterate_images(self):
yield image

def query_embeddings(
self,
embedding: np.ndarray,
k: int,
similarity_metric: Union[str, callable] = "ip",
) -> ([VlcImage], [float]):
"""Embeddings is a NxD numpy array, where N is the number of queries and D is the descriptor size
Queries for the top k matches.

Returns the top k closest matches and the match distances
"""

assert embedding.ndim == 1, "Query embedding must be 1d vector"

matches, similarities = self.batch_query_embeddings(
np.array([embedding]), k, similarity_metric
)
return matches[0], similarities[0]

def query_embeddings_filter(
self,
embedding: np.ndarray,
k: int,
filter_function: Callable[[VlcImage, float], bool],
similarity_metric: Union[str, callable] = "ip",
):
ret = self.batch_query_embeddings_filter(
np.array([embedding]),
k,
lambda _, img, sim: filter_function(img, sim),
similarity_metric,
)

return_matches = [(t[0], t[2]) for t in ret[0]]
return return_matches

def query_embeddings_max_time(
self,
embedding: np.ndarray,
k: int,
max_time: Union[float, int, datetime],
similarity_metric: Union[str, callable] = "ip",
) -> ([VlcImage], [float]):
"""Query image embeddings to find the k closest vectors with timestamp older than max_time."""

if isinstance(max_time, datetime):
max_time = epoch_ns_from_datetime(max_time)
# NOTE: This is a placeholder implementation. Ideally, we re-implement
# this to be more efficient and not iterate through the full set of
# vectors

def time_filter(_, vlc_image, similarity):
return vlc_image.metadata.epoch_ns < max_time

ret = self.batch_query_embeddings_filter(np.array([embedding]), k, time_filter)
matches = [t[2] for t in ret[0]]
similarities = [t[0] for t in ret[0]]
return matches, similarities

def batch_query_embeddings_uuid_filter(
self,
uuids: [str],
k: int,
filter_function: Callable[[VlcImage, VlcImage, float], bool],
similarity_metric: Union[str, callable] = "ip",
):
# get image for each uuid and call query_embeddings)filter

embeddings = np.array([self.get_image(u).embedding for u in uuids])
images = [self.get_image(u) for u in uuids]
return self.batch_query_embeddings_filter(
embeddings, k, filter_function, similarity_metric, filter_metadata=images
)

def batch_query_embeddings(
self,
embeddings: np.ndarray,
k: int,
Expand All @@ -54,8 +132,61 @@ def query_embeddings(
Returns the top k closest matches and the match distances
"""

assert (
embeddings.ndim == 2
), f"Batch query requires an NxD array of query embeddings, not {embeddings.shape}"

return self._image_table.query_embeddings(embeddings, k, similarity_metric)

def batch_query_embeddings_filter(
self,
embeddings: np.ndarray,
k: int,
filter_function: Callable[[T, VlcImage, float], bool],
similarity_metric: Union[str, callable] = "ip",
filter_metadata: Optional[Union[T, List[T]]] = None,
) -> [[Tuple[float, T, VlcImage]]]:
"""Query image embeddings to find the k closest vectors that satisfy
the filter function. Note that this query may be much slower than
`query_embeddings_with_max_time` because it requires iterating over all
stored images.
"""

if isinstance(filter_metadata, list):
assert len(filter_metadata) == len(embeddings)
elif filter_metadata is None:
filter_metadata = [None] * len(embeddings)
elif callable(filter_metadata):
filter_metadata = [filter_metadata] * len(embeddings)
else:
raise Exception("filter function must be a list, callable, or None")

matches, similarities = self.batch_query_embeddings(
embeddings, -1, similarity_metric="ip"
)

filtered_matches_out = []
for metadata, matches_for_query, similarities_for_query in zip(
filter_metadata, matches, similarities
):
filtered_matches_for_query = []
n_matches = 0
for match_image, similarity in zip(
matches_for_query, similarities_for_query
):
if filter_function(metadata, match_image, similarity):
filtered_matches_for_query.append(
(similarity, metadata, match_image)
)
n_matches += 1

if n_matches >= k:
break

filtered_matches_out.append(filtered_matches_for_query)

return filtered_matches_out

def update_embedding(self, image_uuid: str, embedding):
self._image_table.update_embedding(image_uuid, embedding)

Expand Down
29 changes: 0 additions & 29 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,3 @@
# BSD 3-Clause License
#
# Copyright (c) 2021-2024, Massachusetts Institute of Technology.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
"""Fixtures for unit tests."""

import pathlib
Expand Down
Loading
Loading