Skip to content

Commit

Permalink
Merge pull request #101 from njbbaer/message-caching
Browse files Browse the repository at this point in the history
Implement OpenRouter message caching
  • Loading branch information
njbbaer authored Oct 26, 2024
2 parents fd4a1db + 48bb9a2 commit ebafbfa
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 120 deletions.
97 changes: 31 additions & 66 deletions src/api_client.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,55 @@
import os
from abc import ABC, abstractmethod

import httpx

from .chat_completion import AnthropicChatCompletion, OpenRouterChatCompletion
from .chat_completion import ChatCompletion
from .logger import Logger


class APIClient(ABC):
class OpenRouterAPIClient:
API_URL = "https://openrouter.ai/api/v1/chat/completions"
ENV_KEY = "OPENROUTER_API_KEY"
TIMEOUT = 30

def __init__(self):
self.api_key = os.environ.get(self.ENV_KEY)
self.logger = Logger("log.yml")

@staticmethod
def create(client_type):
clients = {"openrouter": OpenRouterAPIClient, "anthropic": AnthropicAPIClient}
if client_type not in clients:
raise ValueError(f"Unsupported client type: {client_type}")
return clients[client_type]()

async def request_completion(self, messages, parameters, pricing):
body = self.prepare_body(messages, parameters)
try:
async with httpx.AsyncClient(timeout=self.TIMEOUT) as client:
response = await client.post(
self.API_URL, headers=self.get_headers(), json=body
)
response.raise_for_status()
completion = self.create_completion(response.json(), pricing)
self.logger.log(parameters, messages, completion.content)
return completion
except httpx.ReadTimeout:
raise Exception("Request timed out")

@abstractmethod
def get_headers(self):
pass

@abstractmethod
def prepare_body(self, messages, parameters):
pass

@abstractmethod
def create_completion(self, response, pricing):
pass


class OpenRouterAPIClient(APIClient):
API_URL = "https://openrouter.ai/api/v1/chat/completions"
ENV_KEY = "OPENROUTER_API_KEY"

def get_headers(self):
return {"Authorization": f"Bearer {self.api_key}"}

def prepare_body(self, messages, parameters):
return {"messages": messages, **parameters}

def create_completion(self, response, pricing):
return OpenRouterChatCompletion(response, pricing)


class AnthropicAPIClient(APIClient):
API_URL = "https://api.anthropic.com/v1/messages"
ENV_KEY = "ANTHROPIC_API_KEY"
async def request_completion(self, messages, parameters, pricing):
body = self.prepare_body(messages, parameters)
try:
completion_data = await self.get_completion_data(body)
completion = ChatCompletion(completion_data, pricing)
self.logger.log(parameters, messages, completion.content)
return completion
except httpx.ReadTimeout:
raise Exception("Request timed out")

def get_headers(self):
return {
"x-api-key": self.api_key,
"anthropic-version": "2023-06-01",
"anthropic-beta": "prompt-caching-2024-07-31",
}
async def get_completion_data(self, body):
async with httpx.AsyncClient(timeout=self.TIMEOUT) as client:
completion_response = await client.post(
self.API_URL, headers=self.get_headers(), json=body
)
completion_response.raise_for_status()
completion_data = completion_response.json()
if "error" in completion_data:
raise Exception(completion_data["error"])

def prepare_body(self, messages, parameters):
other_messages, system = self._transform_messages(messages)
return {"messages": other_messages, "system": system, **parameters}
details_data = await self._poll_details(client, completion_data["id"])
return {**completion_data, "details": details_data["data"]}

def create_completion(self, response, pricing):
return AnthropicChatCompletion(response, pricing)
async def _poll_details(self, client, generation_id, max_attempts=10):
details_url = f"https://openrouter.ai/api/v1/generation?id={generation_id}"

def _transform_messages(self, original_messages):
messages = [msg for msg in original_messages if msg["role"] != "system"]
system = []
for msg in original_messages:
if msg["role"] == "system":
system.extend(msg["content"])
for _ in range(max_attempts):
details_response = await client.get(details_url, headers=self.get_headers())
if details_response.status_code == 200:
return details_response.json()

