Skip to content

Commit

Permalink
Merge pull request #236 from akashAD98/feature/lancedb
Browse files Browse the repository at this point in the history
Feature/lancedb as vectordb added
  • Loading branch information
fm1320 authored Dec 11, 2024
2 parents 51ac755 + 177a96c commit 7137548
Show file tree
Hide file tree
Showing 9 changed files with 4,422 additions and 3,456 deletions.
6 changes: 6 additions & 0 deletions adalflow/adalflow/components/retriever/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,19 @@
OptionalPackages.QDRANT,
)

LanceDBRetriever = LazyImport(
"adalflow.components.retriever.lancedb_retriver.LanceDBRetriever",
OptionalPackages.LANCEDB,
)

__all__ = [
"BM25Retriever",
"LLMRetriever",
"FAISSRetriever",
"RerankerRetriever",
"PostgresRetriever",
"QdrantRetriever",
"LanceDBRetriever",
"split_text_by_word_fn",
"split_text_by_word_fn_then_lower_tokenized",
]
Expand Down
138 changes: 138 additions & 0 deletions adalflow/adalflow/components/retriever/lancedb_retriver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import logging
import numpy as np
import pyarrow as pa
import lancedb
from typing import List, Optional, Sequence, Union, Dict, Any
from adalflow.core.embedder import Embedder
from adalflow.core.retriever import Retriever
from adalflow.core.types import RetrieverOutput

# Initialize logging
log = logging.getLogger(__name__)

# Defined data types
LanceDBRetrieverDocumentEmbeddingType = Union[
List[float], np.ndarray
] # single embedding
LanceDBRetrieverDocumentsType = Sequence[LanceDBRetrieverDocumentEmbeddingType]


# Step 2: Define the LanceDBRetriever class
class LanceDBRetriever(
Retriever[LanceDBRetrieverDocumentEmbeddingType, Union[str, List[str]]]
):
def __init__(
self,
embedder: Embedder,
dimensions: int,
db_uri: str = "/tmp/lancedb",
top_k: int = 5,
overwrite: bool = True,
):
"""
LanceDBRetriever is a retriever that leverages LanceDB to efficiently store and query document embeddings.
Attributes:
embedder (Embedder): An instance of the Embedder class used for computing embeddings.
dimensions (int): The dimensionality of the embeddings used.
db_uri (str): The URI of the LanceDB storage (default is "/tmp/lancedb").
top_k (int): The number of top results to retrieve for a given query (default is 5).
overwrite (bool): If True, the existing table is overwritten; otherwise, new documents are appended.
This retriever supports adding documents with their embeddings to a LanceDB storage and retrieving relevant documents based on a given query.
More information on LanceDB can be found here:(https://github.com/lancedb/lancedb)
Documentations: https://lancedb.github.io/lancedb/
"""
super().__init__()
self.db = lancedb.connect(db_uri)
self.embedder = embedder
self.top_k = top_k
self.dimensions = dimensions

# Define table schema with vector field for embeddings
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), list_size=self.dimensions)),
pa.field("content", pa.string()),
]
)

# Create or overwrite the table for storing documents and embeddings
self.table = self.db.create_table(
"documents", schema=schema, mode="overwrite" if overwrite else "append"
)

def add_documents(self, documents: Sequence[Dict[str, Any]]):
"""
Adds documents with pre-computed embeddings to LanceDB.
Args:
documents (Sequence[Dict[str, Any]]): A sequence of documents, each with a 'content' field containing text.
"""
if not documents:
log.warning("No documents provided for embedding")
return

# Embed document content using Embedder
doc_texts = [doc["content"] for doc in documents]
embeddings = self.embedder(input=doc_texts).data

# Format embeddings for LanceDB
data = [
{"vector": embedding.embedding, "content": text}
for embedding, text in zip(embeddings, doc_texts)
]

# Add data to LanceDB table
self.table.add(data)
log.info(f"Added {len(documents)} documents to the index")

def retrieve(
self, query: Union[str, List[str]], top_k: Optional[int] = None
) -> List[RetrieverOutput]:
""".
Retrieve top-k documents from LanceDB for a given query or queries.
Args:
query (Union[str, List[str]]): A query string or a list of query strings.
top_k (Optional[int]): The number of top documents to retrieve (if not specified, defaults to the instance's top_k).
Returns:
List[RetrieverOutput]: A list of RetrieverOutput containing the indices and scores of the retrieved documents.
"""
if isinstance(query, str):
query = [query]

if not query or (isinstance(query, str) and query.strip() == ""):
raise ValueError("Query cannot be empty.")

# Check if table (index) exists before performing search
if not self.table:
raise ValueError(
"The index has not been initialized or the table is missing."
)

query_embeddings = self.embedder(input=query).data
output: List[RetrieverOutput] = []

# Perform search in LanceDB for each query
for query_emb in query_embeddings:
results = (
self.table.search(query_emb.embedding)
.limit(top_k or self.top_k)
.to_pandas()
)

# Gather indices and scores from search results
indices = results.index.tolist()
scores = results["_distance"].tolist()

# Append results to output
output.append(
RetrieverOutput(
doc_indices=indices,
doc_scores=scores,
query=query[0] if len(query) == 1 else query,
)
)
return output
5 changes: 5 additions & 0 deletions adalflow/adalflow/utils/lazy_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ class OptionalPackages(Enum):
"Please install faiss with: pip install faiss-cpu (or faiss if you use GPU)",
)

LANCEDB = (
"lancedb",
"Please install lancedb with: pip install lancedb .",
)

# db library
SQLALCHEMY = (
"sqlalchemy",
Expand Down
Loading

0 comments on commit 7137548

Please sign in to comment.