Skip to content

Commit

Permalink
minor logic update
Browse files Browse the repository at this point in the history
  • Loading branch information
pablonyx committed Nov 14, 2024
1 parent ade54ce commit 95773d3
Showing 1 changed file with 40 additions and 28 deletions.
68 changes: 40 additions & 28 deletions backend/danswer/natural_language_processing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,41 +89,53 @@ def _check_tokenizer_cache(
model_provider: EmbeddingProvider | None, model_name: str | None
) -> BaseTokenizer:
global _TOKENIZER_CACHE

id_tuple = (model_provider, model_name)

if id_tuple not in _TOKENIZER_CACHE:
# If no provider specified, try to create HuggingFaceTokenizer with model_name
if model_name is not None:
if model_provider is None:
try:
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(model_name)
logger.info(f"Initialized HuggingFaceTokenizer for: {model_name}")
return _TOKENIZER_CACHE[id_tuple]
except Exception as hf_error:
logger.warning(
f"Error initializing HuggingFaceTokenizer for {model_name}: {hf_error}"
)

# Try using TiktokenTokenizer if it supports the model_name
try:
_TOKENIZER_CACHE[id_tuple] = TiktokenTokenizer(model_name)
logger.info(f"Initialized TiktokenTokenizer for: {model_name}")
return _TOKENIZER_CACHE[id_tuple]
except Exception as tiktoken_error:
logger.debug(
f"TiktokenTokenizer not available for model {model_name}: {tiktoken_error}"
)

# Fallback to default DOCUMENT_ENCODER_MODEL
logger.info(
f"Falling back to default embedding model for model {model_name}: {DOCUMENT_ENCODER_MODEL}"
)
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
tokenizer = None

if model_name:
tokenizer = _try_initialize_tokenizer(model_name, model_provider)

if not tokenizer:
logger.info(
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
)
tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)

_TOKENIZER_CACHE[id_tuple] = tokenizer

return _TOKENIZER_CACHE[id_tuple]


def _try_initialize_tokenizer(
model_name: str, model_provider: EmbeddingProvider | None
) -> BaseTokenizer | None:
if model_provider is not None:
# Try using TiktokenTokenizer first if model_provider exists
try:
tokenizer = TiktokenTokenizer(model_name)
logger.info(f"Initialized TiktokenTokenizer for: {model_name}")
return tokenizer
except Exception as tiktoken_error:
logger.debug(
f"TiktokenTokenizer not available for model {model_name}: {tiktoken_error}"
)
else:
# If no provider specified, try HuggingFaceTokenizer
try:
tokenizer = HuggingFaceTokenizer(model_name)
logger.info(f"Initialized HuggingFaceTokenizer for: {model_name}")
return tokenizer
except Exception as hf_error:
logger.warning(
f"Error initializing HuggingFaceTokenizer for {model_name}: {hf_error}"
)

# If both initializations fail, return None
return None


_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)


Expand Down

0 comments on commit 95773d3

Please sign in to comment.