Skip to content

Commit

Permalink
Merge pull request #75 from OVINC-CN/feat_async_reply
Browse files Browse the repository at this point in the history
feat: async reply
  • Loading branch information
OrenZhang authored Jul 31, 2024
2 parents 94fc28f + 60dc0a6 commit 1202be7
Show file tree
Hide file tree
Showing 10 changed files with 194 additions and 96 deletions.
2 changes: 1 addition & 1 deletion .cruft.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"template": "https://github.com/OVINC-CN/DevTemplateDjango.git",
"commit": "99ee8789edfdce279bb8598e17636f8fe21bbb00",
"commit": "2451aad2dcca897848dc8f52a69da38df5d9a5ed",
"checkout": "main",
"context": {
"cookiecutter": {
Expand Down
3 changes: 2 additions & 1 deletion apps/cel/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from apps.cel.tasks.chat import check_usage_limit
from apps.cel.tasks.chat import async_reply, check_usage_limit
from apps.cel.tasks.cos import extract_file

__all__ = [
"check_usage_limit",
"async_reply",
"extract_file",
]
13 changes: 13 additions & 0 deletions apps/cel/tasks/chat.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from asgiref.sync import async_to_sync
from django.db import transaction
from django.db.models import F
from django.shortcuts import get_object_or_404
from ovinc_client.core.lock import task_lock
from ovinc_client.core.logger import celery_logger

from apps.cel import app
from apps.chat.consumers_async import AsyncConsumer
from apps.chat.models import ChatLog
from apps.wallet.models import Wallet

Expand Down Expand Up @@ -56,3 +58,14 @@ def calculate_usage_limit(self, log_id: str):
)

celery_logger.info("[CalculateUsageLimit] End %s", self.request.id)


@app.task(bind=True)
def async_reply(self, channel_name: str, key: str):
"""
Async Reply to User
"""

celery_logger.info("[AsyncReply] Start %s %s %s", self.request.id, channel_name, key)
async_to_sync(AsyncConsumer(channel_name=channel_name, key=key).chat)()
celery_logger.info("[AsyncReply] End %s %s %s", self.request.id, channel_name, key)
4 changes: 4 additions & 0 deletions apps/chat/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
PRICE_DIGIT_NUMS = 20
PRICE_DECIMAL_NUMS = 10

WS_CLOSED_KEY = "ws:closed:{}"

MESSAGE_CACHE_KEY = "message:{}"

if "celery" in sys.argv:
TOKEN_ENCODING = ""
else:
Expand Down
105 changes: 13 additions & 92 deletions apps/chat/consumers.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,19 @@
import asyncio
import json
from typing import Type

from channels.db import database_sync_to_async
from django.contrib.auth import get_user_model
from django.core.cache import cache
from django.shortcuts import get_object_or_404
from django_redis.client import DefaultClient
from ovinc_client.account.models import User
from ovinc_client.core.logger import logger

from apps.chat.client import (
BaiLianClient,
DoubaoClient,
GeminiClient,
HunYuanClient,
HunYuanVisionClient,
KimiClient,
MidjourneyClient,
OpenAIClient,
OpenAIVisionClient,
QianfanClient,
)
from apps.chat.client.base import BaseClient
from apps.chat.constants import AIModelProvider
from apps.chat.exceptions import UnexpectedProvider, VerifyFailed
from apps.chat.models import AIModel
from apps.cel.tasks import async_reply
from apps.chat.exceptions import VerifyFailed
from apps.chat.serializers import OpenAIChatRequestSerializer
from utils.consumers import AsyncWebsocketConsumer

cache: DefaultClient
USER_MODEL: User = get_user_model()


class ChatConsumer(AsyncWebsocketConsumer):
async def receive(self, text_data=None, *args, **kwargs):
try:
await self.chat(text_data=text_data)
await self.send(text_data=json.dumps({"is_finished": True}, ensure_ascii=False))
except Exception as err: # pylint: disable=W0718
logger.exception("[ChatError] %s", err)
await self.send(text_data=json.dumps({"data": str(err), "is_finished": True}, ensure_ascii=False))
await self.close()

async def chat(self, text_data) -> None:
# validate
validated_data = self.validate_input(text_data=text_data)
key = validated_data["key"]

# cache
request_data = self.load_data_from_cache(key)

# model
model = await database_sync_to_async(self.get_model_inst)(request_data["model"])

# get client
client = self.get_model_client(model)

# init client
client = await database_sync_to_async(client)(**request_data)

# response
async for data in client.chat():
await self.send(
text_data=json.dumps({"data": data, "is_finished": False, "log_id": client.log.id}, ensure_ascii=False)
)
await asyncio.sleep(0)

def validate_input(self, text_data: str) -> dict:
# json
# load input
try:
data = json.loads(text_data)
except json.JSONDecodeError as err:
Expand All @@ -77,40 +22,16 @@ def validate_input(self, text_data: str) -> dict:
# validate request
request_serializer = OpenAIChatRequestSerializer(data=data)
request_serializer.is_valid(raise_exception=True)
return request_serializer.validated_data
data = request_serializer.validated_data

def load_data_from_cache(self, key: str) -> dict:
request_data = cache.get(key=key)
cache.delete(key=key)
if not request_data:
raise VerifyFailed()
return request_data
# async chat
async_reply.apply_async(kwargs={"channel_name": self.channel_name, "key": data["key"]})

def get_model_inst(self, model: str) -> AIModel:
return get_object_or_404(AIModel, model=model)
async def chat_send(self, event: dict):
await self.send(text_data=event["text_data"])

async def chat_close(self, event: dict):
await self.close()

# pylint: disable=R0911
def get_model_client(self, model: AIModel) -> Type[BaseClient]:
match model.provider:
case AIModelProvider.TENCENT:
if model.is_vision:
return HunYuanVisionClient
return HunYuanClient
case AIModelProvider.GOOGLE:
return GeminiClient
case AIModelProvider.BAIDU:
return QianfanClient
case AIModelProvider.OPENAI:
if model.is_vision:
return OpenAIVisionClient
return OpenAIClient
case AIModelProvider.ALIYUN:
return BaiLianClient
case AIModelProvider.MOONSHOT:
return KimiClient
case AIModelProvider.DOUBAO:
return DoubaoClient
case AIModelProvider.MIDJOURNEY:
return MidjourneyClient
case _:
raise UnexpectedProvider()
async def disconnect(self, code):
await super().disconnect(code)
147 changes: 147 additions & 0 deletions apps/chat/consumers_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import asyncio
import json
from typing import Type

from autobahn.exception import Disconnected
from channels.db import database_sync_to_async
from channels.exceptions import ChannelFull
from channels.layers import get_channel_layer
from channels_redis.core import RedisChannelLayer
from django.conf import settings
from django.core.cache import cache
from django.shortcuts import get_object_or_404
from ovinc_client.core.logger import logger

from apps.chat.client import (
BaiLianClient,
DoubaoClient,
GeminiClient,
HunYuanClient,
HunYuanVisionClient,
KimiClient,
MidjourneyClient,
OpenAIClient,
OpenAIVisionClient,
QianfanClient,
)
from apps.chat.client.base import BaseClient
from apps.chat.constants import WS_CLOSED_KEY, AIModelProvider
from apps.chat.exceptions import UnexpectedProvider, VerifyFailed
from apps.chat.models import AIModel


class AsyncConsumer:
"""
Async Consumer
"""

def __init__(self, channel_name: str, key: str):
self.channel_layer: RedisChannelLayer = get_channel_layer()
self.channel_name = channel_name
self.key = key

async def chat(self) -> None:
try:
is_closed = await self.do_chat()
if is_closed:
return
await self.send(text_data=json.dumps({"is_finished": True}, ensure_ascii=False))
except Exception as err: # pylint: disable=W0718
logger.exception("[ChatError] %s", err)
await self.send(text_data=json.dumps({"data": str(err), "is_finished": True}, ensure_ascii=False))
await self.close()

async def send(self, text_data: str):
await self.channel_layer.send(
channel=self.channel_name,
message={"type": "chat.send", "text_data": text_data},
)

async def close(self):
await self.channel_layer.send(
channel=self.channel_name,
message={
"type": "chat.close",
},
)

async def do_chat(self) -> bool:
# cache
request_data = self.load_data_from_cache(self.key)

# model
model = await database_sync_to_async(self.get_model_inst)(request_data["model"])

# get client
client = self.get_model_client(model)

# init client
client = await database_sync_to_async(client)(**request_data)

# response
is_closed = False
async for data in client.chat():
if is_closed:
continue
retry_times = 0
while retry_times <= settings.CHANNEL_RETRY_TIMES:
try:
await self.send(
text_data=json.dumps(
{"data": data, "is_finished": False, "log_id": client.log.id}, ensure_ascii=False
)
)
await asyncio.sleep(0)
break
except Disconnected:
logger.warning("[SendMessageFailed-Disconnected] Channel: %s", self.channel_name)
is_closed = True
break
except ChannelFull:
if cache.get(WS_CLOSED_KEY.format(self.channel_name)):
logger.warning("[SendMessageFailed-Disconnected] Channel: %s", self.channel_name)
is_closed = True
break
logger.warning(
"[SendMessageFailed-ChannelFull] Channel: %s; Retry: %d;", self.channel_name, retry_times
)
await asyncio.sleep(settings.CHANNEL_RETRY_SLEEP)
retry_times += 1

return is_closed

def load_data_from_cache(self, key: str) -> dict:
request_data = cache.get(key=key)
cache.delete(key=key)
if not request_data:
raise VerifyFailed()
return request_data

def get_model_inst(self, model: str) -> AIModel:
return get_object_or_404(AIModel, model=model)

# pylint: disable=R0911
def get_model_client(self, model: AIModel) -> Type[BaseClient]:
match model.provider:
case AIModelProvider.TENCENT:
if model.is_vision:
return HunYuanVisionClient
return HunYuanClient
case AIModelProvider.GOOGLE:
return GeminiClient
case AIModelProvider.BAIDU:
return QianfanClient
case AIModelProvider.OPENAI:
if model.is_vision:
return OpenAIVisionClient
return OpenAIClient
case AIModelProvider.ALIYUN:
return BaiLianClient
case AIModelProvider.MOONSHOT:
return KimiClient
case AIModelProvider.DOUBAO:
return DoubaoClient
case AIModelProvider.MIDJOURNEY:
return MidjourneyClient
case _:
raise UnexpectedProvider()
3 changes: 2 additions & 1 deletion apps/chat/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from rest_framework.decorators import action
from rest_framework.response import Response

from apps.chat.constants import MESSAGE_CACHE_KEY
from apps.chat.models import AIModel, ChatLog, SystemPreset
from apps.chat.permissions import AIModelPermission
from apps.chat.serializers import (
Expand Down Expand Up @@ -47,7 +48,7 @@ async def pre_check(self, request, *args, **kwargs):
await database_sync_to_async(get_object_or_404)(AIModel, model=request_data["model"], is_enabled=True)

# cache
cache_key = uniq_id()
cache_key = MESSAGE_CACHE_KEY.format(uniq_id())
cache.set(
key=cache_key,
value={**request_data, "user": request.user.username},
Expand Down
7 changes: 7 additions & 0 deletions entry/settings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
from pathlib import Path

from environ import environ
Expand Down Expand Up @@ -119,9 +120,15 @@
"hosts": [
f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}",
],
"channel_capacity": {
re.compile(r".*"): int(os.getenv("CHANNEL_LAYER_DEFAULT_CAPACITY", "100")),
},
},
},
}
CHANNEL_RETRY_TIMES = int(os.getenv("CHANNEL_RETRY_TIMES", "1"))
CHANNEL_RETRY_SLEEP = int(os.getenv("CHANNEL_RETRY_SLEEP", "1")) # seconds
CHANNEL_CLOSE_KEY_TIMEOUT = int(os.getenv("CHANNEL_CLOSE_KEY_TIMEOUT", "60"))

# Auth
AUTH_PASSWORD_VALIDATORS = [
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ovinc
ovinc-client==0.3.8
ovinc-client==0.3.9

# Celery
celery==5.4.0
Expand Down
4 changes: 4 additions & 0 deletions utils/consumers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from channels.generic.websocket import AsyncWebsocketConsumer as _AsyncWebsocketConsumer
from django.conf import settings
from django.core.cache import cache

from apps.chat.constants import WS_CLOSED_KEY
from utils.connections import connections_handler


Expand All @@ -9,5 +12,6 @@ async def connect(self):
connections_handler.add_connection(self.channel_name, self.scope["client"][0])

async def disconnect(self, code):
cache.set(key=WS_CLOSED_KEY.format(self.channel_name), value=True, timeout=settings.CHANNEL_CLOSE_KEY_TIMEOUT)
await super().disconnect(code)
connections_handler.remove_connection(self.channel_name, self.scope["client"][0])

0 comments on commit 1202be7

Please sign in to comment.