From 21b8af038101388d72f9a6bbfc5b418c61d29d7b Mon Sep 17 00:00:00 2001 From: Nate Baer Date: Fri, 25 Oct 2024 22:50:40 -0700 Subject: [PATCH 1/5] Add cache control --- src/api_client.py | 35 +++++++++++++++++----- src/chat_completion.py | 26 ++++++---------- src/lm_executors/chat_executor_template.j2 | 4 +++ src/telegram/telegram_bot.py | 3 +- 4 files changed, 41 insertions(+), 27 deletions(-) diff --git a/src/api_client.py b/src/api_client.py index 043c68b..a2ddce9 100644 --- a/src/api_client.py +++ b/src/api_client.py @@ -24,17 +24,36 @@ def create(client_type): async def request_completion(self, messages, parameters, pricing): body = self.prepare_body(messages, parameters) try: - async with httpx.AsyncClient(timeout=self.TIMEOUT) as client: - response = await client.post( - self.API_URL, headers=self.get_headers(), json=body - ) - response.raise_for_status() - completion = self.create_completion(response.json(), pricing) - self.logger.log(parameters, messages, completion.content) - return completion + completion_data = await self.get_completion_data(body) + completion = self.create_completion(completion_data, pricing) + self.logger.log(parameters, messages, completion.content) + return completion except httpx.ReadTimeout: raise Exception("Request timed out") + async def get_completion_data(self, body): + async with httpx.AsyncClient(timeout=self.TIMEOUT) as client: + completion_response = await client.post( + self.API_URL, headers=self.get_headers(), json=body + ) + completion_response.raise_for_status() + completion_data = completion_response.json() + if "error" in completion_data: + raise Exception(completion_data["error"]) + + details_data = await self._poll_details(client, completion_data["id"]) + return {**completion_data, "details": details_data["data"]} + + async def _poll_details(self, client, generation_id, max_attempts=10): + details_url = f"https://openrouter.ai/api/v1/generation?id={generation_id}" + + for _ in range(max_attempts): + details_response = await client.get(details_url, headers=self.get_headers()) + if details_response.status_code == 200: + return details_response.json() + + raise TimeoutError("Details not available after maximum attempts") + @abstractmethod def get_headers(self): pass diff --git a/src/chat_completion.py b/src/chat_completion.py index f647b27..c523a16 100644 --- a/src/chat_completion.py +++ b/src/chat_completion.py @@ -25,11 +25,7 @@ def error_message(self): return self.response.get("error", {}).get("message", "") @property - def cache_creation_input_tokens(self): - return 0 - - @property - def cache_read_input_tokens(self): + def cache_discount(self): return 0 @@ -46,18 +42,6 @@ def content(self): def prompt_tokens(self): return self.response["usage"]["input_tokens"] - @property - def completion_tokens(self): - return self.response["usage"]["output_tokens"] - - @property - def cache_creation_input_tokens(self): - return self.response["usage"]["cache_creation_input_tokens"] - - @property - def cache_read_input_tokens(self): - return self.response["usage"]["cache_read_input_tokens"] - @property def finish_reason(self): return self.response["stop_reason"] @@ -83,3 +67,11 @@ def completion_tokens(self): @property def finish_reason(self): return self.choice.get("finish_reason") + + @property + def cache_discount(self): + return self.response["details"]["cache_discount"] + + @property + def cost(self): + return self.response["details"]["total_cost"] diff --git a/src/lm_executors/chat_executor_template.j2 b/src/lm_executors/chat_executor_template.j2 index 16c30c3..daa7680 100644 --- a/src/lm_executors/chat_executor_template.j2 +++ b/src/lm_executors/chat_executor_template.j2 @@ -25,6 +25,10 @@ - type: text text: |- {{ message.content | indent(8) }} + {% if loop.index in [(messages|length), (messages|length - 2)] %} + cache_control: + type: ephemeral + {% endif %} {% endfor %} {% if reinforcement_chat_prompt %} - role: {{ 'system' if not 'claude' in model else 'assistant' }} diff --git a/src/telegram/telegram_bot.py b/src/telegram/telegram_bot.py index cb8f2a3..41e6eee 100644 --- a/src/telegram/telegram_bot.py +++ b/src/telegram/telegram_bot.py @@ -104,8 +104,7 @@ async def stats_command_handler(self, ctx): f"`Cost: ${lc.cost:.2f}`", f"`Prompt tokens: {lc.prompt_tokens}`", f"`Completion tokens: {lc.completion_tokens}`", - f"`Cache creation tokens: {lc.cache_creation_input_tokens}`", - f"`Cache read tokens: {lc.cache_read_input_tokens}`", + f"`Cache discount: ${lc.cache_discount:.2f}`", ] ) else: From 9ef54aea135a005ebee065176e852e22de2f5bd7 Mon Sep 17 00:00:00 2001 From: Nate Baer Date: Fri, 25 Oct 2024 23:13:54 -0700 Subject: [PATCH 2/5] Remove AnthropicAPIClient --- src/api_client.py | 74 +++++-------------------------- src/lm_executors/chat_executor.py | 4 +- 2 files changed, 12 insertions(+), 66 deletions(-) diff --git a/src/api_client.py b/src/api_client.py index a2ddce9..2415f7f 100644 --- a/src/api_client.py +++ b/src/api_client.py @@ -1,31 +1,31 @@ import os -from abc import ABC, abstractmethod import httpx -from .chat_completion import AnthropicChatCompletion, OpenRouterChatCompletion +from .chat_completion import OpenRouterChatCompletion from .logger import Logger -class APIClient(ABC): +class OpenRouterAPIClient: + API_URL = "https://openrouter.ai/api/v1/chat/completions" + ENV_KEY = "OPENROUTER_API_KEY" TIMEOUT = 30 def __init__(self): self.api_key = os.environ.get(self.ENV_KEY) self.logger = Logger("log.yml") - @staticmethod - def create(client_type): - clients = {"openrouter": OpenRouterAPIClient, "anthropic": AnthropicAPIClient} - if client_type not in clients: - raise ValueError(f"Unsupported client type: {client_type}") - return clients[client_type]() + def get_headers(self): + return {"Authorization": f"Bearer {self.api_key}"} + + def prepare_body(self, messages, parameters): + return {"messages": messages, **parameters} async def request_completion(self, messages, parameters, pricing): body = self.prepare_body(messages, parameters) try: completion_data = await self.get_completion_data(body) - completion = self.create_completion(completion_data, pricing) + completion = OpenRouterChatCompletion(completion_data, pricing) self.logger.log(parameters, messages, completion.content) return completion except httpx.ReadTimeout: @@ -53,57 +53,3 @@ async def _poll_details(self, client, generation_id, max_attempts=10): return details_response.json() raise TimeoutError("Details not available after maximum attempts") - - @abstractmethod - def get_headers(self): - pass - - @abstractmethod - def prepare_body(self, messages, parameters): - pass - - @abstractmethod - def create_completion(self, response, pricing): - pass - - -class OpenRouterAPIClient(APIClient): - API_URL = "https://openrouter.ai/api/v1/chat/completions" - ENV_KEY = "OPENROUTER_API_KEY" - - def get_headers(self): - return {"Authorization": f"Bearer {self.api_key}"} - - def prepare_body(self, messages, parameters): - return {"messages": messages, **parameters} - - def create_completion(self, response, pricing): - return OpenRouterChatCompletion(response, pricing) - - -class AnthropicAPIClient(APIClient): - API_URL = "https://api.anthropic.com/v1/messages" - ENV_KEY = "ANTHROPIC_API_KEY" - - def get_headers(self): - return { - "x-api-key": self.api_key, - "anthropic-version": "2023-06-01", - "anthropic-beta": "prompt-caching-2024-07-31", - } - - def prepare_body(self, messages, parameters): - other_messages, system = self._transform_messages(messages) - return {"messages": other_messages, "system": system, **parameters} - - def create_completion(self, response, pricing): - return AnthropicChatCompletion(response, pricing) - - def _transform_messages(self, original_messages): - messages = [msg for msg in original_messages if msg["role"] != "system"] - system = [] - for msg in original_messages: - if msg["role"] == "system": - system.extend(msg["content"]) - - return messages, system diff --git a/src/lm_executors/chat_executor.py b/src/lm_executors/chat_executor.py index e50a5ff..c24a769 100644 --- a/src/lm_executors/chat_executor.py +++ b/src/lm_executors/chat_executor.py @@ -1,7 +1,7 @@ import jinja2 import yaml -from ..api_client import APIClient +from ..api_client import OpenRouterAPIClient from ..resolve_vars import resolve_vars @@ -12,7 +12,7 @@ def __init__(self, context): self.context = context async def execute(self): - client = APIClient.create(self.context.api_provider) + client = OpenRouterAPIClient() params = {"max_tokens": 1000} if self.context.model is not None: From 534b52e9df8b234a26ee99be53a8664c53b3c860 Mon Sep 17 00:00:00 2001 From: Nate Baer Date: Fri, 25 Oct 2024 23:22:27 -0700 Subject: [PATCH 3/5] Remove AnthropicChatCompletion --- src/api_client.py | 4 ++-- src/chat_completion.py | 40 ++++------------------------------------ 2 files changed, 6 insertions(+), 38 deletions(-) diff --git a/src/api_client.py b/src/api_client.py index 2415f7f..1185647 100644 --- a/src/api_client.py +++ b/src/api_client.py @@ -2,7 +2,7 @@ import httpx -from .chat_completion import OpenRouterChatCompletion +from .chat_completion import ChatCompletion from .logger import Logger @@ -25,7 +25,7 @@ async def request_completion(self, messages, parameters, pricing): body = self.prepare_body(messages, parameters) try: completion_data = await self.get_completion_data(body) - completion = OpenRouterChatCompletion(completion_data, pricing) + completion = ChatCompletion(completion_data, pricing) self.logger.log(parameters, messages, completion.content) return completion except httpx.ReadTimeout: diff --git a/src/chat_completion.py b/src/chat_completion.py index c523a16..de9c9bf 100644 --- a/src/chat_completion.py +++ b/src/chat_completion.py @@ -12,42 +12,6 @@ def validate(self): if self.content == "": raise Exception("Response was empty") - @property - def cost(self): - if self.pricing: - return (self.prompt_tokens / 1_000_000 * self.pricing[0]) + ( - self.completion_tokens / 1_000_000 * self.pricing[1] - ) - return 0 - - @property - def error_message(self): - return self.response.get("error", {}).get("message", "") - - @property - def cache_discount(self): - return 0 - - -class AnthropicChatCompletion(ChatCompletion): - @property - def choice(self): - return self.response["content"][0] - - @property - def content(self): - return self.choice["text"] - - @property - def prompt_tokens(self): - return self.response["usage"]["input_tokens"] - - @property - def finish_reason(self): - return self.response["stop_reason"] - - -class OpenRouterChatCompletion(ChatCompletion): @property def choice(self): return self.response["choices"][0] @@ -68,6 +32,10 @@ def completion_tokens(self): def finish_reason(self): return self.choice.get("finish_reason") + @property + def error_message(self): + return self.response.get("error", {}).get("message", "") + @property def cache_discount(self): return self.response["details"]["cache_discount"] From 2039f38be3f808331511753ad572a4a697ae482e Mon Sep 17 00:00:00 2001 From: Nate Baer Date: Fri, 25 Oct 2024 23:45:29 -0700 Subject: [PATCH 4/5] Update max_tokens --- src/lm_executors/chat_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lm_executors/chat_executor.py b/src/lm_executors/chat_executor.py index c24a769..ca16759 100644 --- a/src/lm_executors/chat_executor.py +++ b/src/lm_executors/chat_executor.py @@ -14,7 +14,7 @@ def __init__(self, context): async def execute(self): client = OpenRouterAPIClient() - params = {"max_tokens": 1000} + params = {"max_tokens": 1024} if self.context.model is not None: params["model"] = self.context.model From 48bb9a2489a8b9204427e7d238cacc1b2c35bac2 Mon Sep 17 00:00:00 2001 From: Nate Baer Date: Fri, 25 Oct 2024 23:45:51 -0700 Subject: [PATCH 5/5] Improve cache discount stat --- src/chat_completion.py | 6 ++++++ src/telegram/telegram_bot.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/chat_completion.py b/src/chat_completion.py index de9c9bf..4237e9a 100644 --- a/src/chat_completion.py +++ b/src/chat_completion.py @@ -40,6 +40,12 @@ def error_message(self): def cache_discount(self): return self.response["details"]["cache_discount"] + @property + def cache_discount_string(self): + sign = "-" if self.cache_discount < 0 else "" + amount = f"${abs(self.cache_discount):.2f}" + return f"{sign}{amount}" + @property def cost(self): return self.response["details"]["total_cost"] diff --git a/src/telegram/telegram_bot.py b/src/telegram/telegram_bot.py index 41e6eee..29faa5e 100644 --- a/src/telegram/telegram_bot.py +++ b/src/telegram/telegram_bot.py @@ -102,9 +102,9 @@ async def stats_command_handler(self, ctx): last_message_stats += "\n".join( [ f"`Cost: ${lc.cost:.2f}`", + f"`Cache discount: {lc.cache_discount_string}`", f"`Prompt tokens: {lc.prompt_tokens}`", f"`Completion tokens: {lc.completion_tokens}`", - f"`Cache discount: ${lc.cache_discount:.2f}`", ] ) else: