Skip to content

Commit

Permalink
feat: dall-e vision support #27
Browse files Browse the repository at this point in the history
  • Loading branch information
OrenZhang committed Feb 27, 2024
1 parent 8448cef commit 1038ea8
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 14 deletions.
61 changes: 52 additions & 9 deletions apps/chat/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import List, Union

import google.generativeai as genai
import openai
import qianfan
import requests
import tiktoken
Expand All @@ -16,7 +15,9 @@
from django.db import transaction
from django.utils import timezone
from google.generativeai.types import GenerateContentResponse
from openai.openai_object import OpenAIObject
from openai import AzureOpenAI
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.images_response import ImagesResponse
from qianfan import QfMessages, QfResponse
from requests import Response
from rest_framework.request import Request
Expand Down Expand Up @@ -79,27 +80,30 @@ class OpenAIClient(BaseClient):
@transaction.atomic()
def chat(self, *args, **kwargs) -> any:
self.created_at = int(timezone.now().timestamp() * 1000)
response = openai.ChatCompletion.create(
api_base=settings.OPENAI_API_BASE,
client = AzureOpenAI(
api_key=settings.OPENAI_API_KEY,
model=self.model,
api_version="2023-05-15",
azure_endpoint=settings.OPENAI_API_BASE,
)
response = client.chat.completions.create(
model=self.model.replace(".", ""),
messages=self.messages,
temperature=self.temperature,
top_p=self.top_p,
stream=True,
deployment_id=self.model.replace(".", ""),
)
# pylint: disable=E1133
for chunk in response:
self.record(response=chunk)
yield chunk.choices[0].delta.get("content", "")
yield chunk.choices[0].delta.content or ""
self.finished_at = int(timezone.now().timestamp() * 1000)
self.post_chat()

# pylint: disable=W0221,R1710
def record(self, response: OpenAIObject, **kwargs) -> None:
def record(self, response: ChatCompletionChunk, **kwargs) -> None:
# check log exist
if self.log:
self.log.content += response.choices[0].delta.get("content", "")
self.log.content += response.choices[0].delta.content or ""
return
# create log
self.log = ChatLog.objects.create(
Expand Down Expand Up @@ -128,6 +132,45 @@ def post_chat(self) -> None:
self.log.remove_content()


class OpenAIVisionClient(BaseClient):
"""
OpenAI Vision Client
"""

@transaction.atomic()
def chat(self, *args, **kwargs) -> any:
self.created_at = int(timezone.now().timestamp() * 1000)
client = AzureOpenAI(
api_key=settings.OPENAI_API_KEY,
api_version="2023-12-01-preview",
azure_endpoint=settings.OPENAI_API_BASE,
)
response = client.images.generate(
model=self.model.replace(".", ""),
prompt=self.messages[-1]["content"],
n=1,
size=self.model_inst.vision_size,
quality=self.model_inst.vision_quality,
style=self.model_inst.vision_style,
)
self.record(response=response)
return f"![{self.messages[-1]['content']}]({response.data[0].url})"

# pylint: disable=W0221,R1710
def record(self, response: ImagesResponse, **kwargs) -> None:
self.log = ChatLog.objects.create(
user=self.user,
model=self.model,
messages=self.messages,
content=response.data[0].url,
completion_tokens=1,
completion_token_unit_price=self.model_inst.completion_price,
created_at=self.created_at,
finished_at=int(timezone.now().timestamp() * 1000),
)
self.log.remove_content()


class HunYuanClient(BaseClient):
"""
Hun Yuan
Expand Down
27 changes: 26 additions & 1 deletion apps/chat/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

HUNYUAN_DATA_PATTERN = re.compile(rb"data:\s\{.*\}\n\n")


TOKEN_ENCODING = tiktoken.encoding_for_model("gpt-3.5-turbo")


Expand Down Expand Up @@ -50,3 +49,29 @@ class AIModelProvider(TextChoices):
GOOGLE = "google", gettext_lazy("Google")
BAIDU = "baidu", gettext_lazy("Baidu")
TENCENT = "tencent", gettext_lazy("Tencent")


class VisionSize(TextChoices):
"""
Vision Size
"""

S1024 = "1024x1024", gettext_lazy("1024x1024")


class VisionQuality(TextChoices):
"""
Vision Quality
"""

STANDARD = "standard", gettext_lazy("Standard")
HD = "hd", gettext_lazy("HD")


class VisionStyle(TextChoices):
"""
Vision Style
"""

VIVID = "vivid", gettext_lazy("Vivid")
NATURAL = "natural", gettext_lazy("Natural")
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# pylint: disable=C0103
# Generated by Django 4.2.3 on 2024-02-27 03:09

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("chat", "0005_aimodel"),
]

operations = [
migrations.AddField(
model_name="aimodel",
name="is_vision",
field=models.BooleanField(default=False, verbose_name="Is Vision"),
),
migrations.AddField(
model_name="aimodel",
name="vision_quality",
field=models.CharField(
blank=True,
choices=[("standard", "Standard"), ("hd", "HD")],
max_length=64,
null=True,
verbose_name="Vision Quality",
),
),
migrations.AddField(
model_name="aimodel",
name="vision_size",
field=models.CharField(
blank=True,
choices=[("256x256", "256x256"), ("512x512", "512x512"), ("1024x1024", "1024x1024")],
max_length=64,
null=True,
verbose_name="Vision Size",
),
),
migrations.AddField(
model_name="aimodel",
name="vision_style",
field=models.CharField(
blank=True,
choices=[("vivid", "Vivid"), ("nature", "Nature")],
max_length=64,
null=True,
verbose_name="Vision Style",
),
),
]
25 changes: 25 additions & 0 deletions apps/chat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
PRICE_DIGIT_NUMS,
AIModelProvider,
OpenAIRole,
VisionQuality,
VisionSize,
VisionStyle,
)

USER_MODEL = get_user_model()
Expand Down Expand Up @@ -184,6 +187,28 @@ class AIModel(BaseModel):
completion_price = models.DecimalField(
gettext_lazy("Completion Price"), max_digits=PRICE_DIGIT_NUMS, decimal_places=PRICE_DECIMAL_NUMS
)
is_vision = models.BooleanField(gettext_lazy("Is Vision"), default=False)
vision_size = models.CharField(
gettext_lazy("Vision Size"),
max_length=MEDIUM_CHAR_LENGTH,
choices=VisionSize.choices,
null=True,
blank=True,
)
vision_quality = models.CharField(
gettext_lazy("Vision Quality"),
max_length=MEDIUM_CHAR_LENGTH,
choices=VisionQuality.choices,
null=True,
blank=True,
)
vision_style = models.CharField(
gettext_lazy("Vision Style"),
max_length=MEDIUM_CHAR_LENGTH,
choices=VisionStyle.choices,
null=True,
blank=True,
)

class Meta:
verbose_name = gettext_lazy("AI Model")
Expand Down
22 changes: 19 additions & 3 deletions apps/chat/views.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from corsheaders.middleware import ACCESS_CONTROL_ALLOW_ORIGIN
from django.conf import settings
from django.core.cache import cache
from django.http import StreamingHttpResponse
from django.http import HttpResponse, StreamingHttpResponse
from django.shortcuts import get_object_or_404
from ovinc_client.core.utils import uniq_id
from ovinc_client.core.viewsets import CreateMixin, ListMixin, MainViewSet
from rest_framework.decorators import action
from rest_framework.response import Response

from apps.chat.client import GeminiClient, HunYuanClient, OpenAIClient, QianfanClient
from apps.chat.client import (
GeminiClient,
HunYuanClient,
OpenAIClient,
OpenAIVisionClient,
QianfanClient,
)
from apps.chat.constants import AIModelProvider
from apps.chat.exceptions import UnexpectedProvider, VerifyFailed
from apps.chat.models import AIModel, ChatLog, ModelPermission
Expand Down Expand Up @@ -56,11 +62,21 @@ def create(self, request, *args, **kwargs):
case AIModelProvider.BAIDU:
streaming_content = QianfanClient(request=request, **request_data).chat()
case AIModelProvider.OPENAI:
streaming_content = OpenAIClient(request=request, **request_data).chat()
if model.is_vision:
content = OpenAIVisionClient(request=request, **request_data).chat()
else:
streaming_content = OpenAIClient(request=request, **request_data).chat()
case _:
raise UnexpectedProvider()

# response
if model.is_vision:
return HttpResponse(
content=content.encode("utf-8"),
headers={
"Trace-ID": getattr(request, "otel_trace_id", ""),
},
)
return StreamingHttpResponse(
streaming_content=streaming_content,
headers={
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ pycryptodome==3.19.1
pyinstrument==4.4.0

# OpenAI
openai==0.28.1
openai==1.12.0
tiktoken==0.4.0
pillow==10.2.0

# Gemini
google-generativeai==0.3.2
Expand Down

0 comments on commit 1038ea8

Please sign in to comment.