Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: gemini pro api #20 #21

Merged
merged 2 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ FROM python:3.10
RUN mkdir -p /usr/src/app/logs /usr/src/app/celery-logs
COPY . /usr/src/app
WORKDIR /usr/src/app
RUN pip3 install -U pip && pip3 install -r requirements.txt -i https://mirrors.cloud.tencent.com/pypi/simple
RUN pip3 install -U pip -i https://mirrors.cloud.tencent.com/pypi/simple && pip3 install -r requirements.txt -i https://mirrors.cloud.tencent.com/pypi/simple
RUN bin/proxy_gemini.sh
74 changes: 73 additions & 1 deletion apps/chat/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,25 @@
import json
from typing import List

import google.generativeai as genai
import openai
import requests
import tiktoken
from django.conf import settings
from django.contrib.auth import get_user_model
from django.db import transaction
from django.utils import timezone
from google.generativeai.types import GenerateContentResponse
from openai.openai_object import OpenAIObject
from requests import Response
from rest_framework.request import Request

from apps.chat.constants import (
AI_API_REQUEST_TIMEOUT,
HUNYUAN_DATA_PATTERN,
GeminiRole,
OpenAIModel,
OpenAIRole,
OpenAIUnitPrice,
)
from apps.chat.exceptions import UnexpectedError
Expand Down Expand Up @@ -108,7 +112,7 @@ def record(self, response: OpenAIObject, **kwargs) -> None:
def post_chat(self) -> None:
if not self.log:
return
# calculate tokens
# calculate tokens
encoding = tiktoken.encoding_for_model(self.model)
self.log.prompt_tokens = len(encoding.encode("".join([message["content"] for message in self.log.messages])))
self.log.completion_tokens = len(encoding.encode(self.log.content))
Expand Down Expand Up @@ -221,3 +225,71 @@ def call_api(self) -> Response:
settings.QCLOUD_HUNYUAN_API_URL, json=data, headers=headers, stream=True, timeout=AI_API_REQUEST_TIMEOUT
)
return resp


class GeminiClient(BaseClient):
"""
Gemini Pro
"""

# pylint: disable=R0913
def __init__(self, request: Request, model: str, messages: List[Message], temperature: float, top_p: float):
super().__init__(request=request, model=model, messages=messages, temperature=temperature, top_p=top_p)
genai.configure(api_key=settings.GEMINI_API_KEY)
self.genai_model = genai.GenerativeModel("gemini-pro")

@transaction.atomic()
def chat(self, *args, **kwargs) -> any:
self.created_at = int(timezone.now().timestamp() * 1000)
response = self.genai_model.generate_content(
contents=[
{"role": self.get_role(message["role"]), "parts": [message["content"]]} for message in self.messages
],
generation_config=genai.types.GenerationConfig(
temperature=self.temperature,
top_p=self.top_p,
),
stream=True,
)
for chunk in response:
self.record(response=chunk)
yield chunk.text
self.finished_at = int(timezone.now().timestamp() * 1000)
self.post_chat()

@classmethod
def get_role(cls, role: str) -> str:
if role == OpenAIRole.ASSISTANT:
return GeminiRole.MODEL
return GeminiRole.USER

# pylint: disable=W0221,R1710
def record(self, response: GenerateContentResponse, **kwargs) -> None:
# check log exist
if self.log:
self.log.content += response.text
return
# create log
self.log = ChatLog.objects.create(
user=self.user,
model=self.model,
messages=self.messages,
content="",
created_at=self.created_at,
)
return self.record(response=response)

def post_chat(self) -> None:
if not self.log:
return
# calculate characters
self.log.prompt_tokens = len("".join([message["content"] for message in self.log.messages]))
self.log.completion_tokens = len(self.log.content)
# calculate price
price = OpenAIUnitPrice.get_price(self.model)
self.log.prompt_token_unit_price = price.prompt_token_unit_price
self.log.completion_token_unit_price = price.completion_token_unit_price
# save
self.log.finished_at = self.finished_at
self.log.save()
self.log.remove_content()
11 changes: 11 additions & 0 deletions apps/chat/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ class OpenAIRole(TextChoices):
ASSISTANT = "assistant", gettext_lazy("Assistant")


class GeminiRole(TextChoices):
"""
Gemini Chat Role
"""

USER = "user", gettext_lazy("User")
MODEL = "model", gettext_lazy("Model")


