Skip to content

Commit

Permalink
feat: support to use midjourney
Browse files Browse the repository at this point in the history
  • Loading branch information
OrenZhang committed Jul 30, 2024
1 parent 05d44bb commit cf15856
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 4 deletions.
2 changes: 2 additions & 0 deletions apps/chat/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from apps.chat.client.gemini import GeminiClient
from apps.chat.client.hunyuan import HunYuanClient, HunYuanVisionClient
from apps.chat.client.kimi import KimiClient
from apps.chat.client.midjourney import MidjourneyClient
from apps.chat.client.openai import OpenAIClient, OpenAIVisionClient
from apps.chat.client.qianfan import QianfanClient

Expand All @@ -16,4 +17,5 @@
"BaiLianClient",
"KimiClient",
"DoubaoClient",
"MidjourneyClient",
)
75 changes: 75 additions & 0 deletions apps/chat/client/midjourney.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import asyncio
import time
import uuid

from channels.db import database_sync_to_async
from django.conf import settings
from django.utils import timezone
from httpx import AsyncClient
from ovinc_client.core.logger import logger
from rest_framework import status

from apps.chat.client.base import BaseClient
from apps.chat.constants import MidjourneyResult
from apps.chat.exceptions import GenerateFailed, LoadImageFailed
from apps.cos.client import COSClient


class MidjourneyClient(BaseClient):
"""
Midjourney Client
"""

async def chat(self, *args, **kwargs) -> any:
client = AsyncClient(
http2=True,
headers={"Authorization": f"Bearer {settings.MIDJOURNEY_API_KEY}"},
base_url=settings.MIDJOURNEY_API_BASE_URL,
proxy=settings.OPENAI_HTTP_PROXY_URL or None,
)
# call midjourney api
try:
# submit job
response = await client.post(
url=settings.MIDJOURNEY_IMAGINE_API_PATH, json={"prompt": self.messages[-1]["content"]}
)
result_id = response.json()["result"]
# wait for result
start_time = time.time()
while time.time() - start_time < settings.MIDJOURNEY_IMAGE_JOB_TIMEOUT:
result = await client.get(url=settings.MIDJOURNEY_TASK_RESULT_API_PATH.format(id=result_id))
result_data = result.json()
# if not finished, continue loop
if result_data["status"] not in [MidjourneyResult.FAILURE, MidjourneyResult.SUCCESS]:
yield ""
await asyncio.sleep(settings.MIDJOURNEY_IMAGE_JOB_INTERVAL)
continue
# if failed
if result_data["status"] == MidjourneyResult.FAILURE:
yield str(result_data.get("failReason") or GenerateFailed())
break
# record
await self.record()
# use first success picture
message_url = result_data["imageUrl"]
image_resp = await client.get(message_url)
if image_resp.status_code != status.HTTP_200_OK:
raise LoadImageFailed()
url = await COSClient().put_object(
file=image_resp.content,
file_name=f"{uuid.uuid4().hex}.{image_resp.headers['content-type'].split('/')[-1]}",
)
yield f"![output]({url}?{settings.QCLOUD_COS_IMAGE_STYLE})"
break
except Exception as err: # pylint: disable=W0718
logger.exception("[GenerateContentFailed] %s", err)
yield str(GenerateFailed())
finally:
await client.aclose()

# pylint: disable=W0221,R1710,W0236
async def record(self) -> None:
self.log.completion_tokens = 1
self.log.completion_token_unit_price = self.model_inst.completion_price
self.log.finished_at = int(timezone.now().timestamp() * 1000)
await database_sync_to_async(self.log.save)()
14 changes: 14 additions & 0 deletions apps/chat/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class AIModelProvider(TextChoices):
ALIYUN = "aliyun", gettext_lazy("Aliyun")
MOONSHOT = "moonshot", gettext_lazy("Moonshot")
DOUBAO = "doubao", gettext_lazy("Doubao")
MIDJOURNEY = "midjourney", gettext_lazy("Midjourney")


