diff --git a/apps/chat/admin.py b/apps/chat/admin.py index b7b95df..620f455 100644 --- a/apps/chat/admin.py +++ b/apps/chat/admin.py @@ -1,19 +1,29 @@ import datetime +from typing import Union from django.contrib import admin from django.utils import timezone from django.utils.translation import gettext_lazy from rest_framework.settings import api_settings -from apps.chat.models import ChatLog, ModelPermission +from apps.chat.models import AIModel, ChatLog, ModelPermission + + +class ModelNameMixin: + @admin.display(description=gettext_lazy("Model Name")) + def model_name(self, inst: Union[ModelPermission, ChatLog]) -> str: + model_inst: AIModel = AIModel.objects.filter(model=inst.model, is_enabled=True).first() + if model_inst is None: + return "" + return model_inst.name @admin.register(ChatLog) -class ChatLogAdmin(admin.ModelAdmin): +class ChatLogAdmin(ModelNameMixin, admin.ModelAdmin): list_display = [ "id", "user", - "model", + "model_name", "prompt_tokens", "completion_tokens", "total_price", @@ -29,7 +39,7 @@ def total_price(self, log: ChatLog) -> str: log.prompt_tokens * log.prompt_token_unit_price / 1000 + log.completion_tokens * log.completion_token_unit_price / 1000 ) - return f"{price:.2f}" + return f"{price:.4f}" @admin.display(description=gettext_lazy("Duration(ms)")) def duration(self, log: ChatLog) -> int: @@ -47,7 +57,13 @@ def created_at_formatted(self, log: ChatLog) -> str: @admin.register(ModelPermission) -class ModelPermissionAdmin(admin.ModelAdmin): - list_display = ["id", "user", "model", "expired_at", "created_at"] +class ModelPermissionAdmin(ModelNameMixin, admin.ModelAdmin): + list_display = ["id", "user", "model_name", "expired_at", "created_at"] list_filter = ["model"] search_fields = ["user"] + + +@admin.register(AIModel) +class AIModelAdmin(admin.ModelAdmin): + list_display = ["id", "provider", "model", "name", "is_enabled", "prompt_price", "completion_price"] + list_filter = ["provider", "is_enabled"] diff --git a/apps/chat/client.py b/apps/chat/client.py index 96e58c1..3494aab 100644 --- a/apps/chat/client.py +++ b/apps/chat/client.py @@ -25,12 +25,10 @@ AI_API_REQUEST_TIMEOUT, HUNYUAN_DATA_PATTERN, GeminiRole, - OpenAIModel, OpenAIRole, - OpenAIUnitPrice, ) from apps.chat.exceptions import UnexpectedError -from apps.chat.models import ChatLog, HunYuanChuck, Message +from apps.chat.models import AIModel, ChatLog, HunYuanChuck, Message USER_MODEL = get_user_model() @@ -49,6 +47,7 @@ def __init__( self.request: Request = request self.user: USER_MODEL = request.user self.model: str = model + self.model_inst: AIModel = AIModel.objects.get(model=model, is_enabled=True) self.messages: Union[List[Message], QfMessages] = messages self.temperature: float = temperature self.top_p: float = top_p @@ -121,28 +120,13 @@ def post_chat(self) -> None: self.log.prompt_tokens = len(encoding.encode("".join([message["content"] for message in self.log.messages]))) self.log.completion_tokens = len(encoding.encode(self.log.content)) # calculate price - price = OpenAIUnitPrice.get_price(self.model) - self.log.prompt_token_unit_price = price.prompt_token_unit_price - self.log.completion_token_unit_price = price.completion_token_unit_price + self.log.prompt_token_unit_price = self.model_inst.prompt_price + self.log.completion_token_unit_price = self.model_inst.completion_price # save self.log.finished_at = self.finished_at self.log.save() self.log.remove_content() - @classmethod - def list_models(cls) -> List[dict]: - all_models = openai.Model.list( - api_base=settings.OPENAI_API_BASE, api_key=settings.OPENAI_API_KEY - ).to_dict_recursive()["data"] - supported_models = [ - {"id": model["id"], "name": str(OpenAIModel.get_name(model["id"]))} - for model in all_models - if model["id"] in OpenAIModel.values - ] - supported_models.append({"id": OpenAIModel.HUNYUAN.value, "name": str(OpenAIModel.HUNYUAN.label)}) - supported_models.sort(key=lambda model: model["id"]) - return supported_models - class HunYuanClient(BaseClient): """ @@ -185,9 +169,8 @@ def record(self, response: HunYuanChuck) -> None: self.log.content += response.choices[0].delta.content self.log.prompt_tokens = response.usage.prompt_tokens self.log.completion_tokens = response.usage.completion_tokens - price = OpenAIUnitPrice.get_price(self.model) - self.log.prompt_token_unit_price = price.prompt_token_unit_price - self.log.completion_token_unit_price = price.completion_token_unit_price + self.log.prompt_token_unit_price = self.model_inst.prompt_price + self.log.completion_token_unit_price = self.model_inst.completion_price return # create log self.log = ChatLog.objects.create( @@ -290,9 +273,8 @@ def post_chat(self) -> None: self.log.prompt_tokens = len("".join([message["content"] for message in self.log.messages])) self.log.completion_tokens = len(self.log.content) # calculate price - price = OpenAIUnitPrice.get_price(self.model) - self.log.prompt_token_unit_price = price.prompt_token_unit_price - self.log.completion_token_unit_price = price.completion_token_unit_price + self.log.prompt_token_unit_price = self.model_inst.prompt_price + self.log.completion_token_unit_price = self.model_inst.completion_price # save self.log.finished_at = self.finished_at self.log.save() @@ -345,9 +327,8 @@ def post_chat(self) -> None: if not self.log: return # calculate price - price = OpenAIUnitPrice.get_price(self.model) - self.log.prompt_token_unit_price = price.prompt_token_unit_price - self.log.completion_token_unit_price = price.completion_token_unit_price + self.log.prompt_token_unit_price = self.model_inst.prompt_price + self.log.completion_token_unit_price = self.model_inst.completion_price # save self.log.finished_at = self.finished_at self.log.save() diff --git a/apps/chat/constants.py b/apps/chat/constants.py index af13a8d..51bdb54 100644 --- a/apps/chat/constants.py +++ b/apps/chat/constants.py @@ -1,6 +1,6 @@ import re -from dataclasses import dataclass +import tiktoken from django.utils.translation import gettext_lazy from ovinc_client.core.models import TextChoices @@ -19,6 +19,9 @@ HUNYUAN_DATA_PATTERN = re.compile(rb"data:\s\{.*\}\n\n") +TOKEN_ENCODING = tiktoken.encoding_for_model("gpt-3.5-turbo") + + class OpenAIRole(TextChoices): """ OpenAI Chat Role @@ -38,58 +41,12 @@ class GeminiRole(TextChoices): MODEL = "model", gettext_lazy("Model") -class OpenAIModel(TextChoices): - """ - OpenAI Model - """ - - GPT4 = "gpt-4", "GPT4" - GPT4_32K = "gpt-4-32k", "GPT4 (32K)" - GPT4_TURBO = "gpt-4-1106-preview", "GPT4 Turbo" - GPT35_TURBO = "gpt-3.5-turbo", "GPT3.5 Turbo" - HUNYUAN = "hunyuan-plus", gettext_lazy("HunYuan Plus") - GEMINI = "gemini-pro", "Gemini Pro" - ERNIE_BOT_4_0 = "ERNIE-Bot-4", "ERNIE-Bot 4.0" - ERNIE_BOT_8K = "ERNIE-Bot-8k", "ERNIE-Bot 8K" - ERNIE_BOT = "ERNIE-Bot", "ERNIE-Bot" - ERNIE_BOT_TURBO = "ERNIE-Bot-turbo-AI", "ERNIE-Bot Turbo" - - @classmethod - def get_name(cls, model: str) -> str: - for value, label in cls.choices: - if value == model: - return str(label) - return model - - -@dataclass -class OpenAIUnitPriceItem: - """ - OpenAI Unit Price Item - """ - - prompt_token_unit_price: float - completion_token_unit_price: float - - -class OpenAIUnitPrice: +class AIModelProvider(TextChoices): """ - OpenAI Unit Price Per Thousand Tokens ($) + AI Model Provider """ - price_map = { - OpenAIModel.GPT4.value: OpenAIUnitPriceItem(0.03, 0.06), - OpenAIModel.GPT4_32K.value: OpenAIUnitPriceItem(0.06, 0.12), - OpenAIModel.GPT4_TURBO.value: OpenAIUnitPriceItem(0.01, 0.03), - OpenAIModel.GPT35_TURBO.value: OpenAIUnitPriceItem(0.001, 0.002), - OpenAIModel.HUNYUAN.value: OpenAIUnitPriceItem(round(0.10 / 7, 4), round(0.10 / 7, 4)), - OpenAIModel.GEMINI.value: OpenAIUnitPriceItem(0, 0), - OpenAIModel.ERNIE_BOT_4_0.value: OpenAIUnitPriceItem(round(0.12 / 7, 4), round(0.12 / 7, 4)), - OpenAIModel.ERNIE_BOT_8K.value: OpenAIUnitPriceItem(round(0.024 / 7, 4), round(0.048 / 7, 4)), - OpenAIModel.ERNIE_BOT.value: OpenAIUnitPriceItem(round(0.012 / 7, 4), round(0.012 / 7, 4)), - OpenAIModel.ERNIE_BOT_TURBO.value: OpenAIUnitPriceItem(round(0.008 / 7, 4), round(0.008 / 7, 4)), - } - - @classmethod - def get_price(cls, model: str) -> OpenAIUnitPriceItem: - return cls.price_map.get(model) + OPENAI = "openai", gettext_lazy("Open AI") + GOOGLE = "google", gettext_lazy("Google") + BAIDU = "baidu", gettext_lazy("Baidu") + TENCENT = "tencent", gettext_lazy("Tencent") diff --git a/apps/chat/exceptions.py b/apps/chat/exceptions.py index a963664..1e931bb 100644 --- a/apps/chat/exceptions.py +++ b/apps/chat/exceptions.py @@ -16,3 +16,8 @@ class UnexpectedError(APIException): class VerifyFailed(APIException): status_code = status.HTTP_400_BAD_REQUEST default_detail = gettext_lazy("Pre Check Verify Failed") + + +class UnexpectedProvider(APIException): + status_code = status.HTTP_400_BAD_REQUEST + default_detail = gettext_lazy("Unexpected Provider") diff --git a/apps/chat/migrations/0004_alter_chatlog_model_alter_modelpermission_model_and_more.py b/apps/chat/migrations/0004_alter_chatlog_model_alter_modelpermission_model_and_more.py new file mode 100644 index 0000000..619bb8b --- /dev/null +++ b/apps/chat/migrations/0004_alter_chatlog_model_alter_modelpermission_model_and_more.py @@ -0,0 +1,23 @@ +# pylint: disable=R0801,C0103 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("chat", "0003_remove_content"), + ] + + operations = [ + migrations.AlterField( + model_name="chatlog", + name="model", + field=models.CharField(blank=True, db_index=True, max_length=64, null=True, verbose_name="Model"), + ), + migrations.AlterField( + model_name="modelpermission", + name="model", + field=models.CharField(blank=True, db_index=True, max_length=64, null=True, verbose_name="Model"), + ), + ] diff --git a/apps/chat/migrations/0005_aimodel.py b/apps/chat/migrations/0005_aimodel.py new file mode 100644 index 0000000..403a027 --- /dev/null +++ b/apps/chat/migrations/0005_aimodel.py @@ -0,0 +1,58 @@ +# pylint: disable=R0801,C0103 + +import ovinc_client.core.models +import ovinc_client.core.utils +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("chat", "0004_alter_chatlog_model_alter_modelpermission_model_and_more"), + ] + + operations = [ + migrations.CreateModel( + name="AIModel", + fields=[ + ( + "id", + ovinc_client.core.models.UniqIDField( + default=ovinc_client.core.utils.uniq_id_without_time, + max_length=32, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "provider", + models.CharField( + choices=[ + ("openai", "Open AI"), + ("google", "Google"), + ("baidu", "Baidu"), + ("tencent", "Tencent"), + ], + db_index=True, + max_length=64, + verbose_name="Provider", + ), + ), + ("model", models.CharField(db_index=True, max_length=64, verbose_name="Model")), + ("name", models.CharField(max_length=64, verbose_name="Model Name")), + ("is_enabled", models.BooleanField(db_index=True, default=True, verbose_name="Enabled")), + ("prompt_price", models.DecimalField(decimal_places=10, max_digits=20, verbose_name="Prompt Price")), + ( + "completion_price", + models.DecimalField(decimal_places=10, max_digits=20, verbose_name="Completion Price"), + ), + ], + options={ + "verbose_name": "AI Model", + "verbose_name_plural": "AI Model", + "ordering": ["provider", "name"], + "unique_together": {("provider", "model")}, + }, + ), + ] diff --git a/apps/chat/models.py b/apps/chat/models.py index 4abca53..d053841 100644 --- a/apps/chat/models.py +++ b/apps/chat/models.py @@ -13,7 +13,7 @@ from apps.chat.constants import ( PRICE_DECIMAL_NUMS, PRICE_DIGIT_NUMS, - OpenAIModel, + AIModelProvider, OpenAIRole, ) @@ -32,7 +32,6 @@ class ChatLog(BaseModel): user = ForeignKey(gettext_lazy("User"), to="account.User", on_delete=models.PROTECT) model = models.CharField( gettext_lazy("Model"), - choices=OpenAIModel.choices, max_length=MEDIUM_CHAR_LENGTH, null=True, blank=True, @@ -89,7 +88,6 @@ class ModelPermission(BaseModel): user = ForeignKey(gettext_lazy("User"), to="account.User", on_delete=models.PROTECT) model = models.CharField( gettext_lazy("Model"), - choices=OpenAIModel.choices, max_length=MEDIUM_CHAR_LENGTH, null=True, blank=True, @@ -106,6 +104,9 @@ class Meta: @classmethod def authed_models(cls, user: USER_MODEL, model: str = None) -> QuerySet: + # load enabled models + queryset = AIModel.objects.filter(is_enabled=True) + # build filter q = Q(user=user) # pylint: disable=C0103 if model: q &= Q( # pylint: disable=C0103 @@ -113,7 +114,10 @@ def authed_models(cls, user: USER_MODEL, model: str = None) -> QuerySet: ) else: q &= Q(Q(expired_at__gt=timezone.now()) | Q(expired_at__isnull=True)) # pylint: disable=C0103 - return cls.objects.filter(q) + # load permission + authed_models = cls.objects.filter(q).values("model") + # load authed models + return queryset.filter(model__in=authed_models) @dataclass @@ -160,3 +164,29 @@ def create(cls, data: dict) -> "HunYuanChuck": for choice in data.get("choices", []) ] return chuck + + +class AIModel(BaseModel): + """ + AI Model + """ + + id = UniqIDField(gettext_lazy("ID"), primary_key=True) + provider = models.CharField( + gettext_lazy("Provider"), max_length=MEDIUM_CHAR_LENGTH, choices=AIModelProvider.choices, db_index=True + ) + model = models.CharField(gettext_lazy("Model"), max_length=MEDIUM_CHAR_LENGTH, db_index=True) + name = models.CharField(gettext_lazy("Model Name"), max_length=MEDIUM_CHAR_LENGTH) + is_enabled = models.BooleanField(gettext_lazy("Enabled"), default=True, db_index=True) + prompt_price = models.DecimalField( + gettext_lazy("Prompt Price"), max_digits=PRICE_DIGIT_NUMS, decimal_places=PRICE_DECIMAL_NUMS + ) + completion_price = models.DecimalField( + gettext_lazy("Completion Price"), max_digits=PRICE_DIGIT_NUMS, decimal_places=PRICE_DECIMAL_NUMS + ) + + class Meta: + verbose_name = gettext_lazy("AI Model") + verbose_name_plural = verbose_name + ordering = ["provider", "name"] + unique_together = [["provider", "model"]] diff --git a/apps/chat/serializers.py b/apps/chat/serializers.py index 4a72c57..6a23fe3 100644 --- a/apps/chat/serializers.py +++ b/apps/chat/serializers.py @@ -1,6 +1,5 @@ from typing import List -import tiktoken from django.conf import settings from django.utils.translation import gettext, gettext_lazy from rest_framework import serializers @@ -10,13 +9,12 @@ TEMPERATURE_DEFAULT, TEMPERATURE_MAX, TEMPERATURE_MIN, + TOKEN_ENCODING, TOP_P_DEFAULT, TOP_P_MIN, - OpenAIModel, OpenAIRole, ) - -TOKEN_ENCODING = tiktoken.encoding_for_model(OpenAIModel.GPT35_TURBO) +from apps.chat.models import AIModel class OpenAIMessageSerializer(serializers.Serializer): @@ -28,12 +26,20 @@ class OpenAIMessageSerializer(serializers.Serializer): content = serializers.CharField(label=gettext_lazy("Content")) -class OpenAIRequestSerializer(serializers.Serializer): +class ModelSerializerMixin: + def validate_model(self, model: str) -> str: + if AIModel.objects.filter(model=model, is_enabled=True).exists(): + return model + # pylint: disable=E1101 + raise AIModel.DoesNotExist() + + +class OpenAIRequestSerializer(ModelSerializerMixin, serializers.Serializer): """ OpenAI Request """ - model = serializers.ChoiceField(label=gettext_lazy("Model"), choices=OpenAIModel.choices) + model = serializers.CharField(label=gettext_lazy("Model")) messages = serializers.ListField( label=gettext_lazy("Messages"), child=OpenAIMessageSerializer(), min_length=MESSAGE_MIN_LENGTH ) @@ -56,12 +62,12 @@ def validate_messages(self, messages: List[dict]) -> List[dict]: return messages -class CheckModelPermissionSerializer(serializers.Serializer): +class CheckModelPermissionSerializer(ModelSerializerMixin, serializers.Serializer): """ Model Permission """ - model = serializers.ChoiceField(label=gettext_lazy("Model"), choices=OpenAIModel.choices) + model = serializers.CharField(label=gettext_lazy("Model")) class OpenAIChatRequestSerializer(serializers.Serializer): diff --git a/apps/chat/views.py b/apps/chat/views.py index 5a718fb..ca1703d 100644 --- a/apps/chat/views.py +++ b/apps/chat/views.py @@ -2,15 +2,16 @@ from django.conf import settings from django.core.cache import cache from django.http import StreamingHttpResponse +from django.shortcuts import get_object_or_404 from ovinc_client.core.utils import uniq_id from ovinc_client.core.viewsets import CreateMixin, ListMixin, MainViewSet from rest_framework.decorators import action from rest_framework.response import Response from apps.chat.client import GeminiClient, HunYuanClient, OpenAIClient, QianfanClient -from apps.chat.constants import OpenAIModel -from apps.chat.exceptions import VerifyFailed -from apps.chat.models import ChatLog, ModelPermission +from apps.chat.constants import AIModelProvider +from apps.chat.exceptions import UnexpectedProvider, VerifyFailed +from apps.chat.models import AIModel, ChatLog, ModelPermission from apps.chat.permissions import AIModelPermission from apps.chat.serializers import ( CheckModelPermissionSerializer, @@ -43,15 +44,21 @@ def create(self, request, *args, **kwargs): if not request_data: raise VerifyFailed() + # model + model: AIModel = get_object_or_404(AIModel, model=request_data["model"]) + # call api - if request_data["model"] == OpenAIModel.HUNYUAN: - streaming_content = HunYuanClient(request=request, **request_data).chat() - elif request_data["model"] == OpenAIModel.GEMINI: - streaming_content = GeminiClient(request=request, **request_data).chat() - elif request_data["model"].startswith(OpenAIModel.ERNIE_BOT): - streaming_content = QianfanClient(request=request, **request_data).chat() - else: - streaming_content = OpenAIClient(request=request, **request_data).chat() + match model.provider: + case AIModelProvider.TENCENT: + streaming_content = HunYuanClient(request=request, **request_data).chat() + case AIModelProvider.GOOGLE: + streaming_content = GeminiClient(request=request, **request_data).chat() + case AIModelProvider.BAIDU: + streaming_content = QianfanClient(request=request, **request_data).chat() + case AIModelProvider.OPENAI: + streaming_content = OpenAIClient(request=request, **request_data).chat() + case _: + raise UnexpectedProvider() # response return StreamingHttpResponse( @@ -92,10 +99,7 @@ def list(self, request, *args, **kwargs): List Models """ - data = [ - {"id": model.model, "name": OpenAIModel.get_name(model.model)} - for model in ModelPermission.authed_models(user=request.user) - ] + data = [{"id": model.model, "name": model.name} for model in ModelPermission.authed_models(user=request.user)] data.sort(key=lambda model: model["id"]) return Response(data=data)