Skip to content

Commit

Permalink
feat(client): change all client to openai client
Browse files Browse the repository at this point in the history
  • Loading branch information
OrenZhang committed Dec 18, 2024
1 parent 750510c commit 6f14dcf
Show file tree
Hide file tree
Showing 18 changed files with 167 additions and 638 deletions.
8 changes: 1 addition & 7 deletions apps/cel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,4 @@
app.autodiscover_tasks()

# Schedule Tasks
app.conf.beat_schedule = {
"check_usage_limit": {
"task": "apps.chat.tasks.check_usage_limit",
"schedule": crontab(minute="*"),
"args": (),
},
}
app.conf.beat_schedule = {}
1 change: 1 addition & 0 deletions apps/chat/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class AIModelAdmin(admin.ModelAdmin):
"is_enabled",
"prompt_price",
"completion_price",
"vision_price",
]
list_filter = ["provider", "is_enabled"]

Expand Down
14 changes: 1 addition & 13 deletions apps/chat/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,7 @@
from apps.chat.client.claude import ClaudeClient
from apps.chat.client.gemini import GeminiClient
from apps.chat.client.hunyuan import HunYuanClient, HunYuanVisionClient
from apps.chat.client.kimi import KimiClient
from apps.chat.client.midjourney import MidjourneyClient
from apps.chat.client.openai import OpenAIClient, OpenAIVisionClient
from apps.chat.client.zhipu import ZhipuClient
from apps.chat.client.openai import OpenAIClient

__all__ = (
"GeminiClient",
"OpenAIClient",
"OpenAIVisionClient",
"HunYuanClient",
"HunYuanVisionClient",
"MidjourneyClient",
"KimiClient",
"ClaudeClient",
"ZhipuClient",
)
101 changes: 97 additions & 4 deletions apps/chat/client/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import abc
import base64
import datetime

from channels.db import database_sync_to_async
from django.conf import settings
from django.contrib.auth import get_user_model
from django.shortcuts import get_object_or_404
from django.utils import timezone
from django.utils.translation import gettext
from httpx import Client
from openai import OpenAI
from opentelemetry import trace
from opentelemetry.sdk.trace import Span
from opentelemetry.trace import SpanKind
from ovinc_client.core.logger import logger

from apps.chat.constants import OpenAIRole, SpanType
from apps.chat.models import AIModel, ChatLog, Message
from apps.chat.constants import MessageContentType, OpenAIRole, SpanType
from apps.chat.exceptions import FileExtractFailed, GenerateFailed
from apps.chat.models import AIModel, ChatLog, Message, MessageContent

USER_MODEL = get_user_model()

Expand Down Expand Up @@ -62,19 +69,105 @@ async def _chat(self, *args, **kwargs) -> any:

raise NotImplementedError()

async def record(self, prompt_tokens: int = 0, completion_tokens: int = 0) -> None:
async def record(self, prompt_tokens: int = 0, completion_tokens: int = 0, vision_count: int = 0) -> None:
if not self.log:
return
# calculate tokens
self.log.prompt_tokens = max(prompt_tokens, self.log.prompt_tokens)
vision_tokens = 0
if self.model_inst.prompt_price and self.model_inst.vision_price and vision_count:
vision_tokens = vision_count * self.model_inst.vision_price / self.model_inst.prompt_price
self.log.prompt_tokens = max(prompt_tokens, self.log.prompt_tokens) + vision_tokens
self.log.completion_tokens = max(completion_tokens, self.log.completion_tokens)
# calculate price
self.log.prompt_token_unit_price = self.model_inst.prompt_price
self.log.completion_token_unit_price = self.model_inst.completion_price
# save
self.log.finished_at = int(timezone.now().timestamp() * 1000)
await database_sync_to_async(self.log.save)()
# calculate usage
from apps.chat.tasks import calculate_usage_limit # pylint: disable=C0415

await database_sync_to_async(calculate_usage_limit)(log_id=self.log.id) # pylint: disable=E1120

def start_span(self, name: str, kind: SpanKind, **kwargs) -> Span:
span: Span = self.tracer.start_as_current_span(name=name, kind=kind, **kwargs)
return span


class OpenAIBaseClient(BaseClient, abc.ABC):
"""
OpenAI Client
"""

@property
@abc.abstractmethod
def api_key(self) -> str:
raise NotImplementedError()

@property
@abc.abstractmethod
def base_url(self) -> str:
raise NotImplementedError()

@property
def http_client(self) -> Client | None:
return None

@property
def timeout(self) -> int:
return settings.OPENAI_CHAT_TIMEOUT

@property
def api_model(self) -> str:
return self.model

async def _chat(self, *args, **kwargs) -> any:
image_count = self.format_message()
client = OpenAI(api_key=self.api_key, base_url=self.base_url, http_client=self.http_client)
try:
with self.start_span(SpanType.API, SpanKind.CLIENT):
response = client.chat.completions.create(
model=self.api_model,
messages=[message.model_dump(exclude_none=True) for message in self.messages],
temperature=self.temperature,
top_p=self.top_p,
stream=True,
timeout=self.timeout,
stream_options={"include_usage": True},
extra_headers={"HTTP-Referer": settings.PROJECT_URL, "X-Title": settings.PROJECT_NAME},
)
except Exception as err: # pylint: disable=W0718
logger.exception("[GenerateContentFailed] %s", err)
yield str(GenerateFailed())
response = []
prompt_tokens = 0
completion_tokens = 0
with self.start_span(SpanType.CHUNK, SpanKind.SERVER):
for chunk in response:
if chunk.choices:
yield chunk.choices[0].delta.content or ""
if chunk.usage:
prompt_tokens = chunk.usage.prompt_tokens
completion_tokens = chunk.usage.completion_tokens
await self.record(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, vision_count=image_count)

def format_message(self) -> int:
image_count = 0
for message in self.messages:
message: Message
if not isinstance(message.content, list):
continue
for content in message.content:
content: MessageContent
if content.type != MessageContentType.IMAGE_URL or not content.image_url:
continue
content.image_url.url = self.convert_url_to_base64(content.image_url.url)
image_count += 1
return image_count

def convert_url_to_base64(self, url: str) -> str:
with Client(http2=True) as client:
response = client.get(url)
if response.status_code == 200:
return f"data:image/webp;base64,{base64.b64encode(response.content).decode()}"
raise FileExtractFailed(gettext("Parse Image To Base64 Failed"))
100 changes: 0 additions & 100 deletions apps/chat/client/claude.py

This file was deleted.

73 changes: 0 additions & 73 deletions apps/chat/client/gemini.py

This file was deleted.

Loading

0 comments on commit 6f14dcf

Please sign in to comment.