Skip to content

Commit

Permalink
feat(asgi): replace async to sync
Browse files Browse the repository at this point in the history
  • Loading branch information
OrenZhang committed Jan 8, 2025
1 parent 3ff8627 commit 1aceef6
Show file tree
Hide file tree
Showing 20 changed files with 149 additions and 194 deletions.
2 changes: 1 addition & 1 deletion .cruft.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"template": "https://github.com/OVINC-CN/DevTemplateDjango.git",
"commit": "1e0028fba6cb111a2ef69038ca07ac74fc95a219",
"commit": "c57436b9cfec6091334edb5b2115ea498911d51c",
"checkout": "main",
"context": {
"cookiecutter": {
Expand Down
42 changes: 20 additions & 22 deletions apps/chat/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
import base64
import datetime

from channels.db import database_sync_to_async
from django.conf import settings
from django.contrib.auth import get_user_model
from django.shortcuts import get_object_or_404
from django.utils import timezone
from django.utils.translation import gettext
from httpx import AsyncClient, Client
from httpx import Client
from openai import OpenAI
from openai.types import CompletionUsage
from opentelemetry import trace
Expand Down Expand Up @@ -49,7 +48,7 @@ def __init__(self, user: str, model: str, messages: list[Message]):
)
self.tracer = trace.get_tracer(self.__class__.__name__)

async def chat(self, *args, **kwargs) -> any:
def chat(self, *args, **kwargs) -> any:
"""
Chat
"""
Expand All @@ -70,30 +69,29 @@ async def chat(self, *args, **kwargs) -> any:
audit_content = str(content)
# call audit api
client = COSClient()
await client.text_audit(user=self.user, content=audit_content, data_id=self.log.id)
client.text_audit(user=self.user, content=audit_content, data_id=self.log.id)
for image in audit_image:
await client.image_audit(user=self.user, image_url=image, data_id=self.log.id)
client.image_audit(user=self.user, image_url=image, data_id=self.log.id)
except Exception as e:
await self.record()
self.record()
raise e

with self.start_span(SpanType.CHAT, SpanKind.SERVER):
try:
async for text in self._chat(*args, **kwargs):
yield text
yield from self._chat(*args, **kwargs)
except Exception as e:
await self.record()
self.record()
raise e

@abc.abstractmethod
async def _chat(self, *args, **kwargs) -> any:
def _chat(self, *args, **kwargs) -> any:
"""
Chat
"""

raise NotImplementedError()

async def record(self, prompt_tokens: int = 0, completion_tokens: int = 0, vision_count: int = 0) -> None:
def record(self, prompt_tokens: int = 0, completion_tokens: int = 0, vision_count: int = 0) -> None:
if not self.log:
return
# calculate tokens
Expand All @@ -107,11 +105,11 @@ async def record(self, prompt_tokens: int = 0, completion_tokens: int = 0, visio
self.log.request_unit_price = self.model_inst.request_price
# save
self.log.finished_at = int(timezone.now().timestamp() * 1000)
await database_sync_to_async(self.log.save)()
self.log.save()
# calculate usage
from apps.chat.tasks import calculate_usage_limit # pylint: disable=C0415

await database_sync_to_async(calculate_usage_limit)(log_id=self.log.id) # pylint: disable=E1120
calculate_usage_limit(log_id=self.log.id) # pylint: disable=E1120

def start_span(self, name: str, kind: SpanKind, **kwargs) -> Span:
span: Span = self.tracer.start_as_current_span(name=name, kind=kind, **kwargs)
Expand Down Expand Up @@ -161,8 +159,8 @@ def extra_body(self) -> dict | None:
def extra_chat_params(self) -> dict[str, any]:
return {}

async def _chat(self, *args, **kwargs) -> any:
image_count = await self.format_message()
def _chat(self, *args, **kwargs) -> any:
image_count = self.format_message()
client = OpenAI(api_key=self.api_key, base_url=self.base_url, http_client=self.http_client)
try:
with self.start_span(SpanType.API, SpanKind.CLIENT):
Expand Down Expand Up @@ -195,9 +193,9 @@ async def _chat(self, *args, **kwargs) -> any:
prompt_tokens, completion_tokens = self.get_tokens(chunk.usage)
if chunk.id:
self.log.chat_id = chunk.id
await self.record(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, vision_count=image_count)
self.record(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, vision_count=image_count)

async def format_message(self) -> int:
def format_message(self) -> int:
image_count = 0
for message in self.messages:
message: Message
Expand All @@ -207,19 +205,19 @@ async def format_message(self) -> int:
content: MessageContent
if content.type != MessageContentType.IMAGE_URL or not content.image_url:
continue
content.image_url.url = await self.convert_url_to_base64(content.image_url.url)
content.image_url.url = self.convert_url_to_base64(content.image_url.url)
image_count += 1
return image_count

async def convert_url_to_base64(self, url: str) -> str:
client = AsyncClient(http2=True, timeout=settings.LOAD_IMAGE_TIMEOUT)
def convert_url_to_base64(self, url: str) -> str:
client = Client(http2=True, timeout=settings.LOAD_IMAGE_TIMEOUT)
try:
response = await client.get(url)
response = client.get(url)
if response.status_code == 200:
return f"data:image/webp;base64,{base64.b64encode(response.content).decode()}"
raise FileExtractFailed(gettext("Parse Image To Base64 Failed"))
finally:
await client.aclose()
client.close()

def get_tokens(self, usage: CompletionUsage) -> (int, int):
return (
Expand Down
25 changes: 12 additions & 13 deletions apps/chat/client/midjourney.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import asyncio
import time
import uuid

from httpx import AsyncClient
from httpx import Client
from opentelemetry.trace import SpanKind
from ovinc_client.core.logger import logger
from rest_framework import status
Expand All @@ -20,8 +19,8 @@ class MidjourneyClient(BaseClient):
Midjourney Client
"""

async def _chat(self, *args, **kwargs) -> any:
client = AsyncClient(
def _chat(self, *args, **kwargs) -> any:
client = Client(
http2=True,
headers={"Authorization": f"Bearer {self.model_settings.get("api_key")}"},
base_url=self.model_settings.get("base_url"),
Expand All @@ -32,34 +31,34 @@ async def _chat(self, *args, **kwargs) -> any:
try:
with self.start_span(SpanType.API, SpanKind.CLIENT):
# submit job
response = await client.post(
response = client.post(
url=self.model_settings.get("imaging_path"), json={"prompt": self.messages[-1].content}
)
result_id = response.json()["result"]
# wait for result
start_time = time.time()
while time.time() - start_time < self.model_settings.get("wait_timeout", 600):
result = await client.get(url=self.model_settings.get("result_path").format(id=result_id))
result = client.get(url=self.model_settings.get("result_path").format(id=result_id))
result_data = result.json()
# if not finished, continue loop
if result_data["status"] not in [MidjourneyResult.FAILURE, MidjourneyResult.SUCCESS]:
yield ""
await asyncio.sleep(self.model_settings.get("no_result_sleep", 5))
time.sleep(self.model_settings.get("no_result_sleep", 5))
continue
# if failed
if result_data["status"] == MidjourneyResult.FAILURE:
yield format_error(GenerateFailed(result_data.get("failReason") or None))
await self.record()
self.record()
break
with self.start_span(SpanType.CHUNK, SpanKind.SERVER):
# record
await self.record(completion_tokens=1)
self.record(completion_tokens=1)
# use first success picture
message_url = result_data["imageUrl"]
image_resp = await client.get(message_url)
image_resp = client.get(message_url)
if image_resp.status_code != status.HTTP_200_OK:
raise LoadImageFailed()
url = await COSClient().put_object(
url = COSClient().put_object(
file=image_resp.content,
file_name=f"{uuid.uuid4().hex}.{image_resp.headers['content-type'].split('/')[-1]}",
)
Expand All @@ -68,6 +67,6 @@ async def _chat(self, *args, **kwargs) -> any:
except Exception as err: # pylint: disable=W0718
logger.exception("[GenerateContentFailed] %s", err)
yield format_error(err)
await self.record()
self.record()
finally:
await client.aclose()
client.close()
17 changes: 7 additions & 10 deletions apps/chat/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from apps.chat.exceptions import VerifyFailed
from apps.chat.serializers import OpenAIChatRequestSerializer
from apps.chat.tasks import async_reply
from utils.consumers import AsyncWebsocketConsumer
from utils.consumers import WebsocketConsumer

USER_MODEL: User = get_user_model()


class ChatConsumer(AsyncWebsocketConsumer):
async def receive(self, text_data=None, *args, **kwargs):
class ChatConsumer(WebsocketConsumer):
def receive(self, text_data=None, *args, **kwargs):
# load input
try:
data = json.loads(text_data)
Expand All @@ -27,11 +27,8 @@ async def receive(self, text_data=None, *args, **kwargs):
# async chat
async_reply.apply_async(kwargs={"channel_name": self.channel_name, "key": data["key"]})

async def chat_send(self, event: dict):
await self.send(text_data=event["text_data"])
def chat_send(self, event: dict):
self.send(text_data=event["text_data"])

async def chat_close(self, event: dict):
await self.close()

async def disconnect(self, code):
await super().disconnect(code)
def chat_close(self, event: dict):
self.close()
39 changes: 19 additions & 20 deletions apps/chat/consumers_async.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import asyncio
import json
import time
from typing import Type

from asgiref.sync import async_to_sync
from autobahn.exception import Disconnected
from channels.db import database_sync_to_async
from channels.exceptions import ChannelFull
from channels.layers import get_channel_layer
from channels_redis.core import RedisChannelLayer
Expand All @@ -30,37 +30,37 @@ def __init__(self, channel_name: str, key: str):
self.channel_name = channel_name
self.key = key

async def chat(self) -> None:
def chat(self) -> None:
try:
is_closed = await self.do_chat()
is_closed = self.do_chat()
if is_closed:
return
await self.send(text_data=json.dumps({"is_finished": True}, ensure_ascii=False))
self.send(text_data=json.dumps({"is_finished": True}, ensure_ascii=False))
except Exception as err: # pylint: disable=W0718
logger.exception("[ChatError] %s", err)
await self.send(text_data=json.dumps({"data": format_error(err), "is_finished": True}, ensure_ascii=False))
await self.close()
self.send(text_data=json.dumps({"data": format_error(err), "is_finished": True}, ensure_ascii=False))
self.close()

async def send(self, text_data: str):
await self.channel_layer.send(
def send(self, text_data: str):
async_to_sync(self.channel_layer.send)(
channel=self.channel_name,
message={"type": "chat.send", "text_data": text_data},
)

async def close(self):
await self.channel_layer.send(
def close(self):
async_to_sync(self.channel_layer.send)(
channel=self.channel_name,
message={
"type": "chat.close",
},
)

async def do_chat(self) -> bool:
def do_chat(self) -> bool:
# cache
request_data = self.load_data_from_cache(self.key)

# model
model = await database_sync_to_async(self.get_model_inst)(request_data.model)
model = self.get_model_inst(request_data.model)

# get client
client = self.get_model_client(model)
Expand All @@ -71,26 +71,25 @@ async def do_chat(self) -> bool:
return True

# init client
client = await database_sync_to_async(client)(
client = client(
user=request_data.user,
model=request_data.model,
messages=request_data.messages,
)

# response
is_closed = False
async for data in client.chat():
for data in client.chat():
if is_closed:
continue
retry_times = 0
while retry_times <= settings.CHANNEL_RETRY_TIMES:
try:
await self.send(
self.send(
text_data=json.dumps(
{"data": data, "is_finished": False, "log_id": client.log.id}, ensure_ascii=False
)
)
await asyncio.sleep(0)
break
except Disconnected:
logger.warning("[SendMessageFailed-Disconnected] Channel: %s", self.channel_name)
Expand All @@ -104,7 +103,7 @@ async def do_chat(self) -> bool:
logger.warning(
"[SendMessageFailed-ChannelFull] Channel: %s; Retry: %d;", self.channel_name, retry_times
)
await asyncio.sleep(settings.CHANNEL_RETRY_SLEEP)
time.sleep(settings.CHANNEL_RETRY_SLEEP)
retry_times += 1

return is_closed
Expand Down Expand Up @@ -139,8 +138,8 @@ def __init__(self, key: str):
super().__init__("", key)
self.message = ""

async def send(self, text_data: str):
def send(self, text_data: str):
self.message += json.loads(text_data).get("data", "")

async def close(self):
def close(self):
return
10 changes: 3 additions & 7 deletions apps/chat/permissions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from channels.db import database_sync_to_async
from rest_framework.permissions import BasePermission

from apps.chat.exceptions import NoModelPermission
Expand All @@ -13,18 +12,15 @@ class AIModelPermission(BasePermission):
"""

# pylint: disable=W0236
async def has_permission(self, request, view):
allowed = await database_sync_to_async(AIModel.check_user_permission)(
request.user, model=str(request.data.get("model", ""))
)
def has_permission(self, request, view):
allowed = AIModel.check_user_permission(request.user, model=str(request.data.get("model", "")))
if not allowed:
raise NoModelPermission()
balance = await self.load_balance(request=request)
balance = self.load_balance(request=request)
if balance > 0:
return True
raise NoBalanceException()

@database_sync_to_async
def load_balance(self, request) -> float:
try:
return Wallet.objects.get(user=request.user).balance
Expand Down
Loading

0 comments on commit 1aceef6

Please sign in to comment.