Skip to content

Commit

Permalink
Merge pull request #116 from MxEmerson/master
Browse files Browse the repository at this point in the history
feat: 牛牛夺舍plus
  • Loading branch information
MistEO authored Sep 8, 2024
2 parents a1e31df + 45c90f4 commit 21e4292
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 121 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
id: metadata
with:
images: |
misteo/pallas-bot
${{ github.repository_owner }}/pallas-bot
tags: |
type=raw,value=latest
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ nonebot-plugin-gocqhttp = "^0.6.4"
nb-cli = "^1.2.0"

[tool.nonebot]
plugins = ["nonebot_plugin_apscheduler", "nonebot_plugin_gocqhttp"]
plugins = ["nonebot_plugin_apscheduler"]
plugin_dirs = ["src/plugins"]

[build-system]
Expand Down
4 changes: 2 additions & 2 deletions src/common/utils/array2cqcode/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import json

from .message_segment import BaseMessageSegment
from typing import Any
from typing import Union, Any


def try_convert_to_cqcode(data: Any) -> str | Any:
def try_convert_to_cqcode(data: Any) -> Union[str, Any]:
try:
msg = json.loads(data)
if not isinstance(msg, list):
Expand Down
24 changes: 17 additions & 7 deletions src/plugins/chat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,37 @@
from nonebot.rule import Rule
from nonebot.typing import T_State
from nonebot import on_message, get_driver, logger
import random

from .model import chat, del_session
from src.common.config import BotConfig, GroupConfig
from src.common.config import BotConfig, GroupConfig, plugin_config

try:
from src.common.utils.speech.text_to_speech import text_2_speech
TTS_AVAIABLE = True
except Exception as error:
print('TTS not available, error:', error)
logger.error('TTS not available, error: ', error)
TTS_AVAIABLE = False

try:
from .model import Chat
except Exception as error:
logger.error('Chat model import error: ', error)
raise error

TTS_MIN_LENGTH = 10

try:
chat = Chat(plugin_config.chat_strategy)
except Exception as error:
logger.error('Chat model init error: ', error)
raise error


@BotConfig.handle_sober_up
def on_sober_up(bot_id, group_id, drunkenness) -> bool:
session = f'{bot_id}_{group_id}'
logger.info(
f'bot [{bot_id}] sober up in group [{group_id}], clear session [{session}]')
del_session(session)
chat.del_session(session)


def is_drunk(bot: Bot, event: Event, state: T_State) -> bool:
Expand Down Expand Up @@ -60,8 +71,7 @@ async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
text = text[:50]
if not text:
return

ans = await asyncify(chat)(session, text)
ans = await asyncify(chat.chat)(session, text)
logger.info(f'session [{session}]: {text} -> {ans}')

if TTS_AVAIABLE and len(ans) >= TTS_MIN_LENGTH:
Expand Down
138 changes: 62 additions & 76 deletions src/plugins/chat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,88 +10,74 @@
# 这个要配个 ninja 啥的环境,能大幅提高推理速度,有需要可以自己弄下(仅支持 cuda 显卡)
os.environ["RWKV_CUDA_ON"] = '0'


from rwkv.model import RWKV # pip install rwkv
from .pipeline import PIPELINE, PIPELINE_ARGS
from rwkv.model import RWKV
from .prompt import INIT_PROMPT, CHAT_FORMAT
from src.common.config import plugin_config

# 这个可以照着原仓库的说明改一改,能省点显存啥的
STRATEGY = 'cuda fp16' if cuda else 'cpu fp32'
if plugin_config.chat_strategy:
STRATEGY = plugin_config.chat_strategy

MODEL_DIR = Path('resource/chat/models')
MODEL_EXT = '.pth'
MODEL_PATH = None
for f in MODEL_DIR.glob('*'):
if f.suffix != MODEL_EXT:
continue
MODEL_PATH = f.with_suffix('')
break

print('Chat model:', MODEL_PATH)

if not MODEL_PATH:
print(f'!!!!!!Chat model not found, please put it in {MODEL_DIR}!!!!!!')
print(f'!!!!!!Chat 模型不存在,请放到 {MODEL_DIR} 文件夹下!!!!!!')
raise Exception('Chat model not found')

