Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: dall-e vision support #27 #28

Merged
merged 1 commit into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 52 additions & 9 deletions apps/chat/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import List, Union

import google.generativeai as genai
import openai
import qianfan
import requests
import tiktoken
Expand All @@ -16,7 +15,9 @@
from django.db import transaction
from django.utils import timezone
from google.generativeai.types import GenerateContentResponse
from openai.openai_object import OpenAIObject
from openai import AzureOpenAI
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.images_response import ImagesResponse
from qianfan import QfMessages, QfResponse
from requests import Response
from rest_framework.request import Request
Expand Down Expand Up @@ -79,27 +80,30 @@ class OpenAIClient(BaseClient):
@transaction.atomic()
def chat(self, *args, **kwargs) -> any:
self.created_at = int(timezone.now().timestamp() * 1000)
response = openai.ChatCompletion.create(
api_base=settings.OPENAI_API_BASE,
client = AzureOpenAI(
api_key=settings.OPENAI_API_KEY,
model=self.model,
api_version="2023-05-15",
azure_endpoint=settings.OPENAI_API_BASE,
)
response = client.chat.completions.create(
model=self.model.replace(".", ""),
messages=self.messages,
temperature=self.temperature,
top_p=self.top_p,
stream=True,
deployment_id=self.model.replace(".", ""),
)
# pylint: disable=E1133
for chunk in response:
self.record(response=chunk)
yield chunk.choices[0].delta.get("content", "")
yield chunk.choices[0].delta.content or ""
self.finished_at = int(timezone.now().timestamp() * 1000)
self.post_chat()

# pylint: disable=W0221,R1710
def record(self, response: OpenAIObject, **kwargs) -> None:
def record(self, response: ChatCompletionChunk, **kwargs) -> None:
# check log exist
if self.log:
self.log.content += response.choices[0].delta.get("content", "")
self.log.content += response.choices[0].delta.content or ""
return
# create log
self.log = ChatLog.objects.create(
Expand Down Expand Up @@ -128,6 +132,45 @@ def post_chat(self) -> None:
self.log.remove_content()


class OpenAIVisionClient(BaseClient):
"""
OpenAI Vision Client
"""

@transaction.atomic()
def chat(self, *args, **kwargs) -> any:
self.created_at = int(timezone.now().timestamp() * 1000)
client = AzureOpenAI(
api_key=settings.OPENAI_API_KEY,
api_version="2023-12-01-preview",
azure_endpoint=settings.OPENAI_API_BASE,
)
response = client.images.generate(
model=self.model.replace(".", ""),
prompt=self.messages[-1]["content"],
n=1,
size=self.model_inst.vision_size,
quality=self.model_inst.vision_quality,
style=self.model_inst.vision_style,
)
self.record(response=response)
return f"![{self.messages[-1]['content']}]({response.data[0].url})"

# pylint: disable=W0221,R1710
def record(self, response: ImagesResponse, **kwargs) -> None:
self.log = ChatLog.objects.create(
user=self.user,
model=self.model,
messages=self.messages,
content=response.data[0].url,
completion_tokens=1,
completion_token_unit_price=self.model_inst.completion_price,
created_at=self.created_at,
finished_at=int(timezone.now().timestamp() * 1000),
)
self.log.remove_content()


class HunYuanClient(BaseClient):
"""
Hun Yuan
Expand Down
27 changes: 26 additions & 1 deletion apps/chat/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

HUNYUAN_DATA_PATTERN = re.compile(rb"data:\s\{.*\}\n\n")


TOKEN_ENCODING = tiktoken.encoding_for_model("gpt-3.5-turbo")


Expand Down Expand Up @@ -50,3 +49,29 @@ class AIModelProvider(TextChoices):
GOOGLE = "google", gettext_lazy("Google")
BAIDU = "baidu", gettext_lazy("Baidu")
TENCENT = "tencent", gettext_lazy("Tencent")


class VisionSize(TextChoices):
"""
Vision Size
"""

S1024 = "1024x1024", gettext_lazy("1024x1024")


class VisionQuality(TextChoices):
"""
Vision Quality
"""

STANDARD = "standard", gettext_lazy("Standard")
HD = "hd", gettext_lazy("HD")


class VisionStyle(TextChoices):
"""
Vision Style
"""

VIVID = "vivid", gettext_lazy("Vivid")
NATURAL = "natural", gettext_lazy("Natural")
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# pylint: disable=C0103
# Generated by Django 4.2.3 on 2024-02-27 03:09

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("chat", "0005_aimodel"),
]

operations = [
migrations.AddField(
model_name="aimodel",
name="is_vision",
field=models.BooleanField(default=False, verbose_name="Is Vision"),
),
migrations.AddField(
model_name="aimodel",
name="vision_quality",
field=models.CharField(
blank=True,
choices=[("standard", "Standard"), ("hd", "HD")],
max_length=64,
null=True,
verbose_name="Vision Quality",
),
),
migrations.AddField(
model_name="aimodel",
name="vision_size",
field=models.CharField(
blank=True,
choices=[("256x256", "256x256"), ("512x512", "512x512"), ("1024x1024", "1024x1024")],
max_length=64,
null=True,
verbose_name="Vision Size",
),
),
migrations.AddField(
model_name="aimodel",
name="vision_style",
field=models.CharField(
blank=True,
choices=[("vivid", "Vivid"), ("nature", "Nature")],
max_length=64,
null=True,
verbose_name="Vision Style",
),
),
]
25 changes: 25 additions & 0 deletions apps/chat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
PRICE_DIGIT_NUMS,
AIModelProvider,
OpenAIRole,
VisionQuality,
VisionSize,
VisionStyle,
)

USER_MODEL = get_user_model()
Expand Down Expand Up @@ -184,6 +187,28 @@ class AIModel(BaseModel):
completion_price = models.DecimalField(
gettext_lazy("Completion Price"), max_digits=PRICE_DIGIT_NUMS, decimal_places=PRICE_DECIMAL_NUMS
)
is_vision = models.BooleanField(gettext_lazy("Is Vision"), default=False)
vision_size = models.CharField(
gettext_lazy("Vision Size"),
max_length=MEDIUM_CHAR_LENGTH,
choices=VisionSize.choices,
null=True,
blank=True,
)
vision_quality = models.CharField(
gettext_lazy("Vision Quality"),
max_length=MEDIUM_CHAR_LENGTH,
choices=VisionQuality.choices,
null=True,
blank=True,
)
vision_style = models.CharField(
gettext_lazy("Vision Style"),
max_length=MEDIUM_CHAR_LENGTH,
choices=VisionStyle.choices,
null=True,
blank=True,
)

class Meta:
verbose_name = gettext_lazy("AI Model")
Expand Down
22 changes: 19 additions & 3 deletions apps/chat/views.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from corsheaders.middleware import ACCESS_CONTROL_ALLOW_ORIGIN
from django.conf import settings
from django.core.cache import cache
from django.http import StreamingHttpResponse
from django.http import HttpResponse, StreamingHttpResponse
from django.shortcuts import get_object_or_404
from ovinc_client.core.utils import uniq_id
from ovinc_client.core.viewsets import CreateMixin, ListMixin, MainViewSet
from rest_framework.decorators import action
from rest_framework.response import Response

from apps.chat.client import GeminiClient, HunYuanClient, OpenAIClient, QianfanClient
from apps.chat.client import (
GeminiClient,
HunYuanClient,
OpenAIClient,
OpenAIVisionClient,
QianfanClient,
)
from apps.chat.constants import AIModelProvider
from apps.chat.exceptions import UnexpectedProvider, VerifyFailed
from apps.chat.models import AIModel, ChatLog, ModelPermission
Expand Down Expand Up @@ -56,11 +62,21 @@ def create(self, request, *args, **kwargs):
case AIModelProvider.BAIDU:
streaming_content = QianfanClient(request=request, **request_data).chat()
case AIModelProvider.OPENAI:
streaming_content = OpenAIClient(request=request, **request_data).chat()
if model.is_vision:
content = OpenAIVisionClient(request=request, **request_data).chat()
else:
streaming_content = OpenAIClient(request=request, **request_data).chat()
case _:
raise UnexpectedProvider()

# response
if model.is_vision:
return HttpResponse(
content=content.encode("utf-8"),
headers={
"Trace-ID": getattr(request, "otel_trace_id", ""),
},
)
return StreamingHttpResponse(
streaming_content=streaming_content,
headers={
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ pycryptodome==3.19.1
pyinstrument==4.4.0

# OpenAI
openai==0.28.1
openai==1.12.0
tiktoken==0.4.0
pillow==10.2.0

# Gemini
google-generativeai==0.3.2
Expand Down
Loading