Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Remove langchain from index operations #827

Merged
merged 10 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 14 additions & 15 deletions code/backend/batch/utilities/common/SourceDocument.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import json
from urllib.parse import urlparse, quote
from ..helpers.AzureBlobStorageClient import AzureBlobStorageClient
from langchain.docstore.document import Document


class SourceDocument:
Expand All @@ -30,6 +29,20 @@ def __init__(
def __str__(self):
return f"SourceDocument(id={self.id}, title={self.title}, source={self.source}, chunk={self.chunk}, offset={self.offset}, page_number={self.page_number}, chunk_id={self.chunk_id})"

def __eq__(self, other):
if isinstance(self, other.__class__):
return (
self.id == other.id
and self.content == other.content
and self.source == other.source
and self.title == other.title
and self.chunk == other.chunk
and self.offset == other.offset
and self.page_number == other.page_number
and self.chunk_id == other.chunk_id
)
return False

def to_json(self):
return json.dumps(self, cls=SourceDocumentEncoder)

Expand Down Expand Up @@ -80,20 +93,6 @@ def from_metadata(
chunk_id=metadata.get("chunk_id"),
)

def convert_to_langchain_document(self):
return Document(
page_content=self.content,
metadata={
"id": self.id,
"source": self.source,
"title": self.title,
"chunk": self.chunk,
"offset": self.offset,
"page_number": self.page_number,
"chunk_id": self.chunk_id,
},
)

def get_filename(self, include_path=False):
filename = self.source.replace("_SAS_TOKEN_PLACEHOLDER_", "").replace(
"http://", ""
Expand Down
124 changes: 107 additions & 17 deletions code/backend/batch/utilities/helpers/AzureSearchHelper.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,34 @@
import logging
from typing import Union
from langchain.vectorstores.azuresearch import AzureSearch
from azure.core.credentials import AzureKeyCredential
from azure.identity import DefaultAzureCredential
from azure.search.documents import SearchClient
from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents.indexes.models import (
ExhaustiveKnnAlgorithmConfiguration,
ExhaustiveKnnParameters,
HnswAlgorithmConfiguration,
HnswParameters,
SearchableField,
SearchField,
SearchFieldDataType,
SearchIndex,
SemanticConfiguration,
SemanticField,
SemanticPrioritizedFields,
SemanticSearch,
SimpleField,
VectorSearch,
VectorSearchAlgorithmKind,
VectorSearchAlgorithmMetric,
VectorSearchProfile,
)
from .LLMHelper import LLMHelper
from .EnvHelper import EnvHelper

logger = logging.getLogger(__name__)


class AzureSearchHelper:
_search_dimension: int | None = None
Expand All @@ -16,6 +37,37 @@ def __init__(self):
self.llm_helper = LLMHelper()
self.env_helper = EnvHelper()

search_credential = self._search_credential()
self.search_client = self._create_search_client(search_credential)
self.search_index_client = self._create_search_index_client(search_credential)
self.create_index()

def _search_credential(self):
if self.env_helper.is_auth_type_keys():
return AzureKeyCredential(self.env_helper.AZURE_SEARCH_KEY)
else:
return DefaultAzureCredential()

def _create_search_client(
self, search_credential: Union[AzureKeyCredential, DefaultAzureCredential]
) -> SearchClient:
return SearchClient(
endpoint=self.env_helper.AZURE_SEARCH_SERVICE,
index_name=self.env_helper.AZURE_SEARCH_INDEX,
credential=search_credential,
)

def _create_search_index_client(
self, search_credential: Union[AzureKeyCredential, DefaultAzureCredential]
):
return SearchIndexClient(
endpoint=self.env_helper.AZURE_SEARCH_SERVICE, credential=search_credential
)

def get_search_client(self) -> SearchClient:
self.create_index()
return self.search_client

@property
def search_dimensions(self) -> int:
if AzureSearchHelper._search_dimension is None:
Expand All @@ -24,7 +76,7 @@ def search_dimensions(self) -> int:
)
return AzureSearchHelper._search_dimension

def get_vector_store(self):
def create_index(self):
fields = [
SimpleField(
name="id",
Expand Down Expand Up @@ -70,25 +122,63 @@ def get_vector_store(self):
),
]

return AzureSearch(
azure_search_endpoint=self.env_helper.AZURE_SEARCH_SERVICE,
azure_search_key=(
self.env_helper.AZURE_SEARCH_KEY
if self.env_helper.AZURE_AUTH_TYPE == "keys"
else None
),
index_name=self.env_helper.AZURE_SEARCH_INDEX,
embedding_function=self.llm_helper.get_embedding_model().embed_query,
index = SearchIndex(
name=self.env_helper.AZURE_SEARCH_INDEX,
fields=fields,
search_type=(
"semantic_hybrid"
if self.env_helper.AZURE_SEARCH_USE_SEMANTIC_SEARCH
else "hybrid"
semantic_search=(
SemanticSearch(
configurations=[
SemanticConfiguration(
name=self.env_helper.AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG,
prioritized_fields=SemanticPrioritizedFields(
title_field=None,
content_fields=[SemanticField(field_name="content")],
),
)
]
)
),
vector_search=VectorSearch(
algorithms=[
HnswAlgorithmConfiguration(
name="default",
parameters=HnswParameters(
metric=VectorSearchAlgorithmMetric.COSINE
),
kind=VectorSearchAlgorithmKind.HNSW,
),
ExhaustiveKnnAlgorithmConfiguration(
name="default_exhaustive_knn",
kind=VectorSearchAlgorithmKind.EXHAUSTIVE_KNN,
parameters=ExhaustiveKnnParameters(
metric=VectorSearchAlgorithmMetric.COSINE
),
),
],
profiles=[
VectorSearchProfile(
name="myHnswProfile",
algorithm_configuration_name="default",
),
VectorSearchProfile(
name="myExhaustiveKnnProfile",
algorithm_configuration_name="default_exhaustive_knn",
),
],
),
semantic_configuration_name=self.env_helper.AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG,
user_agent="langchain chatwithyourdata-sa",
)

if self._index_not_exists(self.env_helper.AZURE_SEARCH_INDEX):
logger.info(
f"Creating or updating index {self.env_helper.AZURE_SEARCH_INDEX}"
)
self.search_index_client.create_index(index)

def _index_not_exists(self, index_name: str) -> bool:
return index_name not in [
name for name in self.search_index_client.list_index_names()
]

def get_conversation_logger(self):
fields = [
SimpleField(
Expand Down Expand Up @@ -152,7 +242,7 @@ def get_conversation_logger(self):
azure_search_endpoint=self.env_helper.AZURE_SEARCH_SERVICE,
azure_search_key=(
self.env_helper.AZURE_SEARCH_KEY
if self.env_helper.AZURE_AUTH_TYPE == "keys"
if self.env_helper.is_auth_type_keys()
else None
),
index_name=self.env_helper.AZURE_SEARCH_CONVERSATIONS_LOG_INDEX,
Expand Down
11 changes: 10 additions & 1 deletion code/backend/batch/utilities/helpers/LLMHelper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from openai import AzureOpenAI
from typing import cast
from typing import List, Union, cast
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from semantic_kernel.connectors.ai.open_ai import AzureChatCompletion
Expand Down Expand Up @@ -98,6 +98,15 @@ def get_embedding_model(self):
azure_ad_token_provider=self.token_provider,
)

def generate_embeddings(self, input: Union[str, list[int]]) -> List[float]:
return (
self.openai_client.embeddings.create(
input=[input], model=self.embedding_model
)
.data[0]
.embedding
)

def get_chat_completion_with_functions(
self, messages: list[dict], functions: list[dict], function_call: str = "auto"
):
Expand Down
76 changes: 53 additions & 23 deletions code/backend/batch/utilities/helpers/embedders/PushEmbedder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import json
import logging
from typing import List

from ...helpers.LLMHelper import LLMHelper

from ..AzureBlobStorageClient import AzureBlobStorageClient

from ..config.EmbeddingConfig import EmbeddingConfig
Expand All @@ -17,39 +20,66 @@

class PushEmbedder(EmbedderBase):
def __init__(self, blob_client: AzureBlobStorageClient):
self.llm_helper = LLMHelper()
self.azure_search_helper = AzureSearchHelper()
self.document_loading = DocumentLoading()
self.document_chunking = DocumentChunking()
self.blob_client = blob_client
config = ConfigHelper.get_active_config_or_default()
self.processor_map = {}
self.embedding_configs = {}
for processor in config.document_processors:
ext = processor.document_type.lower()
self.processor_map[ext] = processor
self.embedding_configs[ext] = processor

def embed_file(self, source_url: str, file_name: str):
file_extension = file_name.split(".")[-1]
processor = self.processor_map.get(file_extension)
self.__embed(source_url=source_url, processor=processor)
embedding_config = self.embedding_configs.get(file_extension)
self.__embed(source_url=source_url, embedding_config=embedding_config)
if file_extension != "url":
self.blob_client.upsert_blob_metadata(
file_name, {"embeddings_added": "true"}
)

def __embed(self, source_url: str, processor: EmbeddingConfig):
vector_store_helper = AzureSearchHelper()
vector_store = vector_store_helper.get_vector_store()
if not processor.use_advanced_image_processing:
try:
document_loading = DocumentLoading()
document_chunking = DocumentChunking()
documents: List[SourceDocument] = []
documents = document_loading.load(source_url, processor.loading)
documents = document_chunking.chunk(documents, processor.chunking)
keys = list(map(lambda x: x.id, documents))
documents = [
document.convert_to_langchain_document() for document in documents
]
return vector_store.add_documents(documents=documents, keys=keys)
except Exception as e:
logger.error(f"Error adding embeddings for {source_url}: {e}")
raise e
def __embed(self, source_url: str, embedding_config: EmbeddingConfig):
documents_to_upload: List[SourceDocument] = []
if not embedding_config.use_advanced_image_processing:
documents: List[SourceDocument] = self.document_loading.load(
source_url, embedding_config.loading
)
documents = self.document_chunking.chunk(
documents, embedding_config.chunking
)

for document in documents:
documents_to_upload.append(self._convert_to_search_document(document))

response = self.azure_search_helper.get_search_client().upload_documents(
documents_to_upload
)
if not all([r.succeeded for r in response]):
raise Exception(response)

else:
logger.warn("Advanced image processing is not supported yet")
logger.warning("Advanced image processing is not supported yet")

def _convert_to_search_document(self, document: SourceDocument):
embedded_content = self.llm_helper.generate_embeddings(document.content)
metadata = {
"id": document.id,
"source": document.source,
"title": document.title,
"chunk": document.chunk,
"offset": document.offset,
"page_number": document.page_number,
"chunk_id": document.chunk_id,
}
return {
"id": document.id,
"content": document.content,
"content_vector": embedded_content,
"metadata": json.dumps(metadata),
"title": document.title,
"source": document.source,
"chunk": document.chunk,
"offset": document.offset,
}
Loading
Loading