Skip to content

Commit

Permalink
Merge pull request #25 from OVINC-CN/feat_model_config
Browse files Browse the repository at this point in the history
feat: use db to config models and prices #24
  • Loading branch information
OrenZhang authored Feb 5, 2024
2 parents 732deab + fbfc0a0 commit 35b0d9a
Show file tree
Hide file tree
Showing 9 changed files with 195 additions and 115 deletions.
28 changes: 22 additions & 6 deletions apps/chat/admin.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
import datetime
from typing import Union

from django.contrib import admin
from django.utils import timezone
from django.utils.translation import gettext_lazy
from rest_framework.settings import api_settings

from apps.chat.models import ChatLog, ModelPermission
from apps.chat.models import AIModel, ChatLog, ModelPermission


class ModelNameMixin:
@admin.display(description=gettext_lazy("Model Name"))
def model_name(self, inst: Union[ModelPermission, ChatLog]) -> str:
model_inst: AIModel = AIModel.objects.filter(model=inst.model, is_enabled=True).first()
if model_inst is None:
return ""
return model_inst.name


@admin.register(ChatLog)
class ChatLogAdmin(admin.ModelAdmin):
class ChatLogAdmin(ModelNameMixin, admin.ModelAdmin):
list_display = [
"id",
"user",
"model",
"model_name",
"prompt_tokens",
"completion_tokens",
"total_price",
Expand All @@ -29,7 +39,7 @@ def total_price(self, log: ChatLog) -> str:
log.prompt_tokens * log.prompt_token_unit_price / 1000
+ log.completion_tokens * log.completion_token_unit_price / 1000
)
return f"{price:.2f}"
return f"{price:.4f}"

@admin.display(description=gettext_lazy("Duration(ms)"))
def duration(self, log: ChatLog) -> int:
Expand All @@ -47,7 +57,13 @@ def created_at_formatted(self, log: ChatLog) -> str:


@admin.register(ModelPermission)
class ModelPermissionAdmin(admin.ModelAdmin):
list_display = ["id", "user", "model", "expired_at", "created_at"]
class ModelPermissionAdmin(ModelNameMixin, admin.ModelAdmin):
list_display = ["id", "user", "model_name", "expired_at", "created_at"]
list_filter = ["model"]
search_fields = ["user"]


@admin.register(AIModel)
class AIModelAdmin(admin.ModelAdmin):
list_display = ["id", "provider", "model", "name", "is_enabled", "prompt_price", "completion_price"]
list_filter = ["provider", "is_enabled"]
39 changes: 10 additions & 29 deletions apps/chat/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,10 @@
AI_API_REQUEST_TIMEOUT,
HUNYUAN_DATA_PATTERN,
GeminiRole,
OpenAIModel,
OpenAIRole,
OpenAIUnitPrice,
)
from apps.chat.exceptions import UnexpectedError
from apps.chat.models import ChatLog, HunYuanChuck, Message
from apps.chat.models import AIModel, ChatLog, HunYuanChuck, Message

USER_MODEL = get_user_model()

