Skip to content

Commit

Permalink
Merge pull request #23 from OVINC-CN/feat_ernie_bot
Browse files Browse the repository at this point in the history
feat: support for baidu ernie-bot #22
  • Loading branch information
OrenZhang authored Feb 4, 2024
2 parents 78eda22 + b528ad1 commit 732deab
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 4 deletions.
65 changes: 62 additions & 3 deletions apps/chat/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import hashlib
import hmac
import json
from typing import List
from typing import List, Union

import google.generativeai as genai
import openai
import qianfan
import requests
import tiktoken
from django.conf import settings
Expand All @@ -16,6 +17,7 @@
from django.utils import timezone
from google.generativeai.types import GenerateContentResponse
from openai.openai_object import OpenAIObject
from qianfan import QfMessages, QfResponse
from requests import Response
from rest_framework.request import Request

Expand All @@ -40,12 +42,14 @@ class BaseClient:
"""

# pylint: disable=R0913
def __init__(self, request: Request, model: str, messages: List[Message], temperature: float, top_p: float):
def __init__(
self, request: Request, model: str, messages: Union[List[Message], QfMessages], temperature: float, top_p: float
):
self.log: ChatLog = None
self.request: Request = request
self.user: USER_MODEL = request.user
self.model: str = model
self.messages: List[Message] = messages
self.messages: Union[List[Message], QfMessages] = messages
self.temperature: float = temperature
self.top_p: float = top_p
self.created_at: int = int()
Expand Down Expand Up @@ -293,3 +297,58 @@ def post_chat(self) -> None:
self.log.finished_at = self.finished_at
self.log.save()
self.log.remove_content()


class QianfanClient(BaseClient):
"""
Baidu Qianfan
"""

@transaction.atomic()
def chat(self, *args, **kwargs) -> any:
self.created_at = int(timezone.now().timestamp() * 1000)
client = qianfan.ChatCompletion(ak=settings.QIANFAN_ACCESS_KEY, sk=settings.QIANFAN_SECRET_KEY)
response = client.do(
model=self.model,
messages=self.messages,
temperature=self.temperature,
top_p=self.top_p,
stream=True,
)
for chunk in response:
self.record(response=chunk)
yield chunk.body.get("result", "")
self.finished_at = int(timezone.now().timestamp() * 1000)
self.post_chat()

# pylint: disable=W0221,R1710
def record(self, response: QfResponse, **kwargs) -> None:
# check log exist
if self.log:
self.log.content += response.body.get("result", "")
usage = response.body.get("usage", {})
self.log.prompt_tokens = usage.get("prompt_tokens", 0)
self.log.completion_tokens = usage.get("completion_tokens", 0)
return
# create log
self.log = ChatLog.objects.create(
chat_id=response.body.get("id", ""),
user=self.user,
model=self.model,
messages=self.messages,
content="",
created_at=self.created_at,
)
return self.record(response=response)

def post_chat(self) -> None:
if not self.log:
return
# calculate price
price = OpenAIUnitPrice.get_price(self.model)
self.log.prompt_token_unit_price = price.prompt_token_unit_price
self.log.completion_token_unit_price = price.completion_token_unit_price
# save
self.log.finished_at = self.finished_at
self.log.save()
self.log.remove_content()
8 changes: 8 additions & 0 deletions apps/chat/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ class OpenAIModel(TextChoices):
GPT35_TURBO = "gpt-3.5-turbo", "GPT3.5 Turbo"
HUNYUAN = "hunyuan-plus", gettext_lazy("HunYuan Plus")
GEMINI = "gemini-pro", "Gemini Pro"
ERNIE_BOT_4_0 = "ERNIE-Bot-4", "ERNIE-Bot 4.0"
ERNIE_BOT_8K = "ERNIE-Bot-8k", "ERNIE-Bot 8K"
ERNIE_BOT = "ERNIE-Bot", "ERNIE-Bot"
ERNIE_BOT_TURBO = "ERNIE-Bot-turbo-AI", "ERNIE-Bot Turbo"

@classmethod
def get_name(cls, model: str) -> str:
Expand Down Expand Up @@ -80,6 +84,10 @@ class OpenAIUnitPrice:
OpenAIModel.GPT35_TURBO.value: OpenAIUnitPriceItem(0.001, 0.002),
OpenAIModel.HUNYUAN.value: OpenAIUnitPriceItem(round(0.10 / 7, 4), round(0.10 / 7, 4)),
OpenAIModel.GEMINI.value: OpenAIUnitPriceItem(0, 0),
OpenAIModel.ERNIE_BOT_4_0.value: OpenAIUnitPriceItem(round(0.12 / 7, 4), round(0.12 / 7, 4)),
OpenAIModel.ERNIE_BOT_8K.value: OpenAIUnitPriceItem(round(0.024 / 7, 4), round(0.048 / 7, 4)),
OpenAIModel.ERNIE_BOT.value: OpenAIUnitPriceItem(round(0.012 / 7, 4), round(0.012 / 7, 4)),
OpenAIModel.ERNIE_BOT_TURBO.value: OpenAIUnitPriceItem(round(0.008 / 7, 4), round(0.008 / 7, 4)),
}

@classmethod
Expand Down
4 changes: 3 additions & 1 deletion apps/chat/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from rest_framework.decorators import action
from rest_framework.response import Response

from apps.chat.client import GeminiClient, HunYuanClient, OpenAIClient
from apps.chat.client import GeminiClient, HunYuanClient, OpenAIClient, QianfanClient
from apps.chat.constants import OpenAIModel
from apps.chat.exceptions import VerifyFailed
from apps.chat.models import ChatLog, ModelPermission
Expand Down Expand Up @@ -48,6 +48,8 @@ def create(self, request, *args, **kwargs):
streaming_content = HunYuanClient(request=request, **request_data).chat()
elif request_data["model"] == OpenAIModel.GEMINI:
streaming_content = GeminiClient(request=request, **request_data).chat()
elif request_data["model"].startswith(OpenAIModel.ERNIE_BOT):
streaming_content = QianfanClient(request=request, **request_data).chat()
else:
streaming_content = OpenAIClient(request=request, **request_data).chat()

Expand Down
4 changes: 4 additions & 0 deletions entry/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,5 +210,9 @@
"QCLOUD_HUNYUAN_API_DOMAIN", "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
)

# Baidu Qianfan
QIANFAN_ACCESS_KEY = os.getenv("QIANFAN_ACCESS_KEY", "")
QIANFAN_SECRET_KEY = os.getenv("QIANFAN_SECRET_KEY", "")

# Log
RECORD_CHAT_CONTENT = strtobool(os.getenv("RECORD_CHAT_CONTENT", "False"))
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ tiktoken==0.4.0
google-generativeai==0.3.2
google_api_core==2.15.0
google-ai-generativelanguage==0.4.0

# Baidu
qianfan==0.3.0

0 comments on commit 732deab

Please sign in to comment.