Skip to content

Commit

Permalink
Add handling for rate limiting
Browse files Browse the repository at this point in the history
  • Loading branch information
Weves committed Nov 27, 2024
1 parent 28e2b78 commit 22e0af3
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 38 deletions.
4 changes: 4 additions & 0 deletions backend/danswer/natural_language_processing/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class ModelServerRateLimitError(Exception):
"""
Exception raised for rate limiting errors from the model server.
"""
45 changes: 29 additions & 16 deletions backend/danswer/natural_language_processing/search_nlp_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import requests
from httpx import HTTPError
from requests import Response
from retry import retry

from danswer.configs.app_configs import LARGE_CHUNK_RATIO
Expand All @@ -16,6 +17,7 @@
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.models import SearchSettings
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.natural_language_processing.exceptions import ModelServerRateLimitError
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.logger import setup_logger
Expand Down Expand Up @@ -99,28 +101,39 @@ def __init__(
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"

def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
def _make_request() -> EmbedResponse:
def _make_request() -> Response:
response = requests.post(
self.embed_server_endpoint, json=embed_request.model_dump()
)
try:
response.raise_for_status()
except requests.HTTPError as e:
try:
error_detail = response.json().get("detail", str(e))
except Exception:
error_detail = response.text
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
# signify that this is a rate limit error
if response.status_code == 429:
raise ModelServerRateLimitError(response.text)

return EmbedResponse(**response.json())
response.raise_for_status()
return response

# only perform retries for the non-realtime embedding of passages (e.g. for indexing)
final_make_request_func = _make_request

# if the text type is a passage, add some default
# retries + handling for rate limiting
if embed_request.text_type == EmbedTextType.PASSAGE:
return retry(tries=3, delay=5)(_make_request)()
else:
return _make_request()
final_make_request_func = retry(tries=3, delay=5)(_make_request)
# use 10 second delay as per Azure suggestion
final_make_request_func = retry(
tries=10, delay=10, exceptions=ModelServerRateLimitError
)(final_make_request_func)

try:
response = final_make_request_func()
return EmbedResponse(**response.json())
except requests.HTTPError as e:
try:
error_detail = response.json().get("detail", str(e))
except Exception:
error_detail = response.text
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e

def _batch_encode_texts(
self,
Expand Down
44 changes: 22 additions & 22 deletions backend/model_server/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from fastapi import HTTPException
from google.oauth2 import service_account # type: ignore
from litellm import embedding
from litellm.exceptions import RateLimitError
from retry import retry
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
Expand Down Expand Up @@ -205,28 +206,22 @@ def embed(
model_name: str | None = None,
deployment_name: str | None = None,
) -> list[Embedding]:
try:
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(texts, model_name)
elif self.provider == EmbeddingProvider.AZURE:
return self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return self._embed_litellm_proxy(texts, model_name)

embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error embedding text with {self.provider}: {str(e)}",
)
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(texts, model_name)
elif self.provider == EmbeddingProvider.AZURE:
return self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return self._embed_litellm_proxy(texts, model_name)

embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")

@staticmethod
def create(
Expand Down Expand Up @@ -430,6 +425,11 @@ async def process_embed_request(
prefix=prefix,
)
return EmbedResponse(embeddings=embeddings)
except RateLimitError as e:
raise HTTPException(
status_code=429,
detail=str(e),
)
except Exception as e:
exception_detail = f"Error during embedding process:\n{str(e)}"
logger.exception(exception_detail)
Expand Down
41 changes: 41 additions & 0 deletions backend/tests/daily/embedding/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import EmbeddingProvider

VALID_SHORT_SAMPLE = ["hi"]
VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"]
VALID_LONG_SAMPLE = ["hi " * 999]
# openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't
# seem to be true
TOO_LONG_SAMPLE = ["a"] * 2500
Expand Down Expand Up @@ -99,3 +101,42 @@ def local_nomic_embedding_model() -> EmbeddingModel:
def test_local_nomic_embedding(local_nomic_embedding_model: EmbeddingModel) -> None:
_run_embeddings(VALID_SAMPLE, local_nomic_embedding_model, 768)
_run_embeddings(TOO_LONG_SAMPLE, local_nomic_embedding_model, 768)


@pytest.fixture
def azure_embedding_model() -> EmbeddingModel:
return EmbeddingModel(
server_host="localhost",
server_port=9000,
model_name="text-embedding-3-large",
normalize=True,
query_prefix=None,
passage_prefix=None,
api_key=os.getenv("AZURE_API_KEY"),
provider_type=EmbeddingProvider.AZURE,
api_url=os.getenv("AZURE_API_URL"),
)


# NOTE (chris): this test doesn't work, and I do not know why
# def test_azure_embedding_model_rate_limit(azure_embedding_model: EmbeddingModel):
# """NOTE: this test relies on a very low rate limit for the Azure API +
# this test only being run once in a 1 minute window"""
# # VALID_LONG_SAMPLE is 999 tokens, so the second call should run into rate
# # limits assuming the limit is 1000 tokens per minute
# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
# assert len(result) == 1
# assert len(result[0]) == 1536

# # this should fail
# with pytest.raises(ModelServerRateLimitError):
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)

# # this should succeed, since passage requests retry up to 10 times
# start = time.time()
# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.PASSAGE)
# assert len(result) == 1
# assert len(result[0]) == 1536
# assert time.time() - start > 30 # make sure we waited, even though we hit rate limits

0 comments on commit 22e0af3

Please sign in to comment.