Skip to content

Commit

Permalink
formatting changes
Browse files Browse the repository at this point in the history
  • Loading branch information
fm1320 committed Dec 10, 2024
1 parent 022181d commit b709393
Show file tree
Hide file tree
Showing 4 changed files with 2,518 additions and 2,156 deletions.
59 changes: 42 additions & 17 deletions adalflow/adalflow/components/retriever/lancedb_retriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,24 @@
log = logging.getLogger(__name__)

# Defined data types
LanceDBRetrieverDocumentEmbeddingType = Union[List[float], np.ndarray] # single embedding
LanceDBRetrieverDocumentEmbeddingType = Union[
List[float], np.ndarray
] # single embedding
LanceDBRetrieverDocumentsType = Sequence[LanceDBRetrieverDocumentEmbeddingType]


# Step 2: Define the LanceDBRetriever class
class LanceDBRetriever(Retriever[LanceDBRetrieverDocumentEmbeddingType, Union[str, List[str]]]):
def __init__(self, embedder: Embedder, dimensions: int, db_uri: str = "/tmp/lancedb", top_k: int = 5, overwrite: bool = True):
class LanceDBRetriever(
Retriever[LanceDBRetrieverDocumentEmbeddingType, Union[str, List[str]]]
):
def __init__(
self,
embedder: Embedder,
dimensions: int,
db_uri: str = "/tmp/lancedb",
top_k: int = 5,
overwrite: bool = True,
):
"""
LanceDBRetriever is a retriever that leverages LanceDB to efficiently store and query document embeddings.
Expand All @@ -39,13 +51,17 @@ def __init__(self, embedder: Embedder, dimensions: int, db_uri: str = "/tmp/lanc
self.dimensions = dimensions

# Define table schema with vector field for embeddings
schema = pa.schema([
pa.field("vector", pa.list_(pa.float32(), list_size=self.dimensions)),
pa.field("content", pa.string())
])
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), list_size=self.dimensions)),
pa.field("content", pa.string()),
]
)

# Create or overwrite the table for storing documents and embeddings
self.table = self.db.create_table("documents", schema=schema, mode="overwrite" if overwrite else "append")
self.table = self.db.create_table(
"documents", schema=schema, mode="overwrite" if overwrite else "append"
)

def add_documents(self, documents: Sequence[Dict[str, Any]]):
"""
Expand All @@ -63,13 +79,18 @@ def add_documents(self, documents: Sequence[Dict[str, Any]]):
embeddings = self.embedder(input=doc_texts).data

# Format embeddings for LanceDB
data = [{"vector": embedding.embedding, "content": text} for embedding, text in zip(embeddings, doc_texts)]
data = [
{"vector": embedding.embedding, "content": text}
for embedding, text in zip(embeddings, doc_texts)
]

# Add data to LanceDB table
self.table.add(data)
log.info(f"Added {len(documents)} documents to the index")

