From 6e9b6a1075c7ed880659a0dfd62f778dad4324b7 Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Mon, 21 Oct 2024 22:27:26 -0700 Subject: [PATCH] Handle models like openai/bedrock/claude-3.5-... (#2869) * Handle models like openai/bedrock/claude-3.5-... * Fix log statement --- backend/danswer/llm/utils.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index 3a5e40875f1..bad18214b95 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -342,12 +342,26 @@ def get_llm_max_tokens( try: model_obj = model_map.get(f"{model_provider}/{model_name}") - if not model_obj: - model_obj = model_map[model_name] - logger.debug(f"Using model object for {model_name}") - else: + if model_obj: logger.debug(f"Using model object for {model_provider}/{model_name}") + if not model_obj: + model_obj = model_map.get(model_name) + if model_obj: + logger.debug(f"Using model object for {model_name}") + + if not model_obj: + model_name_split = model_name.split("/") + if len(model_name_split) > 1: + model_obj = model_map.get(model_name_split[1]) + if model_obj: + logger.debug(f"Using model object for {model_name_split[1]}") + + if not model_obj: + raise RuntimeError( + f"No litellm entry found for {model_provider}/{model_name}" + ) + if "max_input_tokens" in model_obj: max_tokens = model_obj["max_input_tokens"] logger.info(