Expand All @@ -49,6 +47,7 @@ def __init__(
self.request: Request = request
self.user: USER_MODEL = request.user
self.model: str = model
self.model_inst: AIModel = AIModel.objects.get(model=model, is_enabled=True)
self.messages: Union[List[Message], QfMessages] = messages
self.temperature: float = temperature
self.top_p: float = top_p
Expand Down Expand Up @@ -121,28 +120,13 @@ def post_chat(self) -> None:
self.log.prompt_tokens = len(encoding.encode("".join([message["content"] for message in self.log.messages])))
self.log.completion_tokens = len(encoding.encode(self.log.content))
# calculate price
price = OpenAIUnitPrice.get_price(self.model)
self.log.prompt_token_unit_price = price.prompt_token_unit_price
self.log.completion_token_unit_price = price.completion_token_unit_price
self.log.prompt_token_unit_price = self.model_inst.prompt_price
self.log.completion_token_unit_price = self.model_inst.completion_price
# save
self.log.finished_at = self.finished_at
self.log.save()
self.log.remove_content()

@classmethod
def list_models(cls) -> List[dict]:
all_models = openai.Model.list(
api_base=settings.OPENAI_API_BASE, api_key=settings.OPENAI_API_KEY
).to_dict_recursive()["data"]
supported_models = [
{"id": model["id"], "name": str(OpenAIModel.get_name(model["id"]))}
for model in all_models
if model["id"] in OpenAIModel.values
]
supported_models.append({"id": OpenAIModel.HUNYUAN.value, "name": str(OpenAIModel.HUNYUAN.label)})
supported_models.sort(key=lambda model: model["id"])
return supported_models


class HunYuanClient(BaseClient):
"""
Expand Down Expand Up @@ -185,9 +169,8 @@ def record(self, response: HunYuanChuck) -> None:
self.log.content += response.choices[0].delta.content
self.log.prompt_tokens = response.usage.prompt_tokens
self.log.completion_tokens = response.usage.completion_tokens
price = OpenAIUnitPrice.get_price(self.model)
self.log.prompt_token_unit_price = price.prompt_token_unit_price
self.log.completion_token_unit_price = price.completion_token_unit_price
self.log.prompt_token_unit_price = self.model_inst.prompt_price
self.log.completion_token_unit_price = self.model_inst.completion_price
return
# create log
self.log = ChatLog.objects.create(
Expand Down Expand Up @@ -290,9 +273,8 @@ def post_chat(self) -> None:
self.log.prompt_tokens = len("".join([message["content"] for message in self.log.messages]))
self.log.completion_tokens = len(self.log.content)
# calculate price
price = OpenAIUnitPrice.get_price(self.model)
self.log.prompt_token_unit_price = price.prompt_token_unit_price
self.log.completion_token_unit_price = price.completion_token_unit_price
self.log.prompt_token_unit_price = self.model_inst.prompt_price
self.log.completion_token_unit_price = self.model_inst.completion_price
# save
self.log.finished_at = self.finished_at
self.log.save()
Expand Down Expand Up @@ -345,9 +327,8 @@ def post_chat(self) -> None:
if not self.log:
return
# calculate price
price = OpenAIUnitPrice.get_price(self.model)
self.log.prompt_token_unit_price = price.prompt_token_unit_price
self.log.completion_token_unit_price = price.completion_token_unit_price
self.log.prompt_token_unit_price = self.model_inst.prompt_price
self.log.completion_token_unit_price = self.model_inst.completion_price
# save
self.log.finished_at = self.finished_at
self.log.save()
Expand Down
63 changes: 10 additions & 53 deletions apps/chat/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
from dataclasses import dataclass

import tiktoken
from django.utils.translation import gettext_lazy
from ovinc_client.core.models import TextChoices

Expand All @@ -19,6 +19,9 @@
HUNYUAN_DATA_PATTERN = re.compile(rb"data:\s\{.*\}\n\n")


TOKEN_ENCODING = tiktoken.encoding_for_model("gpt-3.5-turbo")


class OpenAIRole(TextChoices):
"""
OpenAI Chat Role
Expand All @@ -38,58 +41,12 @@ class GeminiRole(TextChoices):
MODEL = "model", gettext_lazy("Model")


class OpenAIModel(TextChoices):
"""
OpenAI Model
"""

GPT4 = "gpt-4", "GPT4"
GPT4_32K = "gpt-4-32k", "GPT4 (32K)"
GPT4_TURBO = "gpt-4-1106-preview", "GPT4 Turbo"
GPT35_TURBO = "gpt-3.5-turbo", "GPT3.5 Turbo"
HUNYUAN = "hunyuan-plus", gettext_lazy("HunYuan Plus")
GEMINI = "gemini-pro", "Gemini Pro"
ERNIE_BOT_4_0 = "ERNIE-Bot-4", "ERNIE-Bot 4.0"
ERNIE_BOT_8K = "ERNIE-Bot-8k", "ERNIE-Bot 8K"
ERNIE_BOT = "ERNIE-Bot", "ERNIE-Bot"
ERNIE_BOT_TURBO = "ERNIE-Bot-turbo-AI", "ERNIE-Bot Turbo"

@classmethod
def get_name(cls, model: str) -> str:
for value, label in cls.choices:
if value == model:
return str(label)
return model


@dataclass
class OpenAIUnitPriceItem:
"""
OpenAI Unit Price Item
"""

prompt_token_unit_price: float
completion_token_unit_price: float


class OpenAIUnitPrice:
class AIModelProvider(TextChoices):
"""
OpenAI Unit Price Per Thousand Tokens ($)
AI Model Provider
"""

price_map = {
OpenAIModel.GPT4.value: OpenAIUnitPriceItem(0.03, 0.06),
OpenAIModel.GPT4_32K.value: OpenAIUnitPriceItem(0.06, 0.12),
OpenAIModel.GPT4_TURBO.value: OpenAIUnitPriceItem(0.01, 0.03),
OpenAIModel.GPT35_TURBO.value: OpenAIUnitPriceItem(0.001, 0.002),
OpenAIModel.HUNYUAN.value: OpenAIUnitPriceItem(round(0.10 / 7, 4), round(0.10 / 7, 4)),
OpenAIModel.GEMINI.value: OpenAIUnitPriceItem(0, 0),
OpenAIModel.ERNIE_BOT_4_0.value: OpenAIUnitPriceItem(round(0.12 / 7, 4), round(0.12 / 7, 4)),
OpenAIModel.ERNIE_BOT_8K.value: OpenAIUnitPriceItem(round(0.024 / 7, 4), round(0.048 / 7, 4)),
OpenAIModel.ERNIE_BOT.value: OpenAIUnitPriceItem(round(0.012 / 7, 4), round(0.012 / 7, 4)),
OpenAIModel.ERNIE_BOT_TURBO.value: OpenAIUnitPriceItem(round(0.008 / 7, 4), round(0.008 / 7, 4)),
}

@classmethod
def get_price(cls, model: str) -> OpenAIUnitPriceItem:
return cls.price_map.get(model)
OPENAI = "openai", gettext_lazy("Open AI")
GOOGLE = "google", gettext_lazy("Google")
BAIDU = "baidu", gettext_lazy("Baidu")
TENCENT = "tencent", gettext_lazy("Tencent")
5 changes: 5 additions & 0 deletions apps/chat/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,8 @@ class UnexpectedError(APIException):
class VerifyFailed(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_detail = gettext_lazy("Pre Check Verify Failed")


class UnexpectedProvider(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_detail = gettext_lazy("Unexpected Provider")
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# pylint: disable=R0801,C0103

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("chat", "0003_remove_content"),
]

operations = [
migrations.AlterField(
model_name="chatlog",
name="model",
field=models.CharField(blank=True, db_index=True, max_length=64, null=True, verbose_name="Model"),
),
migrations.AlterField(
model_name="modelpermission",
name="model",
field=models.CharField(blank=True, db_index=True, max_length=64, null=True, verbose_name="Model"),
),
]
58 changes: 58 additions & 0 deletions apps/chat/migrations/0005_aimodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# pylint: disable=R0801,C0103

import ovinc_client.core.models
import ovinc_client.core.utils
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("chat", "0004_alter_chatlog_model_alter_modelpermission_model_and_more"),
]

operations = [
migrations.CreateModel(
name="AIModel",
fields=[
(
"id",
ovinc_client.core.models.UniqIDField(
default=ovinc_client.core.utils.uniq_id_without_time,
max_length=32,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
(
"provider",
models.CharField(
choices=[
("openai", "Open AI"),
("google", "Google"),
("baidu", "Baidu"),
("tencent", "Tencent"),
],
db_index=True,
max_length=64,
verbose_name="Provider",
),
),
("model", models.CharField(db_index=True, max_length=64, verbose_name="Model")),
("name", models.CharField(max_length=64, verbose_name="Model Name")),
("is_enabled", models.BooleanField(db_index=True, default=True, verbose_name="Enabled")),
("prompt_price", models.DecimalField(decimal_places=10, max_digits=20, verbose_name="Prompt Price")),
(
"completion_price",
models.DecimalField(decimal_places=10, max_digits=20, verbose_name="Completion Price"),
),
],
options={
"verbose_name": "AI Model",
"verbose_name_plural": "AI Model",
"ordering": ["provider", "name"],
"unique_together": {("provider", "model")},
},
),
]
38 changes: 34 additions & 4 deletions apps/chat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from apps.chat.constants import (
PRICE_DECIMAL_NUMS,
PRICE_DIGIT_NUMS,
OpenAIModel,
AIModelProvider,
OpenAIRole,
)

Expand All @@ -32,7 +32,6 @@ class ChatLog(BaseModel):
user = ForeignKey(gettext_lazy("User"), to="account.User", on_delete=models.PROTECT)
model = models.CharField(
gettext_lazy("Model"),
choices=OpenAIModel.choices,
max_length=MEDIUM_CHAR_LENGTH,
null=True,
blank=True,
Expand Down Expand Up @@ -89,7 +88,6 @@ class ModelPermission(BaseModel):
user = ForeignKey(gettext_lazy("User"), to="account.User", on_delete=models.PROTECT)
model = models.CharField(
gettext_lazy("Model"),
choices=OpenAIModel.choices,
max_length=MEDIUM_CHAR_LENGTH,
null=True,
blank=True,
Expand All @@ -106,14 +104,20 @@ class Meta:

@classmethod
def authed_models(cls, user: USER_MODEL, model: str = None) -> QuerySet:
# load enabled models
queryset = AIModel.objects.filter(is_enabled=True)
# build filter
q = Q(user=user) # pylint: disable=C0103
if model:
q &= Q( # pylint: disable=C0103
Q(model=str(model), expired_at__gt=timezone.now()) | Q(model=str(model), expired_at__isnull=True)
)
else:
q &= Q(Q(expired_at__gt=timezone.now()) | Q(expired_at__isnull=True)) # pylint: disable=C0103
return cls.objects.filter(q)
# load permission
authed_models = cls.objects.filter(q).values("model")
# load authed models
return queryset.filter(model__in=authed_models)


@dataclass
Expand Down Expand Up @@ -160,3 +164,29 @@ def create(cls, data: dict) -> "HunYuanChuck":
for choice in data.get("choices", [])
]
return chuck


class AIModel(BaseModel):
"""
AI Model
"""

id = UniqIDField(gettext_lazy("ID"), primary_key=True)
provider = models.CharField(
gettext_lazy("Provider"), max_length=MEDIUM_CHAR_LENGTH, choices=AIModelProvider.choices, db_index=True
)
model = models.CharField(gettext_lazy("Model"), max_length=MEDIUM_CHAR_LENGTH, db_index=True)
name = models.CharField(gettext_lazy("Model Name"), max_length=MEDIUM_CHAR_LENGTH)
is_enabled = models.BooleanField(gettext_lazy("Enabled"), default=True, db_index=True)
prompt_price = models.DecimalField(
gettext_lazy("Prompt Price"), max_digits=PRICE_DIGIT_NUMS, decimal_places=PRICE_DECIMAL_NUMS
)
completion_price = models.DecimalField(
gettext_lazy("Completion Price"), max_digits=PRICE_DIGIT_NUMS, decimal_places=PRICE_DECIMAL_NUMS
)

class Meta:
verbose_name = gettext_lazy("AI Model")
verbose_name_plural = verbose_name
ordering = ["provider", "name"]
unique_together = [["provider", "model"]]
Loading

0 comments on commit 35b0d9a

Please sign in to comment.