class OpenAIModel(TextChoices):
"""
OpenAI Model
Expand All @@ -39,6 +48,7 @@ class OpenAIModel(TextChoices):
GPT4_TURBO = "gpt-4-1106-preview", "GPT4 Turbo"
GPT35_TURBO = "gpt-3.5-turbo", "GPT3.5 Turbo"
HUNYUAN = "hunyuan-plus", gettext_lazy("HunYuan Plus")
GEMINI = "gemini-pro", "Gemini Pro"

@classmethod
def get_name(cls, model: str) -> str:
Expand Down Expand Up @@ -69,6 +79,7 @@ class OpenAIUnitPrice:
OpenAIModel.GPT4_TURBO.value: OpenAIUnitPriceItem(0.01, 0.03),
OpenAIModel.GPT35_TURBO.value: OpenAIUnitPriceItem(0.001, 0.002),
OpenAIModel.HUNYUAN.value: OpenAIUnitPriceItem(round(0.10 / 7, 4), round(0.10 / 7, 4)),
OpenAIModel.GEMINI.value: OpenAIUnitPriceItem(0, 0),
}

@classmethod
Expand Down
4 changes: 3 additions & 1 deletion apps/chat/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from rest_framework.decorators import action
from rest_framework.response import Response

from apps.chat.client import HunYuanClient, OpenAIClient
from apps.chat.client import GeminiClient, HunYuanClient, OpenAIClient
from apps.chat.constants import OpenAIModel
from apps.chat.exceptions import VerifyFailed
from apps.chat.models import ChatLog, ModelPermission
Expand Down Expand Up @@ -46,6 +46,8 @@ def create(self, request, *args, **kwargs):
# call api
if request_data["model"] == OpenAIModel.HUNYUAN:
streaming_content = HunYuanClient(request=request, **request_data).chat()
elif request_data["model"] == OpenAIModel.GEMINI:
streaming_content = GeminiClient(request=request, **request_data).chat()
else:
streaming_content = OpenAIClient(request=request, **request_data).chat()

Expand Down
13 changes: 13 additions & 0 deletions bin/proxy_gemini.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
SITE_PACKAGES=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")

sed -i'.bak' '161a\
from django.conf import settings' "$SITE_PACKAGES/google/ai/generativelanguage_v1beta/services/generative_service/transports/grpc.py"

sed -i'.bak2' '175a\
("grpc.http_proxy", settings.OPENAI_HTTP_PROXY_URL),' "$SITE_PACKAGES/google/ai/generativelanguage_v1beta/services/generative_service/transports/grpc.py"

sed -i'.bak' '206a\
from django.conf import settings' "$SITE_PACKAGES/google/ai/generativelanguage_v1beta/services/generative_service/transports/grpc_asyncio.py"

sed -i'.bak2' '220a\
("grpc.http_proxy", settings.OPENAI_HTTP_PROXY_URL),' "$SITE_PACKAGES/google/ai/generativelanguage_v1beta/services/generative_service/transports/grpc_asyncio.py"
17 changes: 17 additions & 0 deletions bin/proxy_gemini_macos.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
SITE_PACKAGES=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")

sed -i'.bak' '161a\
from django.conf import settings\
' "$SITE_PACKAGES/google/ai/generativelanguage_v1beta/services/generative_service/transports/grpc.py"

sed -i'.bak2' '175a\
("grpc.http_proxy", settings.OPENAI_HTTP_PROXY_URL),\
' "$SITE_PACKAGES/google/ai/generativelanguage_v1beta/services/generative_service/transports/grpc.py"

sed -i'.bak' '206a\
from django.conf import settings\
' "$SITE_PACKAGES/google/ai/generativelanguage_v1beta/services/generative_service/transports/grpc_asyncio.py"

sed -i'.bak2' '220a\
("grpc.http_proxy", settings.OPENAI_HTTP_PROXY_URL),\
' "$SITE_PACKAGES/google/ai/generativelanguage_v1beta/services/generative_service/transports/grpc_asyncio.py"
3 changes: 3 additions & 0 deletions entry/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@
OPENAI_MAX_ALLOWED_TOKENS = int(os.getenv("OPENAI_MAX_ALLOWED_TOKENS", "4000"))
OPENAI_PRE_CHECK_TIMEOUT = int(os.getenv("OPENAI_PRE_CHECK_TIMEOUT", "600"))

# Gemini
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")

# QCLOUD
QCLOUD_APP_ID = int(os.getenv("QCLOUD_APP_ID", "0"))
QCLOUD_SECRET_ID = os.getenv("QCLOUD_SECRET_ID")
Expand Down
7 changes: 6 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@ uwsgi==2.0.22
arrow==1.2.3

# RSA
pycryptodome==3.15.0
pycryptodome==3.19.1

# Profile
pyinstrument==4.4.0

# OpenAI
openai==0.28.1
tiktoken==0.4.0

# Gemini
google-generativeai==0.3.2
google_api_core==2.15.0
google-ai-generativelanguage==0.4.0