return messages, system
raise TimeoutError("Details not available after maximum attempts")
64 changes: 15 additions & 49 deletions src/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,74 +12,40 @@ def validate(self):
if self.content == "":
raise Exception("Response was empty")

@property
def cost(self):
if self.pricing:
return (self.prompt_tokens / 1_000_000 * self.pricing[0]) + (
self.completion_tokens / 1_000_000 * self.pricing[1]
)
return 0

@property
def error_message(self):
return self.response.get("error", {}).get("message", "")

@property
def cache_creation_input_tokens(self):
return 0

@property
def cache_read_input_tokens(self):
return 0


class AnthropicChatCompletion(ChatCompletion):
@property
def choice(self):
return self.response["content"][0]
return self.response["choices"][0]

@property
def content(self):
return self.choice["text"]
return self.choice["message"]["content"]

@property
def prompt_tokens(self):
return self.response["usage"]["input_tokens"]
return self.response["usage"]["prompt_tokens"]

@property
def completion_tokens(self):
return self.response["usage"]["output_tokens"]

@property
def cache_creation_input_tokens(self):
return self.response["usage"]["cache_creation_input_tokens"]

@property
def cache_read_input_tokens(self):
return self.response["usage"]["cache_read_input_tokens"]
return self.response["usage"]["completion_tokens"]

@property
def finish_reason(self):
return self.response["stop_reason"]


class OpenRouterChatCompletion(ChatCompletion):
@property
def choice(self):
return self.response["choices"][0]
return self.choice.get("finish_reason")

@property
def content(self):
return self.choice["message"]["content"]
def error_message(self):
return self.response.get("error", {}).get("message", "")

@property
def prompt_tokens(self):
return self.response["usage"]["prompt_tokens"]
def cache_discount(self):
return self.response["details"]["cache_discount"]

@property
def completion_tokens(self):
return self.response["usage"]["completion_tokens"]
def cache_discount_string(self):
sign = "-" if self.cache_discount < 0 else ""
amount = f"${abs(self.cache_discount):.2f}"
return f"{sign}{amount}"

@property
def finish_reason(self):
return self.choice.get("finish_reason")
def cost(self):
return self.response["details"]["total_cost"]
6 changes: 3 additions & 3 deletions src/lm_executors/chat_executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jinja2
import yaml

from ..api_client import APIClient
from ..api_client import OpenRouterAPIClient
from ..resolve_vars import resolve_vars


Expand All @@ -12,9 +12,9 @@ def __init__(self, context):
self.context = context

async def execute(self):
client = APIClient.create(self.context.api_provider)
client = OpenRouterAPIClient()

params = {"max_tokens": 1000}
params = {"max_tokens": 1024}
if self.context.model is not None:
params["model"] = self.context.model

Expand Down
4 changes: 4 additions & 0 deletions src/lm_executors/chat_executor_template.j2
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
- type: text
text: |-
{{ message.content | indent(8) }}
{% if loop.index in [(messages|length), (messages|length - 2)] %}
cache_control:
type: ephemeral
{% endif %}
{% endfor %}
{% if reinforcement_chat_prompt %}
- role: {{ 'system' if not 'claude' in model else 'assistant' }}
Expand Down
3 changes: 1 addition & 2 deletions src/telegram/telegram_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,9 @@ async def stats_command_handler(self, ctx):
last_message_stats += "\n".join(
[
f"`Cost: ${lc.cost:.2f}`",
f"`Cache discount: {lc.cache_discount_string}`",
f"`Prompt tokens: {lc.prompt_tokens}`",
f"`Completion tokens: {lc.completion_tokens}`",
f"`Cache creation tokens: {lc.cache_creation_input_tokens}`",
f"`Cache read tokens: {lc.cache_read_input_tokens}`",
]
)
else:
Expand Down

0 comments on commit ebafbfa

Please sign in to comment.