-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #450 from realpython/chromadb
Materials for Embeddings and Vector Databases With ChromaDB
- Loading branch information
Showing
11 changed files
with
468 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Embeddings and Vector Databases With ChromaDB | ||
|
||
Supporting code for the Real Python tutorial [Embeddings and Vector Databases With ChromaDB](https://realpython.com/embeddings-and-vector-databases-with-chromadb/). | ||
|
||
To run the code in this tutorial, you should have `numpy`, `spacy`, `sentence-transformers`, `chromadb`, `polars`, `more-itertools`, and `openai` installed in your environment. | ||
|
||
You can install the dependencies manually, or by running: | ||
|
||
``` | ||
(venv) $ python -m pip install -r requirements.txt | ||
``` |
62 changes: 62 additions & 0 deletions
62
embeddings-and-vector-databases-with-chromadb/car_data_etl.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import pathlib | ||
|
||
import polars as pl | ||
|
||
|
||
def prepare_car_reviews_data( | ||
data_path: pathlib.Path, vehicle_years: list[int] = [2017] | ||
): | ||
"""Prepare the car reviews dataset for ChromaDB""" | ||
|
||
# Define the schema to ensure proper data types are enforced | ||
dtypes = { | ||
"": pl.Int64, | ||
"Review_Date": pl.Utf8, | ||
"Author_Name": pl.Utf8, | ||
"Vehicle_Title": pl.Utf8, | ||
"Review_Title": pl.Utf8, | ||
"Review": pl.Utf8, | ||
"Rating": pl.Float64, | ||
} | ||
|
||
# Scan the car reviews dataset(s) | ||
car_reviews = pl.scan_csv(data_path, dtypes=dtypes) | ||
|
||
# Extract the vehicle title and year as new columns | ||
# Filter on selected years | ||
car_review_db_data = ( | ||
car_reviews.with_columns( | ||
[ | ||
( | ||
pl.col("Vehicle_Title") | ||
.str.split(by=" ") | ||
.list.get(0) | ||
.cast(pl.Int64) | ||
).alias("Vehicle_Year"), | ||
(pl.col("Vehicle_Title").str.split(by=" ").list.get(1)).alias( | ||
"Vehicle_Model" | ||
), | ||
] | ||
) | ||
.filter(pl.col("Vehicle_Year").is_in(vehicle_years)) | ||
.select( | ||
[ | ||
"Review_Title", | ||
"Review", | ||
"Rating", | ||
"Vehicle_Year", | ||
"Vehicle_Model", | ||
] | ||
) | ||
.sort(["Vehicle_Model", "Rating"]) | ||
.collect() | ||
) | ||
|
||
# Create ids, documents, and metadatas data in the format chromadb expects | ||
ids = [f"review{i}" for i in range(car_review_db_data.shape[0])] | ||
documents = car_review_db_data["Review"].to_list() | ||
metadatas = car_review_db_data.drop("Review").to_dicts() | ||
|
||
chroma_data = {"ids": ids, "documents": documents, "metadatas": metadatas} | ||
|
||
return chroma_data |
41 changes: 41 additions & 0 deletions
41
embeddings-and-vector-databases-with-chromadb/chroma_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import pathlib | ||
|
||
import chromadb | ||
from chromadb.utils import embedding_functions | ||
from more_itertools import batched | ||
|
||
|
||
def build_chroma_collection( | ||
chroma_path: pathlib.Path, | ||
collection_name: str, | ||
embbeding_func_name: str, | ||
ids: list[str], | ||
documents: list[str], | ||
metadatas: list[dict], | ||
distance_func_name: str = "cosine", | ||
): | ||
"""Create a ChromaDB collection""" | ||
|
||
chroma_client = chromadb.PersistentClient(chroma_path) | ||
|
||
embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction( | ||
model_name=embbeding_func_name | ||
) | ||
|
||
collection = chroma_client.create_collection( | ||
name=collection_name, | ||
embedding_function=embedding_func, | ||
metadata={"hnsw:space": distance_func_name}, | ||
) | ||
|
||
document_indices = list(range(len(documents))) | ||
|
||
for batch in batched(document_indices, 166): | ||
start_idx = batch[0] | ||
end_idx = batch[-1] | ||
|
||
collection.add( | ||
ids=ids[start_idx:end_idx], | ||
documents=documents[start_idx:end_idx], | ||
metadatas=metadatas[start_idx:end_idx], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
{ | ||
"openai-secret-key": "your-api-key" | ||
} |
7 changes: 7 additions & 0 deletions
7
embeddings-and-vector-databases-with-chromadb/cosine_similarity.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import numpy as np | ||
|
||
|
||
def compute_cosine_similarity(u: np.ndarray, v: np.ndarray) -> float: | ||
"""Compute the cosine similarity between two vectors""" | ||
|
||
return u.dot(v) / (np.linalg.norm(u) * np.linalg.norm(v)) |
39 changes: 39 additions & 0 deletions
39
embeddings-and-vector-databases-with-chromadb/create_car_review_collection.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import chromadb | ||
from chromadb.utils import embedding_functions | ||
|
||
from car_data_etl import prepare_car_reviews_data | ||
from chroma_utils import build_chroma_collection | ||
|
||
DATA_PATH = "data/archive/*" | ||
CHROMA_PATH = "car_review_embeddings" | ||
EMBEDDING_FUNC_NAME = "multi-qa-MiniLM-L6-cos-v1" | ||
COLLECTION_NAME = "car_reviews" | ||
|
||
chroma_car_reviews_dict = prepare_car_reviews_data(DATA_PATH) | ||
|
||
build_chroma_collection( | ||
CHROMA_PATH, | ||
COLLECTION_NAME, | ||
EMBEDDING_FUNC_NAME, | ||
chroma_car_reviews_dict["ids"], | ||
chroma_car_reviews_dict["documents"], | ||
chroma_car_reviews_dict["metadatas"], | ||
) | ||
|
||
client = chromadb.PersistentClient(CHROMA_PATH) | ||
embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction( | ||
model_name=EMBEDDING_FUNC_NAME | ||
) | ||
collection = client.get_collection( | ||
name=COLLECTION_NAME, embedding_function=embedding_func | ||
) | ||
|
||
great_reviews = collection.query( | ||
query_texts=[ | ||
"Find me some positive reviews that discuss the car's performance" | ||
], | ||
n_results=5, | ||
include=["documents", "distances", "metadatas"], | ||
) | ||
|
||
print(great_reviews["documents"][0][0]) |
23 changes: 23 additions & 0 deletions
23
embeddings-and-vector-databases-with-chromadb/intro_to_vectors.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import numpy as np | ||
|
||
# Create vectors with NumPy | ||
vector1 = np.array([1, 0]) | ||
vector2 = np.array([0, 1]) | ||
print(vector1) | ||
print(vector2) | ||
|
||
v1 = np.array([1, 0]) | ||
v2 = np.array([0, 1]) | ||
v3 = np.array([np.sqrt(2), np.sqrt(2)]) | ||
|
||
# Dimension | ||
print(v1.shape) | ||
|
||
# Magnitude | ||
print(np.sqrt(np.sum(v1**2))) | ||
print(np.linalg.norm(v1)) | ||
print(np.linalg.norm(v3)) | ||
|
||
# Dot product | ||
print(np.sum(v1 * v2)) | ||
print(v1.dot(v3)) |
114 changes: 114 additions & 0 deletions
114
embeddings-and-vector-databases-with-chromadb/llm_car_review_context.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import json | ||
import os | ||
|
||
import chromadb | ||
import openai | ||
from chromadb.utils import embedding_functions | ||
|
||
os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||
|
||
DATA_PATH = "data/archive/*" | ||
CHROMA_PATH = "car_review_embeddings" | ||
EMBEDDING_FUNC_NAME = "multi-qa-MiniLM-L6-cos-v1" | ||
COLLECTION_NAME = "car_reviews" | ||
|
||
with open("config.json", "r") as json_file: | ||
config_data = json.load(json_file) | ||
|
||
openai.api_key = config_data.get("openai-secret-key") | ||
|
||
client = chromadb.PersistentClient(CHROMA_PATH) | ||
embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction( | ||
model_name=EMBEDDING_FUNC_NAME | ||
) | ||
|
||
collection = client.get_collection( | ||
name=COLLECTION_NAME, embedding_function=embedding_func | ||
) | ||
|
||
context = """ | ||
You are a customer success employee at a large | ||
car dealership. Use the following car reviews | ||
to answer questions: {} | ||
""" | ||
|
||
question = """ | ||
What's the key to great customer satisfaction | ||
based on detailed positive reviews? | ||
""" | ||
|
||
good_reviews = collection.query( | ||
query_texts=[question], | ||
n_results=10, | ||
include=["documents"], | ||
where={"Rating": {"$gte": 3}}, | ||
) | ||
|
||
reviews_str = ",".join(good_reviews["documents"][0]) | ||
|
||
good_review_summaries = openai.ChatCompletion.create( | ||
model="gpt-3.5-turbo", | ||
messages=[ | ||
{"role": "system", "content": context.format(reviews_str)}, | ||
{"role": "user", "content": question}, | ||
], | ||
temperature=0, | ||
n=1, | ||
) | ||
|
||
reviews_str = ",".join(good_reviews["documents"][0]) | ||
|
||
print("Good reviews: ") | ||
print(reviews_str) | ||
print("###########################################") | ||
|
||
good_review_summaries = openai.ChatCompletion.create( | ||
model="gpt-3.5-turbo", | ||
messages=[ | ||
{"role": "system", "content": context.format(reviews_str)}, | ||
{"role": "user", "content": question}, | ||
], | ||
temperature=0, | ||
n=1, | ||
) | ||
|
||
print("AI-Generated summary of good reviews: ") | ||
print(good_review_summaries["choices"][0]["message"]["content"]) | ||
print("###########################################") | ||
|
||
|
||
context = """ | ||
You are a customer success employee at a large car dealership. | ||
Use the following car reivews to answer questions: {} | ||
""" | ||
question = """ | ||
Which of these poor reviews has the worst implications about | ||
our dealership? Explain why. | ||
""" | ||
|
||
poor_reviews = collection.query( | ||
query_texts=[question], | ||
n_results=5, | ||
include=["documents"], | ||
where={"Rating": {"$lte": 3}}, | ||
) | ||
|
||
reviews_str = ",".join(poor_reviews["documents"][0]) | ||
|
||
print("Worst reviews: ") | ||
print(poor_reviews["documents"][0][0]) | ||
print("###########################################") | ||
|
||
poor_review_analysis = openai.ChatCompletion.create( | ||
model="gpt-3.5-turbo", | ||
messages=[ | ||
{"role": "system", "content": context.format(reviews_str)}, | ||
{"role": "user", "content": question}, | ||
], | ||
temperature=0, | ||
n=1, | ||
) | ||
|
||
print("AI-Generated summary of the single worst review: ") | ||
print(poor_review_analysis["choices"][0]["message"]["content"]) | ||
print("###########################################") |
Oops, something went wrong.