From 16863de0aa58c5abf4c69127a6a8744d23a0871e Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Fri, 29 Nov 2024 20:42:56 -0800 Subject: [PATCH] Improve model token limit detection (#3292) * Properly find context window for ollama llama * Better ollama support + upgrade litellm * Ugprade OpenAI as well * Fix mypy --- backend/danswer/configs/model_configs.py | 4 +- backend/danswer/llm/chat_llm.py | 11 ++- backend/danswer/llm/factory.py | 13 ++++ backend/danswer/llm/utils.py | 91 +++++++++++++++++++----- backend/requirements/default.txt | 4 +- 5 files changed, 100 insertions(+), 23 deletions(-) diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index 0618bf5f684..b71762a4c88 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -70,7 +70,9 @@ ) # Typically, GenAI models nowadays are at least 4K tokens -GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096 +GEN_AI_MODEL_FALLBACK_MAX_TOKENS = int( + os.environ.get("GEN_AI_MODEL_FALLBACK_MAX_TOKENS") or 4096 +) # Number of tokens from chat history to include at maximum # 3000 should be enough context regardless of use, no need to include as much as possible diff --git a/backend/danswer/llm/chat_llm.py b/backend/danswer/llm/chat_llm.py index 031fcd7163a..f4b09d261fd 100644 --- a/backend/danswer/llm/chat_llm.py +++ b/backend/danswer/llm/chat_llm.py @@ -26,7 +26,9 @@ from langchain_core.prompt_values import PromptValue from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS -from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING +from danswer.configs.model_configs import ( + DISABLE_LITELLM_STREAMING, +) from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.configs.model_configs import LITELLM_EXTRA_BODY from danswer.llm.interfaces import LLM @@ -161,7 +163,9 @@ def _convert_delta_to_message_chunk( if role == "user": return HumanMessageChunk(content=content) - elif role == "assistant": + # NOTE: if tool calls are present, then it's an assistant. + # In Ollama, the role will be None for tool-calls + elif role == "assistant" or tool_calls: if tool_calls: tool_call = tool_calls[0] tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or "" @@ -236,6 +240,7 @@ def __init__( custom_config: dict[str, str] | None = None, extra_headers: dict[str, str] | None = None, extra_body: dict | None = LITELLM_EXTRA_BODY, + model_kwargs: dict[str, Any] | None = None, long_term_logger: LongTermLogger | None = None, ): self._timeout = timeout @@ -268,7 +273,7 @@ def __init__( for k, v in custom_config.items(): os.environ[k] = v - model_kwargs: dict[str, Any] = {} + model_kwargs = model_kwargs or {} if extra_headers: model_kwargs.update({"extra_headers": extra_headers}) if extra_body: diff --git a/backend/danswer/llm/factory.py b/backend/danswer/llm/factory.py index 9a2ae66d396..9f0f70f92e8 100644 --- a/backend/danswer/llm/factory.py +++ b/backend/danswer/llm/factory.py @@ -1,5 +1,8 @@ +from typing import Any + from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.chat_configs import QA_TIMEOUT +from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.db.engine import get_session_context_manager from danswer.db.llm import fetch_default_provider @@ -13,6 +16,15 @@ from danswer.utils.long_term_log import LongTermLogger +def _build_extra_model_kwargs(provider: str) -> dict[str, Any]: + """Ollama requires us to specify the max context window. + + For now, just using the GEN_AI_MODEL_FALLBACK_MAX_TOKENS value. + TODO: allow model-specific values to be configured via the UI. + """ + return {"num_ctx": GEN_AI_MODEL_FALLBACK_MAX_TOKENS} if provider == "ollama" else {} + + def get_main_llm_from_tuple( llms: tuple[LLM, LLM], ) -> LLM: @@ -132,5 +144,6 @@ def get_llm( temperature=temperature, custom_config=custom_config, extra_headers=build_llm_extra_headers(additional_headers), + model_kwargs=_build_extra_model_kwargs(provider), long_term_logger=long_term_logger, ) diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index 343f93147d8..e5564e88db0 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -1,3 +1,4 @@ +import copy import io import json from collections.abc import Callable @@ -385,6 +386,62 @@ def test_llm(llm: LLM) -> str | None: return error_msg +def get_model_map() -> dict: + starting_map = copy.deepcopy(cast(dict, litellm.model_cost)) + + # NOTE: we could add additional models here in the future, + # but for now there is no point. Ollama allows the user to + # to specify their desired max context window, and it's + # unlikely to be standard across users even for the same model + # (it heavily depends on their hardware). For now, we'll just + # rely on GEN_AI_MODEL_FALLBACK_MAX_TOKENS to cover this. + # for model_name in [ + # "llama3.2", + # "llama3.2:1b", + # "llama3.2:3b", + # "llama3.2:11b", + # "llama3.2:90b", + # ]: + # starting_map[f"ollama/{model_name}"] = { + # "max_tokens": 128000, + # "max_input_tokens": 128000, + # "max_output_tokens": 128000, + # } + + return starting_map + + +def _strip_extra_provider_from_model_name(model_name: str) -> str: + return model_name.split("/")[1] if "/" in model_name else model_name + + +def _strip_colon_from_model_name(model_name: str) -> str: + return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name + + +def _find_model_obj( + model_map: dict, provider: str, model_names: list[str | None] +) -> dict | None: + # Filter out None values and deduplicate model names + filtered_model_names = [name for name in model_names if name] + + # First try all model names with provider prefix + for model_name in filtered_model_names: + model_obj = model_map.get(f"{provider}/{model_name}") + if model_obj: + logger.debug(f"Using model object for {provider}/{model_name}") + return model_obj + + # Then try all model names without provider prefix + for model_name in filtered_model_names: + model_obj = model_map.get(model_name) + if model_obj: + logger.debug(f"Using model object for {model_name}") + return model_obj + + return None + + def get_llm_max_tokens( model_map: dict, model_name: str, @@ -397,22 +454,22 @@ def get_llm_max_tokens( return GEN_AI_MAX_TOKENS try: - model_obj = model_map.get(f"{model_provider}/{model_name}") - if model_obj: - logger.debug(f"Using model object for {model_provider}/{model_name}") - - if not model_obj: - model_obj = model_map.get(model_name) - if model_obj: - logger.debug(f"Using model object for {model_name}") - - if not model_obj: - model_name_split = model_name.split("/") - if len(model_name_split) > 1: - model_obj = model_map.get(model_name_split[1]) - if model_obj: - logger.debug(f"Using model object for {model_name_split[1]}") - + extra_provider_stripped_model_name = _strip_extra_provider_from_model_name( + model_name + ) + model_obj = _find_model_obj( + model_map, + model_provider, + [ + model_name, + # Remove leading extra provider. Usually for cases where user has a + # customer model proxy which appends another prefix + extra_provider_stripped_model_name, + # remove :XXXX from the end, if present. Needed for ollama. + _strip_colon_from_model_name(model_name), + _strip_colon_from_model_name(extra_provider_stripped_model_name), + ], + ) if not model_obj: raise RuntimeError( f"No litellm entry found for {model_provider}/{model_name}" @@ -488,7 +545,7 @@ def get_max_input_tokens( # `model_cost` dict is a named public interface: # https://litellm.vercel.app/docs/completion/token_usage#7-model_cost # model_map is litellm.model_cost - litellm_model_map = litellm.model_cost + litellm_model_map = get_model_map() input_toks = ( get_llm_max_tokens( diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 8a787aa3233..8a13bb8a74f 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -29,7 +29,7 @@ trafilatura==1.12.2 langchain==0.1.17 langchain-core==0.1.50 langchain-text-splitters==0.0.1 -litellm==1.50.2 +litellm==1.53.1 lxml==5.3.0 lxml_html_clean==0.2.2 llama-index==0.9.45 @@ -38,7 +38,7 @@ msal==1.28.0 nltk==3.8.1 Office365-REST-Python-Client==2.5.9 oauthlib==3.2.2 -openai==1.52.2 +openai==1.55.3 openpyxl==3.1.2 playwright==1.41.2 psutil==5.9.5