Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/lancedb as vectordb added #236

Merged
merged 10 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions adalflow/adalflow/components/retriever/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,19 @@
OptionalPackages.QDRANT,
)

LanceDBRetriever = LazyImport(
"adalflow.components.retriever.lancedb_retriver.LanceDBRetriever",
OptionalPackages.LANCEDB,
)

__all__ = [
"BM25Retriever",
"LLMRetriever",
"FAISSRetriever",
"RerankerRetriever",
"PostgresRetriever",
"QdrantRetriever",
"LanceDBRetriever",
"split_text_by_word_fn",
"split_text_by_word_fn_then_lower_tokenized",
]
Expand Down
138 changes: 138 additions & 0 deletions adalflow/adalflow/components/retriever/lancedb_retriver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import logging
import numpy as np
import pyarrow as pa
import lancedb
from typing import List, Optional, Sequence, Union, Dict, Any
from adalflow.core.embedder import Embedder
from adalflow.core.retriever import Retriever
from adalflow.core.types import RetrieverOutput

# Initialize logging
log = logging.getLogger(__name__)

# Defined data types
LanceDBRetrieverDocumentEmbeddingType = Union[
List[float], np.ndarray
] # single embedding
LanceDBRetrieverDocumentsType = Sequence[LanceDBRetrieverDocumentEmbeddingType]


# Step 2: Define the LanceDBRetriever class
class LanceDBRetriever(
Retriever[LanceDBRetrieverDocumentEmbeddingType, Union[str, List[str]]]
):
def __init__(
self,
embedder: Embedder,
dimensions: int,
db_uri: str = "/tmp/lancedb",
top_k: int = 5,
overwrite: bool = True,
):
"""
LanceDBRetriever is a retriever that leverages LanceDB to efficiently store and query document embeddings.

Attributes:
embedder (Embedder): An instance of the Embedder class used for computing embeddings.
dimensions (int): The dimensionality of the embeddings used.
db_uri (str): The URI of the LanceDB storage (default is "/tmp/lancedb").
top_k (int): The number of top results to retrieve for a given query (default is 5).
overwrite (bool): If True, the existing table is overwritten; otherwise, new documents are appended.

This retriever supports adding documents with their embeddings to a LanceDB storage and retrieving relevant documents based on a given query.

More information on LanceDB can be found here:(https://github.com/lancedb/lancedb)
Documentations: https://lancedb.github.io/lancedb/
"""
super().__init__()
self.db = lancedb.connect(db_uri)
self.embedder = embedder
self.top_k = top_k
self.dimensions = dimensions

# Define table schema with vector field for embeddings
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), list_size=self.dimensions)),
pa.field("content", pa.string()),
]
)

# Create or overwrite the table for storing documents and embeddings
self.table = self.db.create_table(
"documents", schema=schema, mode="overwrite" if overwrite else "append"
)

def add_documents(self, documents: Sequence[Dict[str, Any]]):
"""
Adds documents with pre-computed embeddings to LanceDB.
Args:
documents (Sequence[Dict[str, Any]]): A sequence of documents, each with a 'content' field containing text.

"""
if not documents:
log.warning("No documents provided for embedding")
return

# Embed document content using Embedder
doc_texts = [doc["content"] for doc in documents]
embeddings = self.embedder(input=doc_texts).data

# Format embeddings for LanceDB
data = [
{"vector": embedding.embedding, "content": text}
for embedding, text in zip(embeddings, doc_texts)
]

# Add data to LanceDB table
self.table.add(data)
log.info(f"Added {len(documents)} documents to the index")

def retrieve(
self, query: Union[str, List[str]], top_k: Optional[int] = None
) -> List[RetrieverOutput]:
""".
Retrieve top-k documents from LanceDB for a given query or queries.
Args:
query (Union[str, List[str]]): A query string or a list of query strings.
top_k (Optional[int]): The number of top documents to retrieve (if not specified, defaults to the instance's top_k).

Returns:
List[RetrieverOutput]: A list of RetrieverOutput containing the indices and scores of the retrieved documents.
"""
if isinstance(query, str):
query = [query]

if not query or (isinstance(query, str) and query.strip() == ""):
raise ValueError("Query cannot be empty.")

# Check if table (index) exists before performing search
if not self.table:
raise ValueError(
"The index has not been initialized or the table is missing."
)

query_embeddings = self.embedder(input=query).data
output: List[RetrieverOutput] = []

# Perform search in LanceDB for each query
for query_emb in query_embeddings:
results = (
self.table.search(query_emb.embedding)
.limit(top_k or self.top_k)
.to_pandas()
)

# Gather indices and scores from search results
indices = results.index.tolist()
scores = results["_distance"].tolist()

# Append results to output
output.append(
RetrieverOutput(
doc_indices=indices,
doc_scores=scores,
query=query[0] if len(query) == 1 else query,
)
)
return output
5 changes: 5 additions & 0 deletions adalflow/adalflow/utils/lazy_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ class OptionalPackages(Enum):
"Please install faiss with: pip install faiss-cpu (or faiss if you use GPU)",
)

LANCEDB = (
"lancedb",
"Please install lancedb with: pip install lancedb .",
)

# db library
SQLALCHEMY = (
"sqlalchemy",
Expand Down
Loading
Loading