TOKEN_PATH = MODEL_DIR / '20B_tokenizer.json'

if not TOKEN_PATH.exists():
print(
f'AI Chat updated, please put token file to {TOKEN_PATH}, download: https://github.com/BlinkDL/ChatRWKV/blob/main/20B_tokenizer.json')
print(
f'牛牛的 AI Chat 版本更新了,把 token 文件放到 {TOKEN_PATH} 里再启动, 下载地址:https://github.com/BlinkDL/ChatRWKV/blob/main/20B_tokenizer.json')
raise Exception('Chat token not found')

model = RWKV(model=str(MODEL_PATH), strategy=STRATEGY)
pipeline = PIPELINE(model, str(TOKEN_PATH))
args = PIPELINE_ARGS(
temperature=1.0,
top_p=0.7,
alpha_frequency=0.25,
alpha_presence=0.25,
token_ban=[0], # ban the generation of some tokens
token_stop=[], # stop generation whenever you see any token here
ends=('\n'),
ends_if_too_long=("。", "!", "?", "\n"))


INIT_STATE = deepcopy(pipeline.generate(
INIT_PROMPT, token_count=200, args=args)[1])
all_state = defaultdict(lambda: deepcopy(INIT_STATE))
all_occurrence = {}

chat_locker = Lock()


def chat(session: str, text: str, token_count: int = 50) -> str:
with chat_locker:
state = all_state[session]
ctx = CHAT_FORMAT.format(text)
occurrence = all_occurrence.get(session, {})

out, state, occurrence = pipeline.generate(
ctx, token_count=token_count, args=args, state=state, occurrence=occurrence)

all_state[session] = deepcopy(state)
all_occurrence[session] = occurrence
return out.strip()

from .pipeline import PIPELINE, PIPELINE_ARGS

def del_session(session: str):
with chat_locker:
if session in all_state:
del all_state[session]
if session in all_occurrence:
del all_occurrence[session]
DEFAULT_STRATEGY = 'cuda fp16' if cuda else 'cpu fp32'
DEFAULT_MODEL_DIR = Path('resource/chat/models')


class Chat:
def __init__(self, strategy=DEFAULT_STRATEGY, model_dir=DEFAULT_MODEL_DIR) -> None:
self.STRATEGY = strategy if strategy else DEFAULT_STRATEGY
self.MODEL_DIR = model_dir
self.MODEL_EXT = '.pth'
self.MODEL_PATH = None
self.TOKEN_PATH = self.MODEL_DIR / '20B_tokenizer.json'
for f in self.MODEL_DIR.glob('*'):
if f.suffix != self.MODEL_EXT:
continue
self.MODEL_PATH = f.with_suffix('')
break
if not self.MODEL_PATH:
raise Exception(f'Chat model not found in {self.MODEL_DIR}')
if not self.TOKEN_PATH.exists():
raise Exception(f'Chat token not found in {self.TOKEN_PATH}')
model = RWKV(model=str(self.MODEL_PATH), strategy=self.STRATEGY)
self.pipeline = PIPELINE(model, str(self.TOKEN_PATH))
self.args = PIPELINE_ARGS(
temperature=1.0,
top_p=0.7,
alpha_frequency=0.25,
alpha_presence=0.25,
token_ban=[0], # ban the generation of some tokens
token_stop=[], # stop generation whenever you see any token here
ends=('\n'),
ends_if_too_long=("。", "!", "?", "\n"))

INIT_STATE = deepcopy(self.pipeline.generate(
INIT_PROMPT, token_count=200, args=self.args)[1])
self.all_state = defaultdict(lambda: deepcopy(INIT_STATE))
self.all_occurrence = {}

self.chat_locker = Lock()

def chat(self, session: str, text: str, token_count: int = 50) -> str:
with self.chat_locker:
state = self.all_state[session]
ctx = CHAT_FORMAT.format(text)
occurrence = self.all_occurrence.get(session, {})

