diff --git a/backend/danswer/natural_language_processing/utils.py b/backend/danswer/natural_language_processing/utils.py index a8250570e84..56631495a4e 100644 --- a/backend/danswer/natural_language_processing/utils.py +++ b/backend/danswer/natural_language_processing/utils.py @@ -89,67 +89,70 @@ def _check_tokenizer_cache( model_provider: EmbeddingProvider | None, model_name: str | None ) -> BaseTokenizer: global _TOKENIZER_CACHE - id_tuple = (model_provider, model_name) if id_tuple not in _TOKENIZER_CACHE: - if model_provider in [EmbeddingProvider.OPENAI, EmbeddingProvider.AZURE]: - if model_name is None: - raise ValueError( - "model_name is required for OPENAI and AZURE embeddings" - ) + tokenizer = None - _TOKENIZER_CACHE[id_tuple] = TiktokenTokenizer(model_name) - return _TOKENIZER_CACHE[id_tuple] + if model_name: + tokenizer = _try_initialize_tokenizer(model_name, model_provider) - try: - if model_name is None: - model_name = DOCUMENT_ENCODER_MODEL - - logger.debug(f"Initializing HuggingFaceTokenizer for: {model_name}") - _TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(model_name) - except Exception as primary_error: - logger.error( - f"Error initializing HuggingFaceTokenizer for {model_name}: {primary_error}" - ) - logger.warning( + if not tokenizer: + logger.info( f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}" ) + tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL) - try: - # Cache this tokenizer name to the default so we don't have to try to load it again - # and fail again - _TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer( - DOCUMENT_ENCODER_MODEL - ) - except Exception as fallback_error: - logger.error( - f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}" - ) - raise ValueError( - f"Failed to initialize tokenizer for {model_name} and fallback model" - ) from fallback_error + _TOKENIZER_CACHE[id_tuple] = tokenizer return _TOKENIZER_CACHE[id_tuple] +def _try_initialize_tokenizer( + model_name: str, model_provider: EmbeddingProvider | None +) -> BaseTokenizer | None: + tokenizer: BaseTokenizer | None = None + + if model_provider is not None: + # Try using TiktokenTokenizer first if model_provider exists + try: + tokenizer = TiktokenTokenizer(model_name) + logger.info(f"Initialized TiktokenTokenizer for: {model_name}") + return tokenizer + except Exception as tiktoken_error: + logger.debug( + f"TiktokenTokenizer not available for model {model_name}: {tiktoken_error}" + ) + else: + # If no provider specified, try HuggingFaceTokenizer + try: + tokenizer = HuggingFaceTokenizer(model_name) + logger.info(f"Initialized HuggingFaceTokenizer for: {model_name}") + return tokenizer + except Exception as hf_error: + logger.warning( + f"Error initializing HuggingFaceTokenizer for {model_name}: {hf_error}" + ) + + # If both initializations fail, return None + return None + + _DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL) def get_tokenizer( model_name: str | None, provider_type: EmbeddingProvider | str | None ) -> BaseTokenizer: - if provider_type is not None: - if isinstance(provider_type, str): - try: - provider_type = EmbeddingProvider(provider_type) - except ValueError: - logger.debug( - f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer." - ) - return _DEFAULT_TOKENIZER - return _check_tokenizer_cache(provider_type, model_name) - return _DEFAULT_TOKENIZER + if isinstance(provider_type, str): + try: + provider_type = EmbeddingProvider(provider_type) + except ValueError: + logger.debug( + f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer." + ) + return _DEFAULT_TOKENIZER + return _check_tokenizer_cache(provider_type, model_name) def tokenizer_trim_content(