class VisionSize(TextChoices):
Expand Down Expand Up @@ -130,3 +131,16 @@ class ToolType(TextChoices):
"""

FUNCTION = "function", gettext_lazy("Function")


class MidjourneyResult(TextChoices):
"""
Midjourney Result
"""

NOT_START = "NOT_START", gettext_lazy("Not Start")
SUBMITTED = "SUBMITTED", gettext_lazy("Submitted")
MODAL = "MODAL", gettext_lazy("Modal")
IN_PROGRESS = "IN_PROGRESS", gettext_lazy("In Progress")
FAILURE = "FAILURE", gettext_lazy("Failure")
SUCCESS = "SUCCESS", gettext_lazy("Success")
3 changes: 3 additions & 0 deletions apps/chat/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
HunYuanClient,
HunYuanVisionClient,
KimiClient,
MidjourneyClient,
OpenAIClient,
OpenAIVisionClient,
QianfanClient,
Expand Down Expand Up @@ -109,5 +110,7 @@ def get_model_client(self, model: AIModel) -> Type[BaseClient]:
return KimiClient
case AIModelProvider.DOUBAO:
return DoubaoClient
case AIModelProvider.MIDJOURNEY:
return MidjourneyClient
case _:
raise UnexpectedProvider()
8 changes: 8 additions & 0 deletions entry/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,11 @@
# Doubao
DOUBAO_API_BASE_URL = os.getenv("DOUBAO_API_BASE_URL", "")
DOUBAO_API_KEY = os.getenv("DOUBAO_API_KEY", "")

# Midjourney
MIDJOURNEY_API_BASE_URL = os.getenv("MIDJOURNEY_API_BASE_URL", "")
MIDJOURNEY_API_KEY = os.getenv("MIDJOURNEY_API_KEY", "")
MIDJOURNEY_IMAGINE_API_PATH = os.getenv("MIDJOURNEY_IMAGINE_API_PATH", "/fast/mj/submit/imagine")
MIDJOURNEY_TASK_RESULT_API_PATH = os.getenv("MIDJOURNEY_TASK_RESULT_API_PATH", "/mj/task/{id}/fetch")
MIDJOURNEY_IMAGE_JOB_INTERVAL = int(os.getenv("MIDJOURNEY_IMAGE_JOB_INTERVAL", "5"))
MIDJOURNEY_IMAGE_JOB_TIMEOUT = int(os.getenv("MIDJOURNEY_IMAGE_JOB_TIMEOUT", "600"))
Binary file modified locale/zh_Hans/LC_MESSAGES/django.mo
Binary file not shown.
29 changes: 25 additions & 4 deletions locale/zh_Hans/LC_MESSAGES/django.po
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2024-07-02 15:40+0800\n"
"POT-Creation-Date: 2024-07-30 16:39+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <[email protected]>\n"
Expand Down Expand Up @@ -73,6 +73,12 @@ msgstr "Aliyun"
msgid "Moonshot"
msgstr "Moonshot"

msgid "Doubao"
msgstr "Doubao"

msgid "Midjourney"
msgstr "Midjourney"

msgid "1024x1024"
msgstr "1024x1024"

Expand Down Expand Up @@ -115,6 +121,24 @@ msgstr "完成"
msgid "Function"
msgstr "函数"

msgid "Not Start"
msgstr "未开始"

msgid "Submitted"
msgstr "已提交"

msgid "Modal"
msgstr "模态"

msgid "In Progress"
msgstr "进程中"

msgid "Failure"
msgstr "失败"

msgid "Success"
msgstr "成功"

msgid "Unauthorized Model"
msgstr "未授权的模型"

Expand Down Expand Up @@ -364,9 +388,6 @@ msgstr "钱包充值"
msgid "TCaptcha Verify Failed"
msgstr "安全检查失败"

msgid "Success"
msgstr "成功"

msgid "Refund"
msgstr "退款"

Expand Down

0 comments on commit cf15856

Please sign in to comment.