From 1aceef6c798c8521a592a3860184b7c0ea537ffe Mon Sep 17 00:00:00 2001 From: orenzhang Date: Wed, 8 Jan 2025 17:29:27 +0800 Subject: [PATCH] feat(asgi): replace async to sync --- .cruft.json | 2 +- apps/chat/client/base.py | 42 +++++++++++++--------------- apps/chat/client/midjourney.py | 25 ++++++++--------- apps/chat/consumers.py | 17 +++++------- apps/chat/consumers_async.py | 39 +++++++++++++------------- apps/chat/permissions.py | 10 ++----- apps/chat/serializers.py | 28 ++++++++----------- apps/chat/tasks.py | 3 +- apps/chat/views.py | 51 ++++++++++++---------------------- apps/cos/client.py | 8 +++--- apps/cos/serializers.py | 8 ++---- apps/cos/views.py | 6 ++-- apps/home/views.py | 8 ++++-- apps/wallet/serializers.py | 9 +++--- apps/wallet/views.py | 35 +++++++++-------------- bin/run.sh | 4 +-- requirements.txt | 4 +-- utils/consumers.py | 12 ++++---- utils/wxpay/api.py | 22 +++++++-------- utils/wxpay/utils.py | 10 +++---- 20 files changed, 149 insertions(+), 194 deletions(-) diff --git a/.cruft.json b/.cruft.json index 9cd4ea8..a6550fd 100644 --- a/.cruft.json +++ b/.cruft.json @@ -1,6 +1,6 @@ { "template": "https://github.com/OVINC-CN/DevTemplateDjango.git", - "commit": "1e0028fba6cb111a2ef69038ca07ac74fc95a219", + "commit": "c57436b9cfec6091334edb5b2115ea498911d51c", "checkout": "main", "context": { "cookiecutter": { diff --git a/apps/chat/client/base.py b/apps/chat/client/base.py index 40ef65b..1d76cc8 100644 --- a/apps/chat/client/base.py +++ b/apps/chat/client/base.py @@ -2,13 +2,12 @@ import base64 import datetime -from channels.db import database_sync_to_async from django.conf import settings from django.contrib.auth import get_user_model from django.shortcuts import get_object_or_404 from django.utils import timezone from django.utils.translation import gettext -from httpx import AsyncClient, Client +from httpx import Client from openai import OpenAI from openai.types import CompletionUsage from opentelemetry import trace @@ -49,7 +48,7 @@ def __init__(self, user: str, model: str, messages: list[Message]): ) self.tracer = trace.get_tracer(self.__class__.__name__) - async def chat(self, *args, **kwargs) -> any: + def chat(self, *args, **kwargs) -> any: """ Chat """ @@ -70,30 +69,29 @@ async def chat(self, *args, **kwargs) -> any: audit_content = str(content) # call audit api client = COSClient() - await client.text_audit(user=self.user, content=audit_content, data_id=self.log.id) + client.text_audit(user=self.user, content=audit_content, data_id=self.log.id) for image in audit_image: - await client.image_audit(user=self.user, image_url=image, data_id=self.log.id) + client.image_audit(user=self.user, image_url=image, data_id=self.log.id) except Exception as e: - await self.record() + self.record() raise e with self.start_span(SpanType.CHAT, SpanKind.SERVER): try: - async for text in self._chat(*args, **kwargs): - yield text + yield from self._chat(*args, **kwargs) except Exception as e: - await self.record() + self.record() raise e @abc.abstractmethod - async def _chat(self, *args, **kwargs) -> any: + def _chat(self, *args, **kwargs) -> any: """ Chat """ raise NotImplementedError() - async def record(self, prompt_tokens: int = 0, completion_tokens: int = 0, vision_count: int = 0) -> None: + def record(self, prompt_tokens: int = 0, completion_tokens: int = 0, vision_count: int = 0) -> None: if not self.log: return # calculate tokens @@ -107,11 +105,11 @@ async def record(self, prompt_tokens: int = 0, completion_tokens: int = 0, visio self.log.request_unit_price = self.model_inst.request_price # save self.log.finished_at = int(timezone.now().timestamp() * 1000) - await database_sync_to_async(self.log.save)() + self.log.save() # calculate usage from apps.chat.tasks import calculate_usage_limit # pylint: disable=C0415 - await database_sync_to_async(calculate_usage_limit)(log_id=self.log.id) # pylint: disable=E1120 + calculate_usage_limit(log_id=self.log.id) # pylint: disable=E1120 def start_span(self, name: str, kind: SpanKind, **kwargs) -> Span: span: Span = self.tracer.start_as_current_span(name=name, kind=kind, **kwargs) @@ -161,8 +159,8 @@ def extra_body(self) -> dict | None: def extra_chat_params(self) -> dict[str, any]: return {} - async def _chat(self, *args, **kwargs) -> any: - image_count = await self.format_message() + def _chat(self, *args, **kwargs) -> any: + image_count = self.format_message() client = OpenAI(api_key=self.api_key, base_url=self.base_url, http_client=self.http_client) try: with self.start_span(SpanType.API, SpanKind.CLIENT): @@ -195,9 +193,9 @@ async def _chat(self, *args, **kwargs) -> any: prompt_tokens, completion_tokens = self.get_tokens(chunk.usage) if chunk.id: self.log.chat_id = chunk.id - await self.record(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, vision_count=image_count) + self.record(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, vision_count=image_count) - async def format_message(self) -> int: + def format_message(self) -> int: image_count = 0 for message in self.messages: message: Message @@ -207,19 +205,19 @@ async def format_message(self) -> int: content: MessageContent if content.type != MessageContentType.IMAGE_URL or not content.image_url: continue - content.image_url.url = await self.convert_url_to_base64(content.image_url.url) + content.image_url.url = self.convert_url_to_base64(content.image_url.url) image_count += 1 return image_count - async def convert_url_to_base64(self, url: str) -> str: - client = AsyncClient(http2=True, timeout=settings.LOAD_IMAGE_TIMEOUT) + def convert_url_to_base64(self, url: str) -> str: + client = Client(http2=True, timeout=settings.LOAD_IMAGE_TIMEOUT) try: - response = await client.get(url) + response = client.get(url) if response.status_code == 200: return f"data:image/webp;base64,{base64.b64encode(response.content).decode()}" raise FileExtractFailed(gettext("Parse Image To Base64 Failed")) finally: - await client.aclose() + client.close() def get_tokens(self, usage: CompletionUsage) -> (int, int): return ( diff --git a/apps/chat/client/midjourney.py b/apps/chat/client/midjourney.py index e7599af..ae26703 100644 --- a/apps/chat/client/midjourney.py +++ b/apps/chat/client/midjourney.py @@ -1,8 +1,7 @@ -import asyncio import time import uuid -from httpx import AsyncClient +from httpx import Client from opentelemetry.trace import SpanKind from ovinc_client.core.logger import logger from rest_framework import status @@ -20,8 +19,8 @@ class MidjourneyClient(BaseClient): Midjourney Client """ - async def _chat(self, *args, **kwargs) -> any: - client = AsyncClient( + def _chat(self, *args, **kwargs) -> any: + client = Client( http2=True, headers={"Authorization": f"Bearer {self.model_settings.get("api_key")}"}, base_url=self.model_settings.get("base_url"), @@ -32,34 +31,34 @@ async def _chat(self, *args, **kwargs) -> any: try: with self.start_span(SpanType.API, SpanKind.CLIENT): # submit job - response = await client.post( + response = client.post( url=self.model_settings.get("imaging_path"), json={"prompt": self.messages[-1].content} ) result_id = response.json()["result"] # wait for result start_time = time.time() while time.time() - start_time < self.model_settings.get("wait_timeout", 600): - result = await client.get(url=self.model_settings.get("result_path").format(id=result_id)) + result = client.get(url=self.model_settings.get("result_path").format(id=result_id)) result_data = result.json() # if not finished, continue loop if result_data["status"] not in [MidjourneyResult.FAILURE, MidjourneyResult.SUCCESS]: yield "" - await asyncio.sleep(self.model_settings.get("no_result_sleep", 5)) + time.sleep(self.model_settings.get("no_result_sleep", 5)) continue # if failed if result_data["status"] == MidjourneyResult.FAILURE: yield format_error(GenerateFailed(result_data.get("failReason") or None)) - await self.record() + self.record() break with self.start_span(SpanType.CHUNK, SpanKind.SERVER): # record - await self.record(completion_tokens=1) + self.record(completion_tokens=1) # use first success picture message_url = result_data["imageUrl"] - image_resp = await client.get(message_url) + image_resp = client.get(message_url) if image_resp.status_code != status.HTTP_200_OK: raise LoadImageFailed() - url = await COSClient().put_object( + url = COSClient().put_object( file=image_resp.content, file_name=f"{uuid.uuid4().hex}.{image_resp.headers['content-type'].split('/')[-1]}", ) @@ -68,6 +67,6 @@ async def _chat(self, *args, **kwargs) -> any: except Exception as err: # pylint: disable=W0718 logger.exception("[GenerateContentFailed] %s", err) yield format_error(err) - await self.record() + self.record() finally: - await client.aclose() + client.close() diff --git a/apps/chat/consumers.py b/apps/chat/consumers.py index a16835a..c97d75e 100644 --- a/apps/chat/consumers.py +++ b/apps/chat/consumers.py @@ -6,13 +6,13 @@ from apps.chat.exceptions import VerifyFailed from apps.chat.serializers import OpenAIChatRequestSerializer from apps.chat.tasks import async_reply -from utils.consumers import AsyncWebsocketConsumer +from utils.consumers import WebsocketConsumer USER_MODEL: User = get_user_model() -class ChatConsumer(AsyncWebsocketConsumer): - async def receive(self, text_data=None, *args, **kwargs): +class ChatConsumer(WebsocketConsumer): + def receive(self, text_data=None, *args, **kwargs): # load input try: data = json.loads(text_data) @@ -27,11 +27,8 @@ async def receive(self, text_data=None, *args, **kwargs): # async chat async_reply.apply_async(kwargs={"channel_name": self.channel_name, "key": data["key"]}) - async def chat_send(self, event: dict): - await self.send(text_data=event["text_data"]) + def chat_send(self, event: dict): + self.send(text_data=event["text_data"]) - async def chat_close(self, event: dict): - await self.close() - - async def disconnect(self, code): - await super().disconnect(code) + def chat_close(self, event: dict): + self.close() diff --git a/apps/chat/consumers_async.py b/apps/chat/consumers_async.py index b6a41ea..9c98893 100644 --- a/apps/chat/consumers_async.py +++ b/apps/chat/consumers_async.py @@ -1,9 +1,9 @@ -import asyncio import json +import time from typing import Type +from asgiref.sync import async_to_sync from autobahn.exception import Disconnected -from channels.db import database_sync_to_async from channels.exceptions import ChannelFull from channels.layers import get_channel_layer from channels_redis.core import RedisChannelLayer @@ -30,37 +30,37 @@ def __init__(self, channel_name: str, key: str): self.channel_name = channel_name self.key = key - async def chat(self) -> None: + def chat(self) -> None: try: - is_closed = await self.do_chat() + is_closed = self.do_chat() if is_closed: return - await self.send(text_data=json.dumps({"is_finished": True}, ensure_ascii=False)) + self.send(text_data=json.dumps({"is_finished": True}, ensure_ascii=False)) except Exception as err: # pylint: disable=W0718 logger.exception("[ChatError] %s", err) - await self.send(text_data=json.dumps({"data": format_error(err), "is_finished": True}, ensure_ascii=False)) - await self.close() + self.send(text_data=json.dumps({"data": format_error(err), "is_finished": True}, ensure_ascii=False)) + self.close() - async def send(self, text_data: str): - await self.channel_layer.send( + def send(self, text_data: str): + async_to_sync(self.channel_layer.send)( channel=self.channel_name, message={"type": "chat.send", "text_data": text_data}, ) - async def close(self): - await self.channel_layer.send( + def close(self): + async_to_sync(self.channel_layer.send)( channel=self.channel_name, message={ "type": "chat.close", }, ) - async def do_chat(self) -> bool: + def do_chat(self) -> bool: # cache request_data = self.load_data_from_cache(self.key) # model - model = await database_sync_to_async(self.get_model_inst)(request_data.model) + model = self.get_model_inst(request_data.model) # get client client = self.get_model_client(model) @@ -71,7 +71,7 @@ async def do_chat(self) -> bool: return True # init client - client = await database_sync_to_async(client)( + client = client( user=request_data.user, model=request_data.model, messages=request_data.messages, @@ -79,18 +79,17 @@ async def do_chat(self) -> bool: # response is_closed = False - async for data in client.chat(): + for data in client.chat(): if is_closed: continue retry_times = 0 while retry_times <= settings.CHANNEL_RETRY_TIMES: try: - await self.send( + self.send( text_data=json.dumps( {"data": data, "is_finished": False, "log_id": client.log.id}, ensure_ascii=False ) ) - await asyncio.sleep(0) break except Disconnected: logger.warning("[SendMessageFailed-Disconnected] Channel: %s", self.channel_name) @@ -104,7 +103,7 @@ async def do_chat(self) -> bool: logger.warning( "[SendMessageFailed-ChannelFull] Channel: %s; Retry: %d;", self.channel_name, retry_times ) - await asyncio.sleep(settings.CHANNEL_RETRY_SLEEP) + time.sleep(settings.CHANNEL_RETRY_SLEEP) retry_times += 1 return is_closed @@ -139,8 +138,8 @@ def __init__(self, key: str): super().__init__("", key) self.message = "" - async def send(self, text_data: str): + def send(self, text_data: str): self.message += json.loads(text_data).get("data", "") - async def close(self): + def close(self): return diff --git a/apps/chat/permissions.py b/apps/chat/permissions.py index 85f1120..1801123 100644 --- a/apps/chat/permissions.py +++ b/apps/chat/permissions.py @@ -1,4 +1,3 @@ -from channels.db import database_sync_to_async from rest_framework.permissions import BasePermission from apps.chat.exceptions import NoModelPermission @@ -13,18 +12,15 @@ class AIModelPermission(BasePermission): """ # pylint: disable=W0236 - async def has_permission(self, request, view): - allowed = await database_sync_to_async(AIModel.check_user_permission)( - request.user, model=str(request.data.get("model", "")) - ) + def has_permission(self, request, view): + allowed = AIModel.check_user_permission(request.user, model=str(request.data.get("model", ""))) if not allowed: raise NoModelPermission() - balance = await self.load_balance(request=request) + balance = self.load_balance(request=request) if balance > 0: return True raise NoBalanceException() - @database_sync_to_async def load_balance(self, request) -> float: try: return Wallet.objects.get(user=request.user).balance diff --git a/apps/chat/serializers.py b/apps/chat/serializers.py index 2ef244c..4c0e95d 100644 --- a/apps/chat/serializers.py +++ b/apps/chat/serializers.py @@ -1,7 +1,6 @@ import datetime import pytz -from adrf.serializers import ModelSerializer, Serializer from django.conf import settings from django.utils import timezone from django.utils.translation import gettext_lazy @@ -11,7 +10,7 @@ from apps.chat.models import ChatLog, ChatMessageChangeLog, SystemPreset -class OpenAIMessageSerializer(Serializer): +class OpenAIMessageSerializer(serializers.Serializer): """ OpenAI Message """ @@ -21,7 +20,7 @@ class OpenAIMessageSerializer(Serializer): content = serializers.CharField(label=gettext_lazy("Content")) -class OpenAIRequestSerializer(Serializer): +class OpenAIRequestSerializer(serializers.Serializer): """ OpenAI Request """ @@ -32,7 +31,7 @@ class OpenAIRequestSerializer(Serializer): ) -class CheckModelPermissionSerializer(Serializer): +class CheckModelPermissionSerializer(serializers.Serializer): """ Model Permission """ @@ -40,7 +39,7 @@ class CheckModelPermissionSerializer(Serializer): model = serializers.CharField(label=gettext_lazy("Model")) -class OpenAIChatRequestSerializer(Serializer): +class OpenAIChatRequestSerializer(serializers.Serializer): """ OpenAI Chat """ @@ -48,7 +47,7 @@ class OpenAIChatRequestSerializer(Serializer): key = serializers.CharField() -class SystemPresetSerializer(ModelSerializer): +class SystemPresetSerializer(serializers.ModelSerializer): """ System Preset """ @@ -58,18 +57,13 @@ class Meta: exclude = ["user", "created_at", "updated_at"] -class SerializerMethodField(serializers.SerializerMethodField): - async def ato_representation(self, value): - return super().to_representation(value) - - -class ChatLogSerializer(ModelSerializer): +class ChatLogSerializer(serializers.ModelSerializer): """ Chat Log """ - model_name = SerializerMethodField() - created_at = SerializerMethodField() + model_name = serializers.SerializerMethodField() + created_at = serializers.SerializerMethodField() class Meta: model = ChatLog @@ -94,7 +88,7 @@ def get_created_at(self, obj: ChatLog) -> str: return _datetime.strftime("%y/%m/%d %H:%M:%S") -class MessageChangeLogSerializer(ModelSerializer): +class MessageChangeLogSerializer(serializers.ModelSerializer): """ Message Change Log """ @@ -104,7 +98,7 @@ class Meta: fields = ["message_id", "action", "content"] -class ListMessageChangeLogSerializer(Serializer): +class ListMessageChangeLogSerializer(serializers.Serializer): """ List Message Change Log """ @@ -118,7 +112,7 @@ def validate_start_time(self, start_time: int) -> datetime.datetime: raise serializers.ValidationError(gettext_lazy("Invalid Start Time")) from err -class CreateMessageChangeLogSerializer(Serializer): +class CreateMessageChangeLogSerializer(serializers.Serializer): """ Create Message Change Log """ diff --git a/apps/chat/tasks.py b/apps/chat/tasks.py index 3225a1e..630c092 100644 --- a/apps/chat/tasks.py +++ b/apps/chat/tasks.py @@ -1,6 +1,5 @@ import datetime -from asgiref.sync import async_to_sync from django.conf import settings from django.db import transaction from django.db.models import F @@ -71,7 +70,7 @@ def async_reply(self, channel_name: str, key: str): """ celery_logger.info("[AsyncReply] Start %s %s %s", self.request.id, channel_name, key) - async_to_sync(AsyncConsumer(channel_name=channel_name, key=key).chat)() + AsyncConsumer(channel_name=channel_name, key=key).chat() celery_logger.info("[AsyncReply] End %s %s %s", self.request.id, channel_name, key) diff --git a/apps/chat/views.py b/apps/chat/views.py index e111250..9295491 100644 --- a/apps/chat/views.py +++ b/apps/chat/views.py @@ -1,7 +1,5 @@ import datetime -from typing import List -from channels.db import database_sync_to_async from django.conf import settings from django.core.cache import cache from django.db.models import OuterRef, Q, Subquery @@ -47,7 +45,7 @@ class ChatViewSet(MainViewSet): queryset = ChatLog.objects.all() @action(methods=["POST"], detail=False, permission_classes=[AIModelPermission]) - async def pre_check(self, request, *args, **kwargs): + def pre_check(self, request, *args, **kwargs): """ pre-check before chat """ @@ -58,9 +56,7 @@ async def pre_check(self, request, *args, **kwargs): request_data = ChatRequest(user=request.user.username, **request_serializer.validated_data) # check model - model: AIModel = await database_sync_to_async(get_object_or_404)( - AIModel, model=request_data.model, is_enabled=True - ) + model: AIModel = get_object_or_404(AIModel, model=request_data.model, is_enabled=True) # format message for message in request_data.messages: @@ -85,7 +81,7 @@ async def pre_check(self, request, *args, **kwargs): return Response(data={"key": cache_key}) @action(methods=["GET"], detail=False, authentication_classes=[SessionAuthenticate]) - async def logs(self, request, *args, **kwargs): + def logs(self, request, *args, **kwargs): """ chat logs """ @@ -104,33 +100,26 @@ async def logs(self, request, *args, **kwargs): ) page = NumPagination() - paged_queryset = await database_sync_to_async(page.paginate_queryset)( - queryset=queryset, request=request, view=self - ) + paged_queryset = page.paginate_queryset(queryset=queryset, request=request, view=self) - model_map = await self.load_model_map() + model_map = {model.model: model.name for model in AIModel.objects.all()} serializer = ChatLogSerializer(instance=paged_queryset, many=True, context={"model_map": model_map}) - return page.get_paginated_response(data=await serializer.adata) - - @database_sync_to_async - def load_model_map(self) -> dict: - models = AIModel.objects.all() - return {model.model: model.name for model in models} + return page.get_paginated_response(data=serializer.data) @action(methods=["POST"], detail=False, permission_classes=[AIModelPermission]) - async def json(self, request, *args, **kwargs): + def json(self, request, *args, **kwargs): """ JSON Mode """ # pre check - pre_response = await self.pre_check(request, *args, **kwargs) + pre_response = self.pre_check(request, *args, **kwargs) data = pre_response.data # chat consumer = JSONModeConsumer(data["key"]) - await consumer.chat() + consumer.chat() # response return Response(data={"data": consumer.message}) @@ -142,7 +131,7 @@ class AIModelViewSet(ListMixin, MainViewSet): Model """ - async def list(self, request, *args, **kwargs): + def list(self, request, *args, **kwargs): """ List Models """ @@ -163,15 +152,11 @@ async def list(self, request, *args, **kwargs): "is_vision": model.is_vision, }, } - for model in await self.list_models(request) + for model in AIModel.list_user_models(request.user) ] data.sort(key=lambda model: model["name"]) return Response(data=data) - @database_sync_to_async - def list_models(self, request) -> List[AIModel]: - return list(AIModel.list_user_models(request.user)) - class SystemPresetViewSet(ListMixin, MainViewSet): """ @@ -180,13 +165,13 @@ class SystemPresetViewSet(ListMixin, MainViewSet): queryset = SystemPreset.objects.all() - async def list(self, request, *args, **kwargs): + def list(self, request, *args, **kwargs): """ List System Presets """ queryset = SystemPreset.get_queryset().filter(Q(Q(is_public=True) | Q(user=request.user))).order_by("name") - return Response(await SystemPresetSerializer(instance=queryset, many=True).adata) + return Response(SystemPresetSerializer(instance=queryset, many=True).data) class ChatMessageChangeLogView(ListMixin, CreateMixin, MainViewSet): @@ -196,7 +181,7 @@ class ChatMessageChangeLogView(ListMixin, CreateMixin, MainViewSet): queryset = ChatMessageChangeLog.objects.all() - async def list(self, request: Request, *args, **kwargs) -> Response: + def list(self, request: Request, *args, **kwargs) -> Response: """ load messages """ @@ -219,13 +204,13 @@ async def list(self, request: Request, *args, **kwargs) -> Response: logs = logs.filter(created_at__gt=req_data["start_time"]) # page - queryset = await database_sync_to_async(self.paginate_queryset)(logs) + queryset = self.paginate_queryset(logs) # response resp_slz = MessageChangeLogSerializer(instance=queryset, many=True) - return self.get_paginated_response(await resp_slz.adata) + return self.get_paginated_response(resp_slz.data) - async def create(self, request: Request, *args, **kwargs) -> Response: + def create(self, request: Request, *args, **kwargs) -> Response: """ save message """ @@ -236,7 +221,7 @@ async def create(self, request: Request, *args, **kwargs) -> Response: req_data = req_slz.validated_data # save to db - await database_sync_to_async(ChatMessageChangeLog.objects.create)( + ChatMessageChangeLog.objects.create( user=request.user, message_id=req_data["message_id"], action=req_data["action"], diff --git a/apps/cos/client.py b/apps/cos/client.py index f4354c3..92327cd 100644 --- a/apps/cos/client.py +++ b/apps/cos/client.py @@ -87,7 +87,7 @@ def build_key(self, file_name: str) -> str: return key return self.build_key(file_name=file_name) - async def generate_cos_upload_credential(self, filename: str) -> COSCredential: + def generate_cos_upload_credential(self, filename: str) -> COSCredential: key = self.build_key(file_name=filename) tencent_cloud_api_domain = settings.QCLOUD_API_DOMAIN_TMPL.format("sts") config = { @@ -130,7 +130,7 @@ async def generate_cos_upload_credential(self, filename: str) -> COSCredential: logger.exception("[TempKeyGenerateFailed] %s", err) raise TempKeyGenerateFailed() from err - async def put_object(self, file: bytes | BytesIO, file_name: str) -> str: + def put_object(self, file: bytes | BytesIO, file_name: str) -> str: """ Upload File To COS """ @@ -148,7 +148,7 @@ async def put_object(self, file: bytes | BytesIO, file_name: str) -> str: logger.info("[UploadFileSuccess] %s %s", key, result) return f"{settings.QCLOUD_COS_URL}/{key}" - async def text_audit(self, user: USER_MODEL, content: str, data_id: str = None) -> None: + def text_audit(self, user: USER_MODEL, content: str, data_id: str = None) -> None: """ Text Audit """ @@ -174,7 +174,7 @@ async def text_audit(self, user: USER_MODEL, content: str, data_id: str = None) logger.warning("[TextAuditFailed] %s %s", data_id, response.model_dump_json()) raise SensitiveData(gettext("%s Sensitive") % response.JobsDetail.Label) - async def image_audit(self, user: USER_MODEL, image_url: str, data_id: str = None) -> None: + def image_audit(self, user: USER_MODEL, image_url: str, data_id: str = None) -> None: """ Image Audit """ diff --git a/apps/cos/serializers.py b/apps/cos/serializers.py index 3cba789..4fd7a54 100644 --- a/apps/cos/serializers.py +++ b/apps/cos/serializers.py @@ -1,13 +1,11 @@ -from adrf.serializers import Serializer from django.utils.translation import gettext, gettext_lazy -from ovinc_client.core.async_tools import SyncRunner from ovinc_client.tcaptcha.utils import TCaptchaVerify from rest_framework import serializers from core.exceptions import TCaptchaVerifyFailed -class GenerateTempSecretSerializer(Serializer): +class GenerateTempSecretSerializer(serializers.Serializer): """ Temp Secret """ @@ -17,9 +15,7 @@ class GenerateTempSecretSerializer(Serializer): def validate(self, attrs: dict) -> dict: data = super().validate(attrs) - if not SyncRunner().run( - TCaptchaVerify(user_ip=self.context.get("user_ip"), **data.pop("tcaptcha", {})).verify() - ): + if not TCaptchaVerify(user_ip=self.context.get("user_ip"), **data.pop("tcaptcha", {})).verify(): raise TCaptchaVerifyFailed() return data diff --git a/apps/cos/views.py b/apps/cos/views.py index 89a5641..f836f13 100644 --- a/apps/cos/views.py +++ b/apps/cos/views.py @@ -15,7 +15,7 @@ class COSViewSet(ListMixin, MainViewSet): COS """ - async def list(self, request, *args, **kwargs): + def list(self, request, *args, **kwargs): """ Load Configs """ @@ -28,7 +28,7 @@ async def list(self, request, *args, **kwargs): ) @action(methods=["POST"], detail=False) - async def temp_secret(self, request: Request, *args, **kwargs): + def temp_secret(self, request: Request, *args, **kwargs): """ Generate New Temp Secret for COS """ @@ -42,7 +42,7 @@ async def temp_secret(self, request: Request, *args, **kwargs): request_data = serializer.validated_data # generate - data = await COSClient().generate_cos_upload_credential(filename=request_data["filename"]) + data = COSClient().generate_cos_upload_credential(filename=request_data["filename"]) # response return Response(data=data.model_dump()) diff --git a/apps/home/views.py b/apps/home/views.py index 1c52821..34dec4c 100644 --- a/apps/home/views.py +++ b/apps/home/views.py @@ -19,7 +19,7 @@ class HomeView(MainViewSet): queryset = USER_MODEL.get_queryset() authentication_classes = [SessionAuthenticate] - async def list(self, request, *args, **kwargs): + def list(self, request, *args, **kwargs): msg = f"[{request.method}] Connect Success" return Response({"resp": msg, "user": request.user.username}) @@ -31,7 +31,7 @@ class I18nViewSet(MainViewSet): authentication_classes = [SessionAuthenticate] - async def create(self, request, *args, **kwargs): + def create(self, request, *args, **kwargs): """ Change Language """ @@ -49,5 +49,9 @@ async def create(self, request, *args, **kwargs): lang_code, max_age=settings.SESSION_COOKIE_AGE, domain=settings.SESSION_COOKIE_DOMAIN, + path=settings.SESSION_COOKIE_PATH, + secure=settings.SESSION_COOKIE_SECURE or None, + httponly=settings.SESSION_COOKIE_HTTPONLY or None, + samesite=settings.SESSION_COOKIE_SAMESITE, ) return response diff --git a/apps/wallet/serializers.py b/apps/wallet/serializers.py index e9c3063..e799e09 100644 --- a/apps/wallet/serializers.py +++ b/apps/wallet/serializers.py @@ -1,11 +1,10 @@ -from adrf.serializers import ModelSerializer, Serializer from django.utils.translation import gettext_lazy from rest_framework import serializers from apps.wallet.models import BillingHistory, Wallet -class WalletSerializer(ModelSerializer): +class WalletSerializer(serializers.ModelSerializer): """ Wallet Serializer class """ @@ -15,7 +14,7 @@ class Meta: fields = ["balance"] -class PreChargeSerializer(Serializer): +class PreChargeSerializer(serializers.Serializer): """ PreCharge Serializer class """ @@ -23,7 +22,7 @@ class PreChargeSerializer(Serializer): amount = serializers.IntegerField(label=gettext_lazy("Amount")) -class NotifySerializer(Serializer): +class NotifySerializer(serializers.Serializer): """ Notify Serializer class """ @@ -36,7 +35,7 @@ class NotifySerializer(Serializer): summary = serializers.CharField(label=gettext_lazy("Summary")) -class BillingHistorySerializer(ModelSerializer): +class BillingHistorySerializer(serializers.ModelSerializer): """ Billing History """ diff --git a/apps/wallet/views.py b/apps/wallet/views.py index 5b17cb7..9aad078 100644 --- a/apps/wallet/views.py +++ b/apps/wallet/views.py @@ -4,7 +4,6 @@ import json import qrcode -from channels.db import database_sync_to_async from django.conf import settings from django.http import HttpResponse from django.shortcuts import get_object_or_404 @@ -36,7 +35,7 @@ class WalletViewSet(MainViewSet): queryset = Wallet.objects.all() @action(methods=["GET"], detail=False, authentication_classes=[SessionAuthenticate]) - async def config(self, request, *args, **kwargs): + def config(self, request, *args, **kwargs): """ wallet config """ @@ -49,16 +48,16 @@ async def config(self, request, *args, **kwargs): ) @action(methods=["GET"], detail=False) - async def mine(self, request, *args, **kwargs): + def mine(self, request, *args, **kwargs): """ load user wallet """ - inst, _ = await database_sync_to_async(Wallet.objects.get_or_create)(user=request.user) + inst, _ = Wallet.objects.get_or_create(user=request.user) return Response(data={"balance": float(inst.balance)}) @action(methods=["POST"], detail=False) - async def pre_charge(self, request, *args, **kwargs): + def pre_charge(self, request, *args, **kwargs): """ build charge image """ @@ -69,14 +68,12 @@ async def pre_charge(self, request, *args, **kwargs): request_data = request_serializer.validated_data # build billing - billing: BillingHistory = await database_sync_to_async(BillingHistory.objects.create)( - user=request.user, amount=request_data["amount"] - ) + billing: BillingHistory = BillingHistory.objects.create(user=request.user, amount=request_data["amount"]) # create wxpay charge expire_time = timezone.now() + datetime.timedelta(seconds=settings.WXPAY_ORDER_TIMEOUT) formatted_expire_time = expire_time.strftime(settings.WXPAY_TIME_FORMAT) - prepay_data = await NaivePrePay().request( + prepay_data = NaivePrePay().request( data={ "appid": settings.WXPAY_APP_ID, "mchid": settings.WXPAY_MCHID, @@ -96,7 +93,7 @@ async def pre_charge(self, request, *args, **kwargs): return Response(data=base64.b64encode(buffered.getvalue()).decode("utf-8")) @action(methods=["POST"], detail=False, authentication_classes=[SessionAuthenticate]) - async def wxpay_notify(self, request, *args, **kwargs): + def wxpay_notify(self, request, *args, **kwargs): """ wxpay callback """ @@ -109,7 +106,7 @@ async def wxpay_notify(self, request, *args, **kwargs): request_data = request_serializer.validated_data # verify header - await WXPaySignatureTool.verify(headers=request.headers, content=raw_content) + WXPaySignatureTool.verify(headers=request.headers, content=raw_content) # decrypt data decrypt_data: bytes = WXPaySignatureTool.decrypt( @@ -120,21 +117,17 @@ async def wxpay_notify(self, request, *args, **kwargs): data = json.loads(decrypt_data.decode()) # load billing - billing: BillingHistory = await database_sync_to_async(get_object_or_404)( - BillingHistory, id=data.get("out_trade_no") - ) + billing: BillingHistory = get_object_or_404(BillingHistory, id=data.get("out_trade_no")) billing.callback_data = data billing.is_success = data["trade_state"] == TradeStatus.SUCCESS billing.callback_at = timezone.now() billing.state = data["trade_state"] - await database_sync_to_async(billing.save_to_wallet)( - update_fields=["is_success", "callback_at", "state", "callback_data"] - ) + billing.save_to_wallet(update_fields=["is_success", "callback_at", "state", "callback_data"]) return HttpResponse(status=status.HTTP_200_OK) @action(methods=["GET"], detail=False) - async def billing_history(self, request, *args, **kwargs): + def billing_history(self, request, *args, **kwargs): """ Billing History """ @@ -144,9 +137,7 @@ async def billing_history(self, request, *args, **kwargs): # page paginator = NumPagination() - page_queryset = await database_sync_to_async(paginator.paginate_queryset)( - queryset=queryset, request=request, view=self - ) + page_queryset = paginator.paginate_queryset(queryset=queryset, request=request, view=self) serializer = BillingHistorySerializer(instance=page_queryset, many=True) - return paginator.get_paginated_response(data=await serializer.adata) + return paginator.get_paginated_response(data=serializer.data) diff --git a/bin/run.sh b/bin/run.sh index f9758e4..0080b8d 100755 --- a/bin/run.sh +++ b/bin/run.sh @@ -2,6 +2,6 @@ python manage.py collectstatic --noinput python manage.py migrate --noinput -nohup python manage.py celery worker -c 4 -l INFO >/dev/stdout 2>&1 & +nohup python manage.py celery worker -c ${WORKER_COUNT:-1} -l INFO >/dev/stdout 2>&1 & nohup python manage.py celery beat -l INFO >/dev/stdout 2>&1 & -gunicorn --bind "[::]:8020" -w $WEB_PROCESSES --threads $WEB_THREADS -k uvicorn_worker.UvicornWorker --proxy-protocol --proxy-allow-from "*" --forwarded-allow-ips "*" entry.asgi:application +gunicorn --bind "[::]:8020" -w ${WEB_PROCESSES:-1} --threads ${WEB_THREADS:-10} -k uvicorn_worker.UvicornWorker --proxy-protocol --proxy-allow-from "*" --forwarded-allow-ips "*" entry.asgi:application diff --git a/requirements.txt b/requirements.txt index cef9f11..722b96b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ # ovinc -ovinc-client==0.3.16 +ovinc-client==0.4.0b5 # Celery celery==5.4.0 -# Async +# ASGI asgiref==3.8.1 channels[daphne]==4.2.0 channels_redis==4.2.1 diff --git a/utils/consumers.py b/utils/consumers.py index 3f94652..f84930e 100644 --- a/utils/consumers.py +++ b/utils/consumers.py @@ -1,4 +1,4 @@ -from channels.generic.websocket import AsyncWebsocketConsumer as _AsyncWebsocketConsumer +from channels.generic.websocket import WebsocketConsumer as _WebsocketConsumer from django.conf import settings from django.core.cache import cache @@ -6,12 +6,12 @@ from utils.connections import connections_handler -class AsyncWebsocketConsumer(_AsyncWebsocketConsumer): - async def connect(self): - await super().connect() +class WebsocketConsumer(_WebsocketConsumer): + def connect(self): + super().connect() connections_handler.add_connection(self.channel_name, self.scope["client"][0]) - async def disconnect(self, code): + def disconnect(self, code): cache.set(key=WS_CLOSED_KEY.format(self.channel_name), value=True, timeout=settings.CHANNEL_CLOSE_KEY_TIMEOUT) - await super().disconnect(code) + super().disconnect(code) connections_handler.remove_connection(self.channel_name, self.scope["client"][0]) diff --git a/utils/wxpay/api.py b/utils/wxpay/api.py index ca064e7..b72949b 100644 --- a/utils/wxpay/api.py +++ b/utils/wxpay/api.py @@ -31,14 +31,14 @@ def request_path(self) -> str: def url_keys(self) -> List[str]: return [] - async def request(self, url_params: dict = None, data: dict = None) -> dict: + def request(self, url_params: dict = None, data: dict = None) -> dict: # build params - url = await self.build_url(url_params=url_params) - headers = await self.build_headers(url=url, data=data) + url = self.build_url(url_params=url_params) + headers = self.build_headers(url=url, data=data) # call api - client = httpx.AsyncClient(http2=True, headers=headers) + client = httpx.Client(http2=True, headers=headers) try: - response = await client.request(method=self.request_method, url=url, json=data) + response = client.request(method=self.request_method, url=url, json=data) logger.info( "[WxPayAPIResult] Method: %s; Path: %s; Status: %s", self.request_method, @@ -51,7 +51,7 @@ async def request(self, url_params: dict = None, data: dict = None) -> dict: ) raise WxPayAPIException() from err finally: - await client.aclose() + client.close() # parse response if response.status_code >= status.HTTP_400_BAD_REQUEST: logger.exception( @@ -66,19 +66,17 @@ async def request(self, url_params: dict = None, data: dict = None) -> dict: # verify if not self.verify_response: return response.json() - await WXPaySignatureTool.verify(headers=response.headers, content=response.content) + WXPaySignatureTool.verify(headers=response.headers, content=response.content) return response.json() - async def build_url(self, url_params: dict) -> str: + def build_url(self, url_params: dict) -> str: url = f"{settings.WXPAY_API_BASE_URL}{self.request_path}" if self.url_keys: url = url.format(**{key: url_params[key] for key in self.url_keys}) return url - async def build_headers(self, url: str, data: dict = None) -> Dict[str, str]: - signature = await WXPaySignatureTool.generate( - request_method=self.request_method, request_url=url, request_body=data - ) + def build_headers(self, url: str, data: dict = None) -> Dict[str, str]: + signature = WXPaySignatureTool.generate(request_method=self.request_method, request_url=url, request_body=data) return {"Authorization": signature} diff --git a/utils/wxpay/utils.py b/utils/wxpay/utils.py index 79d2f6a..6b7a00a 100644 --- a/utils/wxpay/utils.py +++ b/utils/wxpay/utils.py @@ -40,7 +40,7 @@ class WXPaySignatureTool: """ @classmethod - async def generate(cls, request_method: str, request_url: str, request_body: dict) -> str: + def generate(cls, request_method: str, request_url: str, request_body: dict) -> str: """ Generate signature for WXPay API """ @@ -78,7 +78,7 @@ async def generate(cls, request_method: str, request_url: str, request_body: dic ) @classmethod - async def verify(cls, headers: dict, content: bytes) -> None: + def verify(cls, headers: dict, content: bytes) -> None: """ Verify Request is from WXPay """ @@ -89,7 +89,7 @@ async def verify(cls, headers: dict, content: bytes) -> None: content=content.decode(), ) signature = base64.b64decode(headers.get("wechatpay-signature", "").encode()) - wxpay_cert = await cls.load_wxpay_cert(serial_no=headers.get("wechatpay-serial", "")) + wxpay_cert = cls.load_wxpay_cert(serial_no=headers.get("wechatpay-serial", "")) try: wxpay_cert.public_key.verify( signature=signature, data=raw_info.encode(), padding=PKCS1v15(), algorithm=SHA256() @@ -98,7 +98,7 @@ async def verify(cls, headers: dict, content: bytes) -> None: raise WxPayInsecureResponse() from err @classmethod - async def load_wxpay_cert(cls, serial_no: str) -> WXPayCert: + def load_wxpay_cert(cls, serial_no: str) -> WXPayCert: """ Load WXPay Cert """ @@ -113,7 +113,7 @@ async def load_wxpay_cert(cls, serial_no: str) -> WXPayCert: from utils.wxpay.api import GetCerts # load cert - cert_data: dict = await GetCerts().request() + cert_data: dict = GetCerts().request() for cert in cert_data["data"]: # check serial id if not cert["serial_no"] == serial_no: