Skip to content

Commit

Permalink
feat(zhipu): support for zhipu ai api
Browse files Browse the repository at this point in the history
  • Loading branch information
OrenZhang committed Dec 17, 2024
1 parent 04ab9a6 commit 1ec94fc
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 0 deletions.
2 changes: 2 additions & 0 deletions apps/chat/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from apps.chat.client.kimi import KimiClient
from apps.chat.client.midjourney import MidjourneyClient
from apps.chat.client.openai import OpenAIClient, OpenAIVisionClient
from apps.chat.client.zhipu import ZhipuClient

__all__ = (
"GeminiClient",
Expand All @@ -14,4 +15,5 @@
"MidjourneyClient",
"KimiClient",
"ClaudeClient",
"ZhipuClient",
)
54 changes: 54 additions & 0 deletions apps/chat/client/zhipu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# pylint: disable=R0801

from typing import List

from django.conf import settings
from openai import OpenAI
from opentelemetry.trace import SpanKind
from ovinc_client.core.logger import logger

from apps.chat.client.base import BaseClient
from apps.chat.constants import SpanType
from apps.chat.exceptions import GenerateFailed


class ZhipuClient(BaseClient):
"""
Zhipu Client
"""

# pylint: disable=R0913,R0917
def __init__(self, user: str, model: str, messages: List[dict], temperature: float, top_p: float):
super().__init__(user=user, model=model, messages=messages, temperature=temperature, top_p=top_p)
self.client = OpenAI(api_key=settings.ZHIPU_API_KEY, base_url=settings.ZHIPU_API_URL)

async def _chat(self, *args, **kwargs) -> any:
try:
with self.start_span(SpanType.API, SpanKind.CLIENT):
response = self.client.chat.completions.create(
model=self.model,
messages=self.messages,
temperature=self.temperature,
top_p=self.top_p,
stream=True,
timeout=settings.ZHIPU_API_TIMEOUT,
stream_options={"include_usage": True},
)
except Exception as err: # pylint: disable=W0718
logger.exception("[GenerateContentFailed] %s", err)
yield str(GenerateFailed())
response = []
content = ""
prompt_tokens = 0
completion_tokens = 0
with self.start_span(SpanType.CHUNK, SpanKind.SERVER):
# pylint: disable=E1133
for chunk in response:
self.log.chat_id = chunk.id
if chunk.choices:
content += chunk.choices[0].delta.content or ""
yield chunk.choices[0].delta.content or ""
if chunk.usage:
prompt_tokens = chunk.usage.prompt_tokens
completion_tokens = chunk.usage.completion_tokens
await self.record(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
1 change: 1 addition & 0 deletions apps/chat/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class AIModelProvider(TextChoices):
MIDJOURNEY = "midjourney", gettext_lazy("Midjourney")
MOONSHOT = "moonshot", gettext_lazy("Moonshot")
CLAUDE = "claude", gettext_lazy("Claude")
ZHIPU = "zhipu", gettext_lazy("Zhipu")


class VisionSize(TextChoices):
Expand Down
3 changes: 3 additions & 0 deletions apps/chat/consumers_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
MidjourneyClient,
OpenAIClient,
OpenAIVisionClient,
ZhipuClient,
)
from apps.chat.client.base import BaseClient
from apps.chat.constants import WS_CLOSED_KEY, AIModelProvider
Expand Down Expand Up @@ -137,6 +138,8 @@ def get_model_client(self, model: AIModel) -> Type[BaseClient]:
return KimiClient
case AIModelProvider.CLAUDE:
return ClaudeClient
case AIModelProvider.ZHIPU:
return ZhipuClient
case _:
raise UnexpectedProvider()

Expand Down
5 changes: 5 additions & 0 deletions entry/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,8 @@
ANTHROPIC_BASE_URL = os.getenv("ANTHROPIC_BASE_URL", "")
ANTHROPIC_MAX_TOKENS = int(os.getenv("ANTHROPIC_MAX_TOKENS", "4096"))
ANTHROPIC_TIMEOUT = int(os.getenv("ANTHROPIC_TIMEOUT", "60"))

# Zhipu
ZHIPU_API_KEY = os.getenv("ZHIPU_API_KEY", "")
ZHIPU_API_URL = os.getenv("ZHIPU_API_URL", "https://open.bigmodel.cn/api/paas/v4/")
ZHIPU_API_TIMEOUT = int(os.getenv("ZHIPU_API_TIMEOUT", "60"))

0 comments on commit 1ec94fc

Please sign in to comment.