diff --git a/.cruft.json b/.cruft.json index 4b1b95e..1130f43 100644 --- a/.cruft.json +++ b/.cruft.json @@ -1,6 +1,6 @@ { "template": "https://github.com/OVINC-CN/DevTemplateDjango.git", - "commit": "99ee8789edfdce279bb8598e17636f8fe21bbb00", + "commit": "2451aad2dcca897848dc8f52a69da38df5d9a5ed", "checkout": "main", "context": { "cookiecutter": { diff --git a/apps/cel/tasks/__init__.py b/apps/cel/tasks/__init__.py index 0eca494..608da2a 100644 --- a/apps/cel/tasks/__init__.py +++ b/apps/cel/tasks/__init__.py @@ -1,7 +1,8 @@ -from apps.cel.tasks.chat import check_usage_limit +from apps.cel.tasks.chat import async_reply, check_usage_limit from apps.cel.tasks.cos import extract_file __all__ = [ "check_usage_limit", + "async_reply", "extract_file", ] diff --git a/apps/cel/tasks/chat.py b/apps/cel/tasks/chat.py index 154c942..ccf7745 100644 --- a/apps/cel/tasks/chat.py +++ b/apps/cel/tasks/chat.py @@ -1,3 +1,4 @@ +from asgiref.sync import async_to_sync from django.db import transaction from django.db.models import F from django.shortcuts import get_object_or_404 @@ -5,6 +6,7 @@ from ovinc_client.core.logger import celery_logger from apps.cel import app +from apps.chat.consumers_async import AsyncConsumer from apps.chat.models import ChatLog from apps.wallet.models import Wallet @@ -56,3 +58,14 @@ def calculate_usage_limit(self, log_id: str): ) celery_logger.info("[CalculateUsageLimit] End %s", self.request.id) + + +@app.task(bind=True) +def async_reply(self, channel_name: str, key: str): + """ + Async Reply to User + """ + + celery_logger.info("[AsyncReply] Start %s %s %s", self.request.id, channel_name, key) + async_to_sync(AsyncConsumer(channel_name=channel_name, key=key).chat)() + celery_logger.info("[AsyncReply] End %s %s %s", self.request.id, channel_name, key) diff --git a/apps/chat/constants.py b/apps/chat/constants.py index eb8f1ec..7f3d7ea 100644 --- a/apps/chat/constants.py +++ b/apps/chat/constants.py @@ -15,6 +15,10 @@ PRICE_DIGIT_NUMS = 20 PRICE_DECIMAL_NUMS = 10 +WS_CLOSED_KEY = "ws:closed:{}" + +MESSAGE_CACHE_KEY = "message:{}" + if "celery" in sys.argv: TOKEN_ENCODING = "" else: diff --git a/apps/chat/consumers.py b/apps/chat/consumers.py index ff9c047..0ed306c 100644 --- a/apps/chat/consumers.py +++ b/apps/chat/consumers.py @@ -1,74 +1,19 @@ -import asyncio import json -from typing import Type -from channels.db import database_sync_to_async from django.contrib.auth import get_user_model -from django.core.cache import cache -from django.shortcuts import get_object_or_404 -from django_redis.client import DefaultClient from ovinc_client.account.models import User -from ovinc_client.core.logger import logger -from apps.chat.client import ( - BaiLianClient, - DoubaoClient, - GeminiClient, - HunYuanClient, - HunYuanVisionClient, - KimiClient, - MidjourneyClient, - OpenAIClient, - OpenAIVisionClient, - QianfanClient, -) -from apps.chat.client.base import BaseClient -from apps.chat.constants import AIModelProvider -from apps.chat.exceptions import UnexpectedProvider, VerifyFailed -from apps.chat.models import AIModel +from apps.cel.tasks import async_reply +from apps.chat.exceptions import VerifyFailed from apps.chat.serializers import OpenAIChatRequestSerializer from utils.consumers import AsyncWebsocketConsumer -cache: DefaultClient USER_MODEL: User = get_user_model() class ChatConsumer(AsyncWebsocketConsumer): async def receive(self, text_data=None, *args, **kwargs): - try: - await self.chat(text_data=text_data) - await self.send(text_data=json.dumps({"is_finished": True}, ensure_ascii=False)) - except Exception as err: # pylint: disable=W0718 - logger.exception("[ChatError] %s", err) - await self.send(text_data=json.dumps({"data": str(err), "is_finished": True}, ensure_ascii=False)) - await self.close() - - async def chat(self, text_data) -> None: - # validate - validated_data = self.validate_input(text_data=text_data) - key = validated_data["key"] - - # cache - request_data = self.load_data_from_cache(key) - - # model - model = await database_sync_to_async(self.get_model_inst)(request_data["model"]) - - # get client - client = self.get_model_client(model) - - # init client - client = await database_sync_to_async(client)(**request_data) - - # response - async for data in client.chat(): - await self.send( - text_data=json.dumps({"data": data, "is_finished": False, "log_id": client.log.id}, ensure_ascii=False) - ) - await asyncio.sleep(0) - - def validate_input(self, text_data: str) -> dict: - # json + # load input try: data = json.loads(text_data) except json.JSONDecodeError as err: @@ -77,40 +22,16 @@ def validate_input(self, text_data: str) -> dict: # validate request request_serializer = OpenAIChatRequestSerializer(data=data) request_serializer.is_valid(raise_exception=True) - return request_serializer.validated_data + data = request_serializer.validated_data - def load_data_from_cache(self, key: str) -> dict: - request_data = cache.get(key=key) - cache.delete(key=key) - if not request_data: - raise VerifyFailed() - return request_data + # async chat + async_reply.apply_async(kwargs={"channel_name": self.channel_name, "key": data["key"]}) - def get_model_inst(self, model: str) -> AIModel: - return get_object_or_404(AIModel, model=model) + async def chat_send(self, event: dict): + await self.send(text_data=event["text_data"]) + + async def chat_close(self, event: dict): + await self.close() - # pylint: disable=R0911 - def get_model_client(self, model: AIModel) -> Type[BaseClient]: - match model.provider: - case AIModelProvider.TENCENT: - if model.is_vision: - return HunYuanVisionClient - return HunYuanClient - case AIModelProvider.GOOGLE: - return GeminiClient - case AIModelProvider.BAIDU: - return QianfanClient - case AIModelProvider.OPENAI: - if model.is_vision: - return OpenAIVisionClient - return OpenAIClient - case AIModelProvider.ALIYUN: - return BaiLianClient - case AIModelProvider.MOONSHOT: - return KimiClient - case AIModelProvider.DOUBAO: - return DoubaoClient - case AIModelProvider.MIDJOURNEY: - return MidjourneyClient - case _: - raise UnexpectedProvider() + async def disconnect(self, code): + await super().disconnect(code) diff --git a/apps/chat/consumers_async.py b/apps/chat/consumers_async.py new file mode 100644 index 0000000..7e2443c --- /dev/null +++ b/apps/chat/consumers_async.py @@ -0,0 +1,147 @@ +import asyncio +import json +from typing import Type + +from autobahn.exception import Disconnected +from channels.db import database_sync_to_async +from channels.exceptions import ChannelFull +from channels.layers import get_channel_layer +from channels_redis.core import RedisChannelLayer +from django.conf import settings +from django.core.cache import cache +from django.shortcuts import get_object_or_404 +from ovinc_client.core.logger import logger + +from apps.chat.client import ( + BaiLianClient, + DoubaoClient, + GeminiClient, + HunYuanClient, + HunYuanVisionClient, + KimiClient, + MidjourneyClient, + OpenAIClient, + OpenAIVisionClient, + QianfanClient, +) +from apps.chat.client.base import BaseClient +from apps.chat.constants import WS_CLOSED_KEY, AIModelProvider +from apps.chat.exceptions import UnexpectedProvider, VerifyFailed +from apps.chat.models import AIModel + + +class AsyncConsumer: + """ + Async Consumer + """ + + def __init__(self, channel_name: str, key: str): + self.channel_layer: RedisChannelLayer = get_channel_layer() + self.channel_name = channel_name + self.key = key + + async def chat(self) -> None: + try: + is_closed = await self.do_chat() + if is_closed: + return + await self.send(text_data=json.dumps({"is_finished": True}, ensure_ascii=False)) + except Exception as err: # pylint: disable=W0718 + logger.exception("[ChatError] %s", err) + await self.send(text_data=json.dumps({"data": str(err), "is_finished": True}, ensure_ascii=False)) + await self.close() + + async def send(self, text_data: str): + await self.channel_layer.send( + channel=self.channel_name, + message={"type": "chat.send", "text_data": text_data}, + ) + + async def close(self): + await self.channel_layer.send( + channel=self.channel_name, + message={ + "type": "chat.close", + }, + ) + + async def do_chat(self) -> bool: + # cache + request_data = self.load_data_from_cache(self.key) + + # model + model = await database_sync_to_async(self.get_model_inst)(request_data["model"]) + + # get client + client = self.get_model_client(model) + + # init client + client = await database_sync_to_async(client)(**request_data) + + # response + is_closed = False + async for data in client.chat(): + if is_closed: + continue + retry_times = 0 + while retry_times <= settings.CHANNEL_RETRY_TIMES: + try: + await self.send( + text_data=json.dumps( + {"data": data, "is_finished": False, "log_id": client.log.id}, ensure_ascii=False + ) + ) + await asyncio.sleep(0) + break + except Disconnected: + logger.warning("[SendMessageFailed-Disconnected] Channel: %s", self.channel_name) + is_closed = True + break + except ChannelFull: + if cache.get(WS_CLOSED_KEY.format(self.channel_name)): + logger.warning("[SendMessageFailed-Disconnected] Channel: %s", self.channel_name) + is_closed = True + break + logger.warning( + "[SendMessageFailed-ChannelFull] Channel: %s; Retry: %d;", self.channel_name, retry_times + ) + await asyncio.sleep(settings.CHANNEL_RETRY_SLEEP) + retry_times += 1 + + return is_closed + + def load_data_from_cache(self, key: str) -> dict: + request_data = cache.get(key=key) + cache.delete(key=key) + if not request_data: + raise VerifyFailed() + return request_data + + def get_model_inst(self, model: str) -> AIModel: + return get_object_or_404(AIModel, model=model) + + # pylint: disable=R0911 + def get_model_client(self, model: AIModel) -> Type[BaseClient]: + match model.provider: + case AIModelProvider.TENCENT: + if model.is_vision: + return HunYuanVisionClient + return HunYuanClient + case AIModelProvider.GOOGLE: + return GeminiClient + case AIModelProvider.BAIDU: + return QianfanClient + case AIModelProvider.OPENAI: + if model.is_vision: + return OpenAIVisionClient + return OpenAIClient + case AIModelProvider.ALIYUN: + return BaiLianClient + case AIModelProvider.MOONSHOT: + return KimiClient + case AIModelProvider.DOUBAO: + return DoubaoClient + case AIModelProvider.MIDJOURNEY: + return MidjourneyClient + case _: + raise UnexpectedProvider() diff --git a/apps/chat/views.py b/apps/chat/views.py index 91d83ab..88d5fee 100644 --- a/apps/chat/views.py +++ b/apps/chat/views.py @@ -14,6 +14,7 @@ from rest_framework.decorators import action from rest_framework.response import Response +from apps.chat.constants import MESSAGE_CACHE_KEY from apps.chat.models import AIModel, ChatLog, SystemPreset from apps.chat.permissions import AIModelPermission from apps.chat.serializers import ( @@ -47,7 +48,7 @@ async def pre_check(self, request, *args, **kwargs): await database_sync_to_async(get_object_or_404)(AIModel, model=request_data["model"], is_enabled=True) # cache - cache_key = uniq_id() + cache_key = MESSAGE_CACHE_KEY.format(uniq_id()) cache.set( key=cache_key, value={**request_data, "user": request.user.username}, diff --git a/entry/settings.py b/entry/settings.py index fa02bfe..6aa8702 100644 --- a/entry/settings.py +++ b/entry/settings.py @@ -1,4 +1,5 @@ import os +import re from pathlib import Path from environ import environ @@ -119,9 +120,15 @@ "hosts": [ f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}", ], + "channel_capacity": { + re.compile(r".*"): int(os.getenv("CHANNEL_LAYER_DEFAULT_CAPACITY", "100")), + }, }, }, } +CHANNEL_RETRY_TIMES = int(os.getenv("CHANNEL_RETRY_TIMES", "1")) +CHANNEL_RETRY_SLEEP = int(os.getenv("CHANNEL_RETRY_SLEEP", "1")) # seconds +CHANNEL_CLOSE_KEY_TIMEOUT = int(os.getenv("CHANNEL_CLOSE_KEY_TIMEOUT", "60")) # Auth AUTH_PASSWORD_VALIDATORS = [ diff --git a/requirements.txt b/requirements.txt index 45b4f71..89e0c64 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ # ovinc -ovinc-client==0.3.8 +ovinc-client==0.3.9 # Celery celery==5.4.0 diff --git a/utils/consumers.py b/utils/consumers.py index 7fc03a9..3f94652 100644 --- a/utils/consumers.py +++ b/utils/consumers.py @@ -1,5 +1,8 @@ from channels.generic.websocket import AsyncWebsocketConsumer as _AsyncWebsocketConsumer +from django.conf import settings +from django.core.cache import cache +from apps.chat.constants import WS_CLOSED_KEY from utils.connections import connections_handler @@ -9,5 +12,6 @@ async def connect(self): connections_handler.add_connection(self.channel_name, self.scope["client"][0]) async def disconnect(self, code): + cache.set(key=WS_CLOSED_KEY.format(self.channel_name), value=True, timeout=settings.CHANNEL_CLOSE_KEY_TIMEOUT) await super().disconnect(code) connections_handler.remove_connection(self.channel_name, self.scope["client"][0])