Skip to content

Commit

Permalink
Remove langchain from index operations
Browse files Browse the repository at this point in the history
- This is preparation for adding an additional image vector field for
  advanced image processing
- I've tried to keep changes to a minimum
- Still using langchain in the question and answer tool for now
- Increased unit test coverage

Required by #748
  • Loading branch information
adamdougal committed May 7, 2024
1 parent 413ccd6 commit cbb3d08
Show file tree
Hide file tree
Showing 22 changed files with 1,093 additions and 472 deletions.
30 changes: 14 additions & 16 deletions code/backend/batch/utilities/common/SourceDocument.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@ def __init__(
def __str__(self):
return f"SourceDocument(id={self.id}, title={self.title}, source={self.source}, chunk={self.chunk}, offset={self.offset}, page_number={self.page_number}, chunk_id={self.chunk_id})"

def __eq__(self, other):
if isinstance(self, other.__class__):
return (
self.id == other.id
and self.content == other.content
and self.source == other.source
and self.title == other.title
and self.chunk == other.chunk
and self.offset == other.offset
and self.page_number == other.page_number
and self.chunk_id == other.chunk_id
)
return False

def to_json(self):
return json.dumps(self, cls=SourceDocumentEncoder)

Expand Down Expand Up @@ -79,22 +93,6 @@ def from_metadata(
chunk_id=metadata.get("chunk_id"),
)

def convert_to_langchain_document(self):
from langchain.docstore.document import Document

return Document(
page_content=self.content,
metadata={
"id": self.id,
"source": self.source,
"title": self.title,
"chunk": self.chunk,
"offset": self.offset,
"page_number": self.page_number,
"chunk_id": self.chunk_id,
},
)

def get_filename(self, include_path=False):
filename = self.source.replace("_SAS_TOKEN_PLACEHOLDER_", "").replace(
"http://", ""
Expand Down
128 changes: 111 additions & 17 deletions code/backend/batch/utilities/helpers/AzureSearchHelper.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,34 @@
import logging
from typing import Union
from langchain.vectorstores.azuresearch import AzureSearch
from azure.core.credentials import AzureKeyCredential
from azure.identity import DefaultAzureCredential
from azure.search.documents import SearchClient
from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents.indexes.models import (
ExhaustiveKnnAlgorithmConfiguration,
ExhaustiveKnnParameters,
HnswAlgorithmConfiguration,
HnswParameters,
SearchableField,
SearchField,
SearchFieldDataType,
SearchIndex,
SemanticConfiguration,
SemanticField,
SemanticPrioritizedFields,
SemanticSearch,
SimpleField,
VectorSearch,
VectorSearchAlgorithmKind,
VectorSearchAlgorithmMetric,
VectorSearchProfile,
)
from .LLMHelper import LLMHelper
from .EnvHelper import EnvHelper

logger = logging.getLogger(__name__)


class AzureSearchHelper:
_search_dimension: int | None = None
Expand All @@ -16,6 +37,37 @@ def __init__(self):
self.llm_helper = LLMHelper()
self.env_helper = EnvHelper()

search_credential = self._search_credential()
self.search_client = self._create_search_client(search_credential)
self.search_index_client = self._create_search_index_client(search_credential)
self.create_index()

def _search_credential(self):
if self.env_helper.is_auth_type_keys():
return AzureKeyCredential(self.env_helper.AZURE_SEARCH_KEY)
else:
return DefaultAzureCredential()

def _create_search_client(
self, search_credential: Union[AzureKeyCredential, DefaultAzureCredential]
) -> SearchClient:
return SearchClient(
endpoint=self.env_helper.AZURE_SEARCH_SERVICE,
index_name=self.env_helper.AZURE_SEARCH_INDEX,
credential=search_credential,
)

def _create_search_index_client(
self, search_credential: Union[AzureKeyCredential, DefaultAzureCredential]
):
return SearchIndexClient(
endpoint=self.env_helper.AZURE_SEARCH_SERVICE, credential=search_credential
)

def get_search_client(self) -> SearchClient:
self.create_index()
return self.search_client

@property
def search_dimensions(self) -> int:
if AzureSearchHelper._search_dimension is None:
Expand All @@ -24,7 +76,7 @@ def search_dimensions(self) -> int:
)
return AzureSearchHelper._search_dimension

def get_vector_store(self):
def create_index(self):
fields = [
SimpleField(
name="id",
Expand Down Expand Up @@ -70,25 +122,67 @@ def get_vector_store(self):
),
]

return AzureSearch(
azure_search_endpoint=self.env_helper.AZURE_SEARCH_SERVICE,
azure_search_key=(
self.env_helper.AZURE_SEARCH_KEY
if self.env_helper.AZURE_AUTH_TYPE == "keys"
else None
),
index_name=self.env_helper.AZURE_SEARCH_INDEX,
embedding_function=self.llm_helper.get_embedding_model().embed_query,
index = SearchIndex(
name=self.env_helper.AZURE_SEARCH_INDEX,
fields=fields,
search_type=(
"semantic_hybrid"
if self.env_helper.AZURE_SEARCH_USE_SEMANTIC_SEARCH
else "hybrid"
semantic_search=(
SemanticSearch(
configurations=[
SemanticConfiguration(
name=self.env_helper.AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG,
prioritized_fields=SemanticPrioritizedFields(
title_field=None,
content_fields=[SemanticField(field_name="content")],
),
)
]
)
),
vector_search=VectorSearch(
algorithms=[
HnswAlgorithmConfiguration(
name="default",
parameters=HnswParameters(
metric=VectorSearchAlgorithmMetric.COSINE
),
kind=VectorSearchAlgorithmKind.HNSW,
),
ExhaustiveKnnAlgorithmConfiguration(
name="default_exhaustive_knn",
kind=VectorSearchAlgorithmKind.EXHAUSTIVE_KNN,
parameters=ExhaustiveKnnParameters(
metric=VectorSearchAlgorithmMetric.COSINE
),
),
],
profiles=[
VectorSearchProfile(
name="myHnswProfile",
algorithm_configuration_name="default",
),
VectorSearchProfile(
name="myExhaustiveKnnProfile",
algorithm_configuration_name="default_exhaustive_knn",
),
],
),
semantic_configuration_name=self.env_helper.AZURE_SEARCH_SEMANTIC_SEARCH_CONFIG,
user_agent="langchain chatwithyourdata-sa",
)

if self._index_not_exists(self.env_helper.AZURE_SEARCH_INDEX):
try:
logger.info(
f"Creating or updating index {self.env_helper.AZURE_SEARCH_INDEX}"
)
self.search_index_client.create_index(index)
except Exception as e:
logger.exception("Error Creating index")
raise e

def _index_not_exists(self, index_name: str) -> bool:
return index_name not in [
name for name in self.search_index_client.list_index_names()
]

def get_conversation_logger(self):
fields = [
SimpleField(
Expand Down Expand Up @@ -152,7 +246,7 @@ def get_conversation_logger(self):
azure_search_endpoint=self.env_helper.AZURE_SEARCH_SERVICE,
azure_search_key=(
self.env_helper.AZURE_SEARCH_KEY
if self.env_helper.AZURE_AUTH_TYPE == "keys"
if self.env_helper.is_auth_type_keys()
else None
),
index_name=self.env_helper.AZURE_SEARCH_CONVERSATIONS_LOG_INDEX,
Expand Down
11 changes: 10 additions & 1 deletion code/backend/batch/utilities/helpers/LLMHelper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from openai import AzureOpenAI
from typing import cast
from typing import List, Union, cast
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from semantic_kernel.connectors.ai.open_ai import AzureChatCompletion
Expand Down Expand Up @@ -98,6 +98,15 @@ def get_embedding_model(self):
azure_ad_token_provider=self.token_provider,
)

def generate_embeddings(self, input: Union[str, list[int]]) -> List[float]:
return (
self.openai_client.embeddings.create(
input=[input], model=self.embedding_model
)
.data[0]
.embedding
)

def get_chat_completion_with_functions(
self, messages: list[dict], functions: list[dict], function_call: str = "auto"
):
Expand Down
87 changes: 63 additions & 24 deletions code/backend/batch/utilities/helpers/embedders/PushEmbedder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import json
import logging
from typing import List

from ...helpers import LLMHelper

from ..AzureBlobStorageClient import AzureBlobStorageClient

from ..config.EmbeddingConfig import EmbeddingConfig
Expand All @@ -17,39 +20,75 @@

class PushEmbedder(EmbedderBase):
def __init__(self, blob_client: AzureBlobStorageClient):
self.llm_helper = LLMHelper()
self.azure_search_helper = AzureSearchHelper()
self.document_loading = DocumentLoading()
self.document_chunking = DocumentChunking()
self.blob_client = blob_client
config = ConfigHelper.get_active_config_or_default()
self.processor_map = {}
self.embedding_configs = {}
for processor in config.document_processors:
ext = processor.document_type.lower()
self.processor_map[ext] = processor
self.embedding_configs[ext] = processor

def embed_file(self, source_url: str, file_name: str):
file_extension = file_name.split(".")[-1]
processor = self.processor_map.get(file_extension)
self.__embed(source_url=source_url, processor=processor)
embedding_config = self.embedding_configs.get(file_extension)
self.__embed(source_url=source_url, embedding_config=embedding_config)
if file_extension != "url":
self.blob_client.upsert_blob_metadata(
file_name, {"embeddings_added": "true"}
)

def __embed(self, source_url: str, processor: EmbeddingConfig):
vector_store_helper = AzureSearchHelper()
vector_store = vector_store_helper.get_vector_store()
if not processor.use_advanced_image_processing:
try:
document_loading = DocumentLoading()
document_chunking = DocumentChunking()
documents: List[SourceDocument] = []
documents = document_loading.load(source_url, processor.loading)
documents = document_chunking.chunk(documents, processor.chunking)
keys = list(map(lambda x: x.id, documents))
documents = [
document.convert_to_langchain_document() for document in documents
]
return vector_store.add_documents(documents=documents, keys=keys)
except Exception as e:
logger.error(f"Error adding embeddings for {source_url}: {e}")
raise e
else:
logger.warn("Advanced image processing is not supported yet")
def __embed(self, source_url: str, embedding_config: EmbeddingConfig):
documents_to_upload: List[SourceDocument] = []
try:
if not embedding_config.use_advanced_image_processing:
documents: List[SourceDocument] = self.document_loading.load(
source_url, embedding_config.loading
)
documents = self.document_chunking.chunk(
documents, embedding_config.chunking
)

for document in documents:
documents_to_upload.append(
self._convert_to_search_document(document)
)

response = (
self.azure_search_helper.get_search_client().upload_documents(
documents_to_upload
)
)
if not all([r.succeeded for r in response]):
raise Exception(response)

else:
logger.warning("Advanced image processing is not supported yet")

except Exception as e:
logger.error(f"Error adding embeddings for {source_url}: {e}")
raise e

def _convert_to_search_document(self, document: SourceDocument):
embedded_content = self.llm_helper.generate_embeddings(document.content)
metadata = {
"id": document.id,
"source": document.source,
"title": document.title,
"chunk": document.chunk,
"offset": document.offset,
"page_number": document.page_number,
"chunk_id": document.chunk_id,
}
return {
"id": document.id,
"content": document.content,
"content_vector": embedded_content,
"metadata": json.dumps(metadata),
"title": document.title,
"source": document.source,
"chunk": document.chunk,
"offset": document.offset,
}
Loading

0 comments on commit cbb3d08

Please sign in to comment.