diff --git a/Dockerfile b/Dockerfile index 14f33d6..27e267b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,4 +2,5 @@ FROM python:3.10 RUN mkdir -p /usr/src/app/logs /usr/src/app/celery-logs COPY . /usr/src/app WORKDIR /usr/src/app -RUN pip3 install -U pip && pip3 install -r requirements.txt -i https://mirrors.cloud.tencent.com/pypi/simple +RUN pip3 install -U pip -i https://mirrors.cloud.tencent.com/pypi/simple && pip3 install -r requirements.txt -i https://mirrors.cloud.tencent.com/pypi/simple +RUN bin/proxy_gemini.sh diff --git a/apps/chat/client.py b/apps/chat/client.py index 40788e1..24d8a0d 100644 --- a/apps/chat/client.py +++ b/apps/chat/client.py @@ -6,6 +6,7 @@ import json from typing import List +import google.generativeai as genai import openai import requests import tiktoken @@ -13,6 +14,7 @@ from django.contrib.auth import get_user_model from django.db import transaction from django.utils import timezone +from google.generativeai.types import GenerateContentResponse from openai.openai_object import OpenAIObject from requests import Response from rest_framework.request import Request @@ -20,7 +22,9 @@ from apps.chat.constants import ( AI_API_REQUEST_TIMEOUT, HUNYUAN_DATA_PATTERN, + GeminiRole, OpenAIModel, + OpenAIRole, OpenAIUnitPrice, ) from apps.chat.exceptions import UnexpectedError @@ -108,7 +112,7 @@ def record(self, response: OpenAIObject, **kwargs) -> None: def post_chat(self) -> None: if not self.log: return - # calculate tokens + # calculate tokens encoding = tiktoken.encoding_for_model(self.model) self.log.prompt_tokens = len(encoding.encode("".join([message["content"] for message in self.log.messages]))) self.log.completion_tokens = len(encoding.encode(self.log.content)) @@ -221,3 +225,71 @@ def call_api(self) -> Response: settings.QCLOUD_HUNYUAN_API_URL, json=data, headers=headers, stream=True, timeout=AI_API_REQUEST_TIMEOUT ) return resp + + +class GeminiClient(BaseClient): + """ + Gemini Pro + """ + + # pylint: disable=R0913 + def __init__(self, request: Request, model: str, messages: List[Message], temperature: float, top_p: float): + super().__init__(request=request, model=model, messages=messages, temperature=temperature, top_p=top_p) + genai.configure(api_key=settings.GEMINI_API_KEY) + self.genai_model = genai.GenerativeModel("gemini-pro") + + @transaction.atomic() + def chat(self, *args, **kwargs) -> any: + self.created_at = int(timezone.now().timestamp() * 1000) + response = self.genai_model.generate_content( + contents=[ + {"role": self.get_role(message["role"]), "parts": [message["content"]]} for message in self.messages + ], + generation_config=genai.types.GenerationConfig( + temperature=self.temperature, + top_p=self.top_p, + ), + stream=True, + ) + for chunk in response: + self.record(response=chunk) + yield chunk.text + self.finished_at = int(timezone.now().timestamp() * 1000) + self.post_chat() + + @classmethod + def get_role(cls, role: str) -> str: + if role == OpenAIRole.ASSISTANT: + return GeminiRole.MODEL + return GeminiRole.USER + + # pylint: disable=W0221,R1710 + def record(self, response: GenerateContentResponse, **kwargs) -> None: + # check log exist + if self.log: + self.log.content += response.text + return + # create log + self.log = ChatLog.objects.create( + user=self.user, + model=self.model, + messages=self.messages, + content="", + created_at=self.created_at, + ) + return self.record(response=response) + + def post_chat(self) -> None: + if not self.log: + return + # calculate characters + self.log.prompt_tokens = len("".join([message["content"] for message in self.log.messages])) + self.log.completion_tokens = len(self.log.content) + # calculate price + price = OpenAIUnitPrice.get_price(self.model) + self.log.prompt_token_unit_price = price.prompt_token_unit_price + self.log.completion_token_unit_price = price.completion_token_unit_price + # save + self.log.finished_at = self.finished_at + self.log.save() + self.log.remove_content() diff --git a/apps/chat/constants.py b/apps/chat/constants.py index 204e6ec..b1d89a0 100644 --- a/apps/chat/constants.py +++ b/apps/chat/constants.py @@ -29,6 +29,15 @@ class OpenAIRole(TextChoices): ASSISTANT = "assistant", gettext_lazy("Assistant") +class GeminiRole(TextChoices): + """ + Gemini Chat Role + """ + + USER = "user", gettext_lazy("User") + MODEL = "model", gettext_lazy("Model") + + class OpenAIModel(TextChoices): """ OpenAI Model @@ -39,6 +48,7 @@ class OpenAIModel(TextChoices): GPT4_TURBO = "gpt-4-1106-preview", "GPT4 Turbo" GPT35_TURBO = "gpt-3.5-turbo", "GPT3.5 Turbo" HUNYUAN = "hunyuan-plus", gettext_lazy("HunYuan Plus") + GEMINI = "gemini-pro", "Gemini Pro" @classmethod def get_name(cls, model: str) -> str: @@ -69,6 +79,7 @@ class OpenAIUnitPrice: OpenAIModel.GPT4_TURBO.value: OpenAIUnitPriceItem(0.01, 0.03), OpenAIModel.GPT35_TURBO.value: OpenAIUnitPriceItem(0.001, 0.002), OpenAIModel.HUNYUAN.value: OpenAIUnitPriceItem(round(0.10 / 7, 4), round(0.10 / 7, 4)), + OpenAIModel.GEMINI.value: OpenAIUnitPriceItem(0, 0), } @classmethod diff --git a/apps/chat/views.py b/apps/chat/views.py index 88df524..bcd8704 100644 --- a/apps/chat/views.py +++ b/apps/chat/views.py @@ -7,7 +7,7 @@ from rest_framework.decorators import action from rest_framework.response import Response -from apps.chat.client import HunYuanClient, OpenAIClient +from apps.chat.client import GeminiClient, HunYuanClient, OpenAIClient from apps.chat.constants import OpenAIModel from apps.chat.exceptions import VerifyFailed from apps.chat.models import ChatLog, ModelPermission @@ -46,6 +46,8 @@ def create(self, request, *args, **kwargs): # call api if request_data["model"] == OpenAIModel.HUNYUAN: streaming_content = HunYuanClient(request=request, **request_data).chat() + elif request_data["model"] == OpenAIModel.GEMINI: + streaming_content = GeminiClient(request=request, **request_data).chat() else: streaming_content = OpenAIClient(request=request, **request_data).chat() diff --git a/bin/proxy_gemini.sh b/bin/proxy_gemini.sh new file mode 100755 index 0000000..12e4526 --- /dev/null +++ b/bin/proxy_gemini.sh @@ -0,0 +1,13 @@ +SITE_PACKAGES=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())") + +sed -i'.bak' '161a\ + from django.conf import settings' "$SITE_PACKAGES/google/ai/generativelanguage_v1beta/services/generative_service/transports/grpc.py" + +sed -i'.bak2' '175a\ + ("grpc.http_proxy", settings.OPENAI_HTTP_PROXY_URL),' "$SITE_PACKAGES/google/ai/generativelanguage_v1beta/services/generative_service/transports/grpc.py" + +sed -i'.bak' '206a\ + from django.conf import settings' "$SITE_PACKAGES/google/ai/generativelanguage_v1beta/services/generative_service/transports/grpc_asyncio.py" + +sed -i'.bak2' '220a\ + ("grpc.http_proxy", settings.OPENAI_HTTP_PROXY_URL),' "$SITE_PACKAGES/google/ai/generativelanguage_v1beta/services/generative_service/transports/grpc_asyncio.py" diff --git a/bin/proxy_gemini_macos.sh b/bin/proxy_gemini_macos.sh new file mode 100755 index 0000000..3fe4c4b --- /dev/null +++ b/bin/proxy_gemini_macos.sh @@ -0,0 +1,17 @@ +SITE_PACKAGES=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())") + +sed -i'.bak' '161a\ + from django.conf import settings\ +' "$SITE_PACKAGES/google/ai/generativelanguage_v1beta/services/generative_service/transports/grpc.py" + +sed -i'.bak2' '175a\ + ("grpc.http_proxy", settings.OPENAI_HTTP_PROXY_URL),\ +' "$SITE_PACKAGES/google/ai/generativelanguage_v1beta/services/generative_service/transports/grpc.py" + +sed -i'.bak' '206a\ + from django.conf import settings\ +' "$SITE_PACKAGES/google/ai/generativelanguage_v1beta/services/generative_service/transports/grpc_asyncio.py" + +sed -i'.bak2' '220a\ + ("grpc.http_proxy", settings.OPENAI_HTTP_PROXY_URL),\ +' "$SITE_PACKAGES/google/ai/generativelanguage_v1beta/services/generative_service/transports/grpc_asyncio.py" diff --git a/entry/settings.py b/entry/settings.py index 3d2f841..ee28051 100644 --- a/entry/settings.py +++ b/entry/settings.py @@ -199,6 +199,9 @@ OPENAI_MAX_ALLOWED_TOKENS = int(os.getenv("OPENAI_MAX_ALLOWED_TOKENS", "4000")) OPENAI_PRE_CHECK_TIMEOUT = int(os.getenv("OPENAI_PRE_CHECK_TIMEOUT", "600")) +# Gemini +GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "") + # QCLOUD QCLOUD_APP_ID = int(os.getenv("QCLOUD_APP_ID", "0")) QCLOUD_SECRET_ID = os.getenv("QCLOUD_SECRET_ID") diff --git a/requirements.txt b/requirements.txt index 3f61963..585eba9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,7 @@ uwsgi==2.0.22 arrow==1.2.3 # RSA -pycryptodome==3.15.0 +pycryptodome==3.19.1 # Profile pyinstrument==4.4.0 @@ -19,3 +19,8 @@ pyinstrument==4.4.0 # OpenAI openai==0.28.1 tiktoken==0.4.0 + +# Gemini +google-generativeai==0.3.2 +google_api_core==2.15.0 +google-ai-generativelanguage==0.4.0