Skip to content

Commit

Permalink
initial litellm tests for BerriAI/litellm#6592
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanMarten committed Jan 7, 2025
1 parent 0c7cf21 commit e545db2
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 142 deletions.
225 changes: 109 additions & 116 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ tqdm = "^4.67.0"
matplotlib = "^3.9.2"
nest-asyncio = "^1.6.0"
rich = "^13.7.0"
litellm = "1.55.4"
litellm = "^1.57.0"
isort = "^5.13.2"
tiktoken = ">=0.7.0,<0.8.0"
aiofiles = ">=22.0,<24.0"
Expand Down
4 changes: 2 additions & 2 deletions src/bespokelabs/curator/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(

if batch:
config_params = {
"model": model_name,
"model_name": model_name,
"base_url": base_url,
"batch_size": batch_size,
"batch_check_interval": batch_check_interval,
Expand All @@ -106,7 +106,7 @@ def __init__(
config = BatchRequestProcessorConfig(**_remove_none_values(config_params))
else:
config_params = {
"model": model_name,
"model_name": model_name,
"base_url": base_url,
"max_requests_per_minute": max_requests_per_minute,
"max_tokens_per_minute": max_tokens_per_minute,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ def run(
return output_dataset

logger.info(
f"Running {self.__class__.__name__} completions with model: {self.config.model}"
f"Running {self.__class__.__name__} completions with model: {self.config.model_name}"
)

self.prompt_formatter = prompt_formatter
if self.prompt_formatter.response_format:
if not self.check_structured_output_support():
raise ValueError(
f"Model {self.config.model} does not support structured output, "
f"Model {self.config.model_name} does not support structured output, "
f"response_format: {self.prompt_formatter.response_format}"
)
generic_request_files = self.create_request_files(dataset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def create_api_specific_request_batch(self, generic_request: GenericRequest) ->
"custom_id": str(generic_request.original_row_idx),
"params": {
"model": generic_request.model,
"max_tokens": litellm.get_max_tokens(self.config.model),
"max_tokens": litellm.get_max_tokens(self.config.model_name),
**kwargs, # contains 'system' and 'messages'
**generic_request.generation_params, # contains 'temperature', 'top_p', etc.
},
Expand Down Expand Up @@ -243,7 +243,7 @@ def parse_api_specific_response(
)

cost = litellm.completion_cost(
model=self.config.model,
model=self.config.model_name,
prompt=str(generic_request.messages),
completion=response_message_raw,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def parse_api_specific_response(
)

cost = litellm.completion_cost(
model=self.config.model,
model=self.config.model_name,
prompt=str(generic_request.messages),
completion=response_message_raw,
)
Expand Down
8 changes: 4 additions & 4 deletions src/bespokelabs/curator/request_processor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,22 @@
class RequestProcessorConfig(BaseModel):
"""Configuration for request processors."""

model: str
model_name: str
base_url: str | None = None
max_retries: int = Field(default=10, ge=0)
request_timeout: int = Field(default=10 * 60, gt=0)
require_all_responses: bool = Field(default=True)
generation_params: dict = Field(default_factory=dict)

def __post_init__(self):
self.supported_params = litellm.get_supported_openai_params(model=self.model)
self.supported_params = litellm.get_supported_openai_params(model=self.model_name)
logger.debug(
f"Automatically detected supported params using litellm for {self.model}: {self.supported_params}"
f"Automatically detected supported params using litellm for {self.model_name}: {self.supported_params}"
)

for key in self.generation_params.keys():
raise ValueError(
f"Generation parameter '{key}' is not supported for model '{self.model}'"
f"Generation parameter '{key}' is not supported for model '{self.model_name}'"
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,24 +67,26 @@ class User(BaseModel):
try:
response = run_in_event_loop(
self.client.chat.completions.create(
model=self.config.model,
model=self.config.model_name,
messages=[{"role": "user", "content": "Jason is 25 years old."}],
response_model=User,
)
)
logger.info(f"Check instructor structure output response: {response}")
assert isinstance(response, User)
logger.info(
f"Model {self.config.model} supports structured output via instructor, response: {response}"
f"Model {self.config.model_name} supports structured output via instructor, response: {response}"
)
return True
except instructor.exceptions.InstructorRetryException as e:
if "litellm.AuthenticationError" in str(e):
logger.warning(f"Please provide a valid API key for model {self.config.model}.")
logger.warning(
f"Please provide a valid API key for model {self.config.model_name}."
)
raise e
else:
logger.warning(
f"Model {self.config.model} does not support structured output via instructor: {e} {type(e)} {e.__cause__}"
f"Model {self.config.model_name} does not support structured output via instructor: {e} {type(e)} {e.__cause__}"
)
return False

Expand All @@ -101,7 +103,7 @@ def estimate_output_tokens(self) -> int:
Falls back to 0 if token estimation fails
"""
try:
return litellm.get_max_tokens(model=self.config.model) // 4
return litellm.get_max_tokens(model=self.config.model_name) // 4
except Exception:
return 0

Expand All @@ -117,13 +119,13 @@ def estimate_total_tokens(self, messages: list) -> int:
Returns:
int: Total estimated tokens (input + output)
"""
input_tokens = litellm.token_counter(model=self.config.model, messages=messages)
input_tokens = litellm.token_counter(model=self.config.model_name, messages=messages)
output_tokens = self.estimate_output_tokens()
return input_tokens + output_tokens

def test_call(self):
completion = litellm.completion(
model=self.config.model,
model=self.config.model_name,
messages=[
{"role": "user", "content": "hi"}
], # Some models (e.g. Claude) require an non-empty message to get rate limits.
Expand All @@ -133,7 +135,9 @@ def test_call(self):
litellm.completion_cost(completion_response=completion.model_dump())
except Exception as e:
# We should ideally not catch a catch-all exception here. But litellm is not throwing any specific error.
logger.warning(f"LiteLLM does not support cost estimation for model {self.model}: {e}")
logger.warning(
f"LiteLLM does not support cost estimation for model {self.config.model_name}: {e}"
)

headers = completion._hidden_params.get("additional_headers", {})
logger.info(f"Test call headers: {headers}")
Expand All @@ -149,7 +153,7 @@ def get_header_based_rate_limits(self) -> tuple[int, int]:
- Makes a test request to get rate limit information from response headers.
- Some providers (e.g., Claude) require non-empty messages
"""
logger.info(f"Getting rate limits for model: {self.config.model}")
logger.info(f"Getting rate limits for model: {self.config.model_name}")

headers = self.test_call()
rpm = int(headers.get("x-ratelimit-limit-requests", 0))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_header_based_rate_limits(self) -> tuple[int, int]:
response = requests.post(
self.url,
headers={"Authorization": f"Bearer {self.api_key}"},
json={"model": self.config.model, "messages": []},
json={"model": self.config.model_name, "messages": []},
)
rpm = int(response.headers.get("x-ratelimit-limit-requests", 0))
tpm = int(response.headers.get("x-ratelimit-limit-tokens", 0))
Expand All @@ -82,7 +82,7 @@ def estimate_output_tokens(self) -> int:
Override this method for more accurate model-specific estimates.
"""
try:
return litellm.get_max_tokens(model=self.config.model) // 4
return litellm.get_max_tokens(model=self.config.model_name) // 4
except Exception:
return 0

Expand Down Expand Up @@ -131,7 +131,7 @@ def check_structured_output_support(self) -> bool:
- gpt-4o-mini with date >= 2024-07-18 or latest
- gpt-4o with date >= 2024-08-06 or latest
"""
model_name = self.config.model.lower()
model_name = self.config.model_name.lower()

# Check gpt-4o-mini support
if model_name == "gpt-4o-mini": # Latest version
Expand Down Expand Up @@ -231,9 +231,9 @@ async def call_single_request(

def get_token_encoding(self) -> str:
"""Get the token encoding name for a given model."""
if self.config.model.startswith("gpt-4"):
if self.config.model_name.startswith("gpt-4"):
name = "cl100k_base"
elif self.config.model.startswith("gpt-3.5"):
elif self.config.model_name.startswith("gpt-3.5"):
name = "cl100k_base"
else:
name = "cl100k_base" # Default to cl100k_base
Expand Down

0 comments on commit e545db2

Please sign in to comment.