out, state, occurrence = self.pipeline.generate(
ctx, token_count=token_count, args=self.args, state=state, occurrence=occurrence)

self.all_state[session] = deepcopy(state)
self.all_occurrence[session] = occurrence
return out.strip()

def del_session(self, session: str):
with self.chat_locker:
if session in self.all_state:
del self.all_state[session]
if session in self.all_occurrence:
del self.all_occurrence[session]


if __name__ == "__main__":
chat = Chat('cpu fp32')
while True:
session = "main"
text = input('text:')
result = chat(session, text)
result = chat.chat(session, text)
print(result)
33 changes: 1 addition & 32 deletions src/plugins/greeting/wiki.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import random
from src.common.utils.download_tools import DownloadTools

# 这里的值是CN不代表是中文语音,wiki的定义有点怪,所有语言都叫CN_xx
# 实际的url类似 'https://static.prts.wiki/voice_jp/char_485_pallas/CN_01.wav'
Expand Down Expand Up @@ -30,29 +29,7 @@
voices_source = 'resource/voices'


class WikiVoice(DownloadTools):
def download_voice_from_wiki(self, operator, url, filename):
folder = f'{voices_source}/{operator}'
f = f'{folder}/{filename}'
if os.path.exists(f):
return

print('Downloading', url, "as", filename, "to", folder)
content = self.request_file(url)
if content:
os.makedirs(folder, exist_ok=True)
with open(f, mode='wb+') as voice:
voice.write(content)
else:
print("Download failed!")

def download_voices(self, folder, oper_id):
base_url = f'https://static.prts.wiki/voice/{oper_id}/'
for key, web_file in voice_dict.items():
url = f'{base_url}{web_file}.wav'
filename = f'{key}.wav'
self.download_voice_from_wiki(folder, url, filename)

class WikiVoice():
def get_voice_filename(self, operator, key):
if key not in voice_dict:
return None
Expand All @@ -65,11 +42,3 @@ def get_voice_filename(self, operator, key):
def get_random_voice(self, operator, ranges):
key = random.choice([r for r in ranges if r in voice_dict])
return self.get_voice_filename(operator, key)


if __name__ == '__main__':
operator = 'Pallas'
wiki = WikiVoice()
wiki.download_voices('Pallas', 'char_485_pallas')

print(wiki.get_random_voice(operator))
52 changes: 50 additions & 2 deletions src/plugins/take_name/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import random
from nonebot import require, logger, get_bot
from nonebot import require, logger, get_bot, on_notice
from nonebot.rule import Rule
from nonebot.typing import T_State
from nonebot.adapters import Bot, Event
from nonebot.exception import ActionFailed
from nonebot.adapters.onebot.v11 import Message
from nonebot.adapters.onebot.v11 import Message, NoticeEvent

from src.plugins.repeater.model import Chat
from src.common.config import BotConfig
Expand Down Expand Up @@ -77,3 +80,48 @@ async def change_name():
except ActionFailed:
# 可能牛牛退群了
continue


async def is_change_name_notice(bot: Bot, event: NoticeEvent, state: T_State) -> bool:
config = BotConfig(event.self_id, event.group_id)
if event.notice_type == 'group_card' and event.user_id == config.taken_name():
return True
return False


watch_name = on_notice(rule=Rule(is_change_name_notice), priority=4)


@watch_name.handle()
async def watch_name_handle(bot: Bot, event: NoticeEvent, state: T_State):
group_id = event.group_id
user_id = event.user_id

try:
info = await bot.call_api('get_group_member_info', **{
'group_id': group_id,
'user_id': user_id,
'no_cache': True
})
except ActionFailed:
return

card = info['card'] if info['card'] else info['nickname']

logger.info(
'bot [{}] watch name change by [{}] in group [{}]'.format(
bot.self_id, user_id, group_id))

config = BotConfig(bot.self_id, group_id)

try:
await bot.call_api('set_group_card', **{
'group_id': group_id,
'user_id': user_id,
'card': card
})

config.update_taken_name(user_id)

except ActionFailed:
return

0 comments on commit 21e4292

Please sign in to comment.