Skip to content

Commit

Permalink
added embedding-based retriever
Browse files Browse the repository at this point in the history
  • Loading branch information
alkidbaci committed Nov 14, 2024
1 parent 8f9d6ae commit 6c2001f
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions general_working_directory/embedding-based_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from openai import OpenAI
import pandas as pd
import numpy as np
from numpy.linalg import norm

query = 'Do you have any T-shirt with a oversize fit and graphics written in front'

df = pd.read_csv("embeddings.csv", index_col=0, nrows=None)
iris = df.index.values.tolist()

client = OpenAI(base_url="http://tentris-ml.cs.upb.de:8502/v1", api_key="token-tentris-upb")

docs = np.array(df.values)
qr = np.array(client.embeddings.create(input=[query], model="tentris").data[0].embedding)

docs_norms = docs / norm(docs, axis=1, keepdims=True)
qr_norms = qr / norm(qr)

cosine_similarities = (docs_norms @ qr_norms).flatten()

best_match_index = np.argmax(cosine_similarities)
best_similarity = cosine_similarities[best_match_index]

print(cosine_similarities)
print(f"The best scoring image is the image with iri: {iris[best_match_index]} and score: {best_similarity}")

0 comments on commit 6c2001f

Please sign in to comment.