Skip to content

Commit

Permalink
Minor refactoring and improve error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
njbbaer committed May 5, 2024
1 parent c9449c2 commit f638018
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 30 deletions.
32 changes: 14 additions & 18 deletions src/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


class ApiClient:
TIMEOUT = 120
CONFIG = {
"openai": {
"api_url": "https://api.openai.com",
Expand All @@ -15,30 +16,25 @@ class ApiClient:
},
}

def __init__(self, name, instruction_template=None):
self.name = name or "openai"
self.instruction_template = instruction_template

@property
def api_url(self):
return self.CONFIG[self.name]["api_url"]

@property
def api_key(self):
return os.environ.get(self.CONFIG[self.name]["api_key_env"])
def __init__(self, provider):
self.provider = provider

async def call_api(self, messages, parameters):
body = {"messages": messages, **parameters}
if self.instruction_template:
body["instruction_template"] = self.instruction_template

headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}

async with httpx.AsyncClient(timeout=120) as client:
async with httpx.AsyncClient(timeout=self.TIMEOUT) as client:
url = f"{self.api_url}/v1/chat/completions"
response = await client.post(url, headers=headers, json=body)
response.raise_for_status()
response_json = response.json()
if "error" in response_json:
raise Exception(response_json["error"]["message"])
return response_json
return response.json()

@property
def api_url(self):
return self.CONFIG[self.provider]["api_url"]

@property
def api_key(self):
var = self.CONFIG[self.provider]["api_key_env"]
return os.environ.get(var)
17 changes: 13 additions & 4 deletions src/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@


class ChatCompletion:
API_CALL_TIMEOUT = 120
MODEL_PRICES = {
"gpt-4": [0.03, 0.06],
"gpt-4-0314": [0.03, 0.06],
Expand All @@ -20,16 +19,22 @@ def __init__(self, response, model):
self.pricing = self.MODEL_PRICES.get(model, [0, 0])
self.logger = Logger("log.yml")

if self.finish_reason == "length":
raise Exception("Response exceeded maximum length")

@classmethod
async def generate(cls, client, content, parameters):
response = await client.call_api(content, parameters)
completion = cls(response, parameters.get("model"))
completion.validate()
completion.logger.log(parameters, content, completion.content)
return completion

def validate(self):
if self.error_message:
raise Exception(self.error_message)
if self.finish_reason == "length":
raise Exception("Response exceeded maximum length")
if self.content == "":
raise Exception("Response was empty")

@property
def choice(self):
return self.response["choices"][0]
Expand All @@ -55,3 +60,7 @@ def cost(self):
return (self.prompt_tokens / 1000 * self.pricing[0]) + (
self.completion_tokens / 1000 * self.pricing[1]
)

@property
def error_message(self):
return self.response.get("error", {}).get("message", "")
4 changes: 0 additions & 4 deletions src/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,6 @@ def model(self):
def conversation_id(self):
return self._data.setdefault("conversation_id", self._next_conversation_id())

@property
def instruction_template(self):
return self._data.get("instruction_template")

def _load_conversation(self):
path = f"{self.conversations_dir}/{self.char_name}_{self.conversation_id}.yml"
self._conversation = Conversation(path)
Expand Down
5 changes: 1 addition & 4 deletions src/lm_executors/chat_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@ def __init__(self, context):
self.context = context

async def execute(self):
client = ApiClient(
name=self.context.api_provider,
instruction_template=self.context.instruction_template,
)
client = ApiClient(self.context.api_provider)

params = {"max_tokens": 1000}
if self.context.model is not None:
Expand Down

0 comments on commit f638018

Please sign in to comment.