Skip to content

Commit

Permalink
Improved tokenizer fallback (#3132)
Browse files Browse the repository at this point in the history
* silence warning

* improved fallback logic

* k

* minor cosmetic update

* minor logic update

* nit
  • Loading branch information
pablonyx authored Nov 15, 2024
1 parent ddff7ec commit e2c460b
Showing 1 changed file with 46 additions and 43 deletions.
89 changes: 46 additions & 43 deletions backend/danswer/natural_language_processing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,67 +89,70 @@ def _check_tokenizer_cache(
model_provider: EmbeddingProvider | None, model_name: str | None
) -> BaseTokenizer:
global _TOKENIZER_CACHE

id_tuple = (model_provider, model_name)

if id_tuple not in _TOKENIZER_CACHE:
if model_provider in [EmbeddingProvider.OPENAI, EmbeddingProvider.AZURE]:
if model_name is None:
raise ValueError(
"model_name is required for OPENAI and AZURE embeddings"
)
tokenizer = None

_TOKENIZER_CACHE[id_tuple] = TiktokenTokenizer(model_name)
return _TOKENIZER_CACHE[id_tuple]
if model_name:
tokenizer = _try_initialize_tokenizer(model_name, model_provider)

try:
if model_name is None:
model_name = DOCUMENT_ENCODER_MODEL

logger.debug(f"Initializing HuggingFaceTokenizer for: {model_name}")
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(model_name)
except Exception as primary_error:
logger.error(
f"Error initializing HuggingFaceTokenizer for {model_name}: {primary_error}"
)
logger.warning(
if not tokenizer:
logger.info(
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
)
tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)

try:
# Cache this tokenizer name to the default so we don't have to try to load it again
# and fail again
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(
DOCUMENT_ENCODER_MODEL
)
except Exception as fallback_error:
logger.error(
f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}"
)
raise ValueError(
f"Failed to initialize tokenizer for {model_name} and fallback model"
) from fallback_error
_TOKENIZER_CACHE[id_tuple] = tokenizer

return _TOKENIZER_CACHE[id_tuple]


def _try_initialize_tokenizer(
model_name: str, model_provider: EmbeddingProvider | None
) -> BaseTokenizer | None:
tokenizer: BaseTokenizer | None = None

if model_provider is not None:
# Try using TiktokenTokenizer first if model_provider exists
try:
tokenizer = TiktokenTokenizer(model_name)
logger.info(f"Initialized TiktokenTokenizer for: {model_name}")
return tokenizer
except Exception as tiktoken_error:
logger.debug(
f"TiktokenTokenizer not available for model {model_name}: {tiktoken_error}"
)
else:
# If no provider specified, try HuggingFaceTokenizer
try:
tokenizer = HuggingFaceTokenizer(model_name)
logger.info(f"Initialized HuggingFaceTokenizer for: {model_name}")
return tokenizer
except Exception as hf_error:
logger.warning(
f"Error initializing HuggingFaceTokenizer for {model_name}: {hf_error}"
)

# If both initializations fail, return None
return None


_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)


def get_tokenizer(
model_name: str | None, provider_type: EmbeddingProvider | str | None
) -> BaseTokenizer:
if provider_type is not None:
if isinstance(provider_type, str):
try:
provider_type = EmbeddingProvider(provider_type)
except ValueError:
logger.debug(
f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer."
)
return _DEFAULT_TOKENIZER
return _check_tokenizer_cache(provider_type, model_name)
return _DEFAULT_TOKENIZER
if isinstance(provider_type, str):
try:
provider_type = EmbeddingProvider(provider_type)
except ValueError:
logger.debug(
f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer."
)
return _DEFAULT_TOKENIZER
return _check_tokenizer_cache(provider_type, model_name)


def tokenizer_trim_content(
Expand Down

0 comments on commit e2c460b

Please sign in to comment.