From 1038ea86452d15700613f001cde711c22cab15c5 Mon Sep 17 00:00:00 2001 From: orenzhang Date: Tue, 27 Feb 2024 11:17:44 +0800 Subject: [PATCH] feat: dall-e vision support #27 --- apps/chat/client.py | 61 ++++++++++++++++--- apps/chat/constants.py | 27 +++++++- ..._vision_aimodel_vision_quality_and_more.py | 52 ++++++++++++++++ apps/chat/models.py | 25 ++++++++ apps/chat/views.py | 22 ++++++- requirements.txt | 3 +- 6 files changed, 176 insertions(+), 14 deletions(-) create mode 100644 apps/chat/migrations/0006_aimodel_is_vision_aimodel_vision_quality_and_more.py diff --git a/apps/chat/client.py b/apps/chat/client.py index 3494aab..9c89681 100644 --- a/apps/chat/client.py +++ b/apps/chat/client.py @@ -7,7 +7,6 @@ from typing import List, Union import google.generativeai as genai -import openai import qianfan import requests import tiktoken @@ -16,7 +15,9 @@ from django.db import transaction from django.utils import timezone from google.generativeai.types import GenerateContentResponse -from openai.openai_object import OpenAIObject +from openai import AzureOpenAI +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk +from openai.types.images_response import ImagesResponse from qianfan import QfMessages, QfResponse from requests import Response from rest_framework.request import Request @@ -79,27 +80,30 @@ class OpenAIClient(BaseClient): @transaction.atomic() def chat(self, *args, **kwargs) -> any: self.created_at = int(timezone.now().timestamp() * 1000) - response = openai.ChatCompletion.create( - api_base=settings.OPENAI_API_BASE, + client = AzureOpenAI( api_key=settings.OPENAI_API_KEY, - model=self.model, + api_version="2023-05-15", + azure_endpoint=settings.OPENAI_API_BASE, + ) + response = client.chat.completions.create( + model=self.model.replace(".", ""), messages=self.messages, temperature=self.temperature, top_p=self.top_p, stream=True, - deployment_id=self.model.replace(".", ""), ) + # pylint: disable=E1133 for chunk in response: self.record(response=chunk) - yield chunk.choices[0].delta.get("content", "") + yield chunk.choices[0].delta.content or "" self.finished_at = int(timezone.now().timestamp() * 1000) self.post_chat() # pylint: disable=W0221,R1710 - def record(self, response: OpenAIObject, **kwargs) -> None: + def record(self, response: ChatCompletionChunk, **kwargs) -> None: # check log exist if self.log: - self.log.content += response.choices[0].delta.get("content", "") + self.log.content += response.choices[0].delta.content or "" return # create log self.log = ChatLog.objects.create( @@ -128,6 +132,45 @@ def post_chat(self) -> None: self.log.remove_content() +class OpenAIVisionClient(BaseClient): + """ + OpenAI Vision Client + """ + + @transaction.atomic() + def chat(self, *args, **kwargs) -> any: + self.created_at = int(timezone.now().timestamp() * 1000) + client = AzureOpenAI( + api_key=settings.OPENAI_API_KEY, + api_version="2023-12-01-preview", + azure_endpoint=settings.OPENAI_API_BASE, + ) + response = client.images.generate( + model=self.model.replace(".", ""), + prompt=self.messages[-1]["content"], + n=1, + size=self.model_inst.vision_size, + quality=self.model_inst.vision_quality, + style=self.model_inst.vision_style, + ) + self.record(response=response) + return f"![{self.messages[-1]['content']}]({response.data[0].url})" + + # pylint: disable=W0221,R1710 + def record(self, response: ImagesResponse, **kwargs) -> None: + self.log = ChatLog.objects.create( + user=self.user, + model=self.model, + messages=self.messages, + content=response.data[0].url, + completion_tokens=1, + completion_token_unit_price=self.model_inst.completion_price, + created_at=self.created_at, + finished_at=int(timezone.now().timestamp() * 1000), + ) + self.log.remove_content() + + class HunYuanClient(BaseClient): """ Hun Yuan diff --git a/apps/chat/constants.py b/apps/chat/constants.py index 51bdb54..dabf29d 100644 --- a/apps/chat/constants.py +++ b/apps/chat/constants.py @@ -18,7 +18,6 @@ HUNYUAN_DATA_PATTERN = re.compile(rb"data:\s\{.*\}\n\n") - TOKEN_ENCODING = tiktoken.encoding_for_model("gpt-3.5-turbo") @@ -50,3 +49,29 @@ class AIModelProvider(TextChoices): GOOGLE = "google", gettext_lazy("Google") BAIDU = "baidu", gettext_lazy("Baidu") TENCENT = "tencent", gettext_lazy("Tencent") + + +class VisionSize(TextChoices): + """ + Vision Size + """ + + S1024 = "1024x1024", gettext_lazy("1024x1024") + + +class VisionQuality(TextChoices): + """ + Vision Quality + """ + + STANDARD = "standard", gettext_lazy("Standard") + HD = "hd", gettext_lazy("HD") + + +class VisionStyle(TextChoices): + """ + Vision Style + """ + + VIVID = "vivid", gettext_lazy("Vivid") + NATURAL = "natural", gettext_lazy("Natural") diff --git a/apps/chat/migrations/0006_aimodel_is_vision_aimodel_vision_quality_and_more.py b/apps/chat/migrations/0006_aimodel_is_vision_aimodel_vision_quality_and_more.py new file mode 100644 index 0000000..a8e16bc --- /dev/null +++ b/apps/chat/migrations/0006_aimodel_is_vision_aimodel_vision_quality_and_more.py @@ -0,0 +1,52 @@ +# pylint: disable=C0103 +# Generated by Django 4.2.3 on 2024-02-27 03:09 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("chat", "0005_aimodel"), + ] + + operations = [ + migrations.AddField( + model_name="aimodel", + name="is_vision", + field=models.BooleanField(default=False, verbose_name="Is Vision"), + ), + migrations.AddField( + model_name="aimodel", + name="vision_quality", + field=models.CharField( + blank=True, + choices=[("standard", "Standard"), ("hd", "HD")], + max_length=64, + null=True, + verbose_name="Vision Quality", + ), + ), + migrations.AddField( + model_name="aimodel", + name="vision_size", + field=models.CharField( + blank=True, + choices=[("256x256", "256x256"), ("512x512", "512x512"), ("1024x1024", "1024x1024")], + max_length=64, + null=True, + verbose_name="Vision Size", + ), + ), + migrations.AddField( + model_name="aimodel", + name="vision_style", + field=models.CharField( + blank=True, + choices=[("vivid", "Vivid"), ("nature", "Nature")], + max_length=64, + null=True, + verbose_name="Vision Style", + ), + ), + ] diff --git a/apps/chat/models.py b/apps/chat/models.py index d053841..92cad0c 100644 --- a/apps/chat/models.py +++ b/apps/chat/models.py @@ -15,6 +15,9 @@ PRICE_DIGIT_NUMS, AIModelProvider, OpenAIRole, + VisionQuality, + VisionSize, + VisionStyle, ) USER_MODEL = get_user_model() @@ -184,6 +187,28 @@ class AIModel(BaseModel): completion_price = models.DecimalField( gettext_lazy("Completion Price"), max_digits=PRICE_DIGIT_NUMS, decimal_places=PRICE_DECIMAL_NUMS ) + is_vision = models.BooleanField(gettext_lazy("Is Vision"), default=False) + vision_size = models.CharField( + gettext_lazy("Vision Size"), + max_length=MEDIUM_CHAR_LENGTH, + choices=VisionSize.choices, + null=True, + blank=True, + ) + vision_quality = models.CharField( + gettext_lazy("Vision Quality"), + max_length=MEDIUM_CHAR_LENGTH, + choices=VisionQuality.choices, + null=True, + blank=True, + ) + vision_style = models.CharField( + gettext_lazy("Vision Style"), + max_length=MEDIUM_CHAR_LENGTH, + choices=VisionStyle.choices, + null=True, + blank=True, + ) class Meta: verbose_name = gettext_lazy("AI Model") diff --git a/apps/chat/views.py b/apps/chat/views.py index ca1703d..fc90b47 100644 --- a/apps/chat/views.py +++ b/apps/chat/views.py @@ -1,14 +1,20 @@ from corsheaders.middleware import ACCESS_CONTROL_ALLOW_ORIGIN from django.conf import settings from django.core.cache import cache -from django.http import StreamingHttpResponse +from django.http import HttpResponse, StreamingHttpResponse from django.shortcuts import get_object_or_404 from ovinc_client.core.utils import uniq_id from ovinc_client.core.viewsets import CreateMixin, ListMixin, MainViewSet from rest_framework.decorators import action from rest_framework.response import Response -from apps.chat.client import GeminiClient, HunYuanClient, OpenAIClient, QianfanClient +from apps.chat.client import ( + GeminiClient, + HunYuanClient, + OpenAIClient, + OpenAIVisionClient, + QianfanClient, +) from apps.chat.constants import AIModelProvider from apps.chat.exceptions import UnexpectedProvider, VerifyFailed from apps.chat.models import AIModel, ChatLog, ModelPermission @@ -56,11 +62,21 @@ def create(self, request, *args, **kwargs): case AIModelProvider.BAIDU: streaming_content = QianfanClient(request=request, **request_data).chat() case AIModelProvider.OPENAI: - streaming_content = OpenAIClient(request=request, **request_data).chat() + if model.is_vision: + content = OpenAIVisionClient(request=request, **request_data).chat() + else: + streaming_content = OpenAIClient(request=request, **request_data).chat() case _: raise UnexpectedProvider() # response + if model.is_vision: + return HttpResponse( + content=content.encode("utf-8"), + headers={ + "Trace-ID": getattr(request, "otel_trace_id", ""), + }, + ) return StreamingHttpResponse( streaming_content=streaming_content, headers={ diff --git a/requirements.txt b/requirements.txt index 2bd2b52..bf136e0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,8 +17,9 @@ pycryptodome==3.19.1 pyinstrument==4.4.0 # OpenAI -openai==0.28.1 +openai==1.12.0 tiktoken==0.4.0 +pillow==10.2.0 # Gemini google-generativeai==0.3.2