def retrieve(self, query: Union[str, List[str]], top_k: Optional[int] = None) -> List[RetrieverOutput]:
def retrieve(
self, query: Union[str, List[str]], top_k: Optional[int] = None
) -> List[RetrieverOutput]:
""".
Retrieve top-k documents from LanceDB for a given query or queries.
Args:
Expand All @@ -83,11 +104,13 @@ def retrieve(self, query: Union[str, List[str]], top_k: Optional[int] = None) ->
query = [query]

if not query or (isinstance(query, str) and query.strip() == ""):
raise ValueError("Query cannot be empty.")
raise ValueError("Query cannot be empty.")

# Check if table (index) exists before performing search
if not self.table:
raise ValueError("The index has not been initialized or the table is missing.")
raise ValueError(
"The index has not been initialized or the table is missing."
)

query_embeddings = self.embedder(input=query).data
output: List[RetrieverOutput] = []
Expand All @@ -105,9 +128,11 @@ def retrieve(self, query: Union[str, List[str]], top_k: Optional[int] = None) ->
scores = results["_distance"].tolist()

# Append results to output
output.append(RetrieverOutput(
doc_indices=indices,
doc_scores=scores,
query=query[0] if len(query) == 1 else query
))
output.append(
RetrieverOutput(
doc_indices=indices,
doc_scores=scores,
query=query[0] if len(query) == 1 else query,
)
)
return output
2 changes: 1 addition & 1 deletion adalflow/adalflow/utils/lazy_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class OptionalPackages(Enum):
)

LANCEDB = (
"lancedb",
"lancedb",
"Please install lancedb with: pip install lancedb .",
)

Expand Down
40 changes: 20 additions & 20 deletions adalflow/tests/test_lancedb_retriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from unittest import mock
from adalflow.core.types import EmbedderOutput, RetrieverOutput


# Helper function to create dummy embeddings
def create_dummy_embeddings(num_embeddings, dim):
return np.random.rand(num_embeddings, dim).astype(np.float32)


class TestLanceDBRetriever(unittest.TestCase):
def setUp(self):
self.dimensions = 128
Expand All @@ -21,7 +23,10 @@ def setUp(self):
# Mock embedder to return dummy embeddings
self.dummy_embeddings = create_dummy_embeddings(10, self.dimensions)
self.embedder.return_value = EmbedderOutput(
data=[Mock(embedding=emb) for emb in self.dummy_embeddings[:len(self.single_query)]]
data=[
Mock(embedding=emb)
for emb in self.dummy_embeddings[: len(self.single_query)]
]
)

with patch("lancedb.connect") as mock_db_connect:
Expand All @@ -32,7 +37,7 @@ def setUp(self):
embedder=self.embedder,
dimensions=self.dimensions,
db_uri="/tmp/lancedb",
top_k=self.top_k
top_k=self.top_k,
)

def test_initialization(self):
Expand Down Expand Up @@ -68,11 +73,10 @@ def test_retrieve_single_query(self):
)

# Mock search results from LanceDB as pandas DataFrame
results_df = pd.DataFrame({
"index": [0, 1, 2],
"_distance": [0.1, 0.2, 0.3]
})
self.mock_table.search.return_value.limit.return_value.to_pandas.return_value = results_df
results_df = pd.DataFrame({"index": [0, 1, 2], "_distance": [0.1, 0.2, 0.3]})
self.mock_table.search.return_value.limit.return_value.to_pandas.return_value = (
results_df
)

result = self.retriever.retrieve(query)
self.assertIsInstance(result[0], RetrieverOutput)
Expand All @@ -91,11 +95,10 @@ def test_retrieve_multiple_queries(self):
)

# Mock search results for each query
results_df = pd.DataFrame({
"index": [0, 1, 2],
"_distance": [0.1, 0.2, 0.3]
})
self.mock_table.search.return_value.limit.return_value.to_pandas.return_value = results_df
results_df = pd.DataFrame({"index": [0, 1, 2], "_distance": [0.1, 0.2, 0.3]})
self.mock_table.search.return_value.limit.return_value.to_pandas.return_value = (
results_df
)

result = self.retriever.retrieve(queries)
self.assertEqual(len(result), len(queries))
Expand All @@ -106,10 +109,9 @@ def test_retrieve_multiple_queries(self):

def test_retrieve_with_empty_query(self):
# Mock the empty results DataFrame
self.mock_table.search.return_value.limit.return_value.to_pandas.return_value = pd.DataFrame({
"index": [],
"_distance": []
})
self.mock_table.search.return_value.limit.return_value.to_pandas.return_value = pd.DataFrame(
{"index": [], "_distance": []}
)

def test_retrieve_with_no_index(self):
empty_retriever = LanceDBRetriever(
Expand All @@ -128,12 +130,10 @@ def test_overwrite_table_on_initialization(self):
embedder=self.embedder,
dimensions=self.dimensions,
db_uri="/tmp/lancedb",
overwrite=True
overwrite=True,
)
mock_db.create_table.assert_called_once_with(
"documents",
schema=mock.ANY,
mode="overwrite"
"documents", schema=mock.ANY, mode="overwrite"
)


Expand Down
Loading

0 comments on commit b709393

Please sign in to comment.