-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference_ms.py
320 lines (273 loc) · 10.5 KB
/
inference_ms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
import json
import time
from typing import List, Dict, Tuple
import argparse
import mindspore as ms
from mindspore import context
from mindnlp.peft import PeftModel, PeftConfig, LoraConfig, get_peft_model
from mindnlp.transformers import AutoModelForCausalLM, AutoTokenizer
import datasets
import requests
context.set_context(device_target="GPU")
# context.set_context(device_id=1)
ROLE_DICT = {
"西游记": {
"孙悟空": "悟空",
"唐僧": "唐僧",
"猪八戒": "八戒",
"沙僧": "沙僧",
},
"三国演义": {
"刘备": "玄德",
"关羽": "云长",
"张飞": "翼德",
"曹操": "曹操",
"诸葛亮": "孔明",
},
"水浒传": {
"宋江": "宋公明",
"卢俊义": "玉麒麟",
"吴用": "智多星",
"林冲": "豹子头",
},
"红楼梦": {
"贾宝玉": "宝玉",
"林黛玉": "黛玉",
"薛宝钗": "宝钗",
"王熙凤": "凤姐",
},
}
prompt_system = "<|im_start|>system\n{}<|im_end|>\n"
prompt_user = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
prompt_assistant = "{}<|im_end|>\n"
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path",
type=str,
default="./model_zoo/Qwen/Qwen2-7B-Instruct", # Qwen/Qwen2-7B-Instruct
)
parser.add_argument("--inf_max_length", type=int, default=128)
parser.add_argument(
"--adapter_path",
type=str,
default="./ChatStyle/results/20241204LF-qwen2-7b-instruct",
)
parser.add_argument(
"--isTerminal",
action="store_true",
help="Whether to use terminal for inference",
)
parser.add_argument(
"--isWebsocket",
action="store_true",
help="Whether to use websocket for inference",
)
return parser.parse_args()
# format RAG retrieval results
def format_docs(docs, wiki_docs=None):
ans = "从古籍中检索到的信息如下:\n\n"
for id, doc in enumerate(docs):
ans += f"{id+1}. {doc.page_content}\n\n"
if wiki_docs is not None:
ans += "从维基百科中检索到的信息如下:\n\n"
ans += f'{len(docs)+1}. {wiki_docs[0].metadata["summary"]}\n\n'
# print(f'检索到的信息有:{ans}')
return ans
def get_RAG_prompt(book: str = "西游记", role: str = "孙悟空", query: str = ""):
if query is None and len(query) == 0:
return None
# TODO: update when env changed
base_rag_path = "./RAG/EchoRAG"
# TODO: change config info if necessary, such as embedding model PATH, retriever PATH, etc.
import sys
sys.path.append(base_rag_path)
from config.ConfigLoader import ConfigLoader
config = ConfigLoader(base_rag_path + "/config/config.yaml")
# import logging
# log_level = config.get("global.log_level", "INFO")
# logging.basicConfig(level = log_level,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# logger = logging.getLogger(__name__)
from retrievers.RetrieverCreator import RetrieverCreator
from embeddings.TorchBgeEmbeddings import EmbeddingModelCreator
# logger.info("Creating embedding model...")
# TODO: add embedding_path
embedding_path = ""
embedding_creator = EmbeddingModelCreator(config, embedding_path)
embedding_model = embedding_creator.create_embedding_model()
# logger.info("Creating retriever...")
vecDB_path = base_rag_path + "/" + config.get("vector_db.index_path")
retriever = RetrieverCreator(
config, embedding_model, vecDB_path, collection_name="four_famous"
).create_retriever()
template_retrieved = "在{book}中, {query}"
retrieved_dict = {"book": book, "query": query}
retrieved_docs = retriever.invoke(template_retrieved.format(**retrieved_dict))
retrieved_info = format_docs(retrieved_docs, None)
# template = """假如你是<{book}>中的{role},请与我对话。下面是已知信息: \n
# {retrieved_info}\n
# 请你根据这些信息回答这个问题:{query}。\n
# {spec_role}道:“"""
# input_dict = {"book": book, "role": role, "retrieved_info": retrieved_info, "query": query, "spec_role": ROLE_DICT[book][role]}
# input = template.format(**input_dict)
return retrieved_info
def get_prompt(
msgs: List[Dict], book: str = "西游记", role: str = "孙悟空", has_RAG=False
):
text = ""
for i in range(len(msgs)):
if msgs[i]["role"] == "system":
text += prompt_system.format(msgs[i]["content"])
elif msgs[i]["role"] == "user":
retrieved_info = (
get_RAG_prompt(book, role, msgs[i]["content"]) if has_RAG else ""
)
print(f"retrieved_info: {retrieved_info}")
if i == 1:
user_input = """假如你是<{book}>中的{role},请与我对话。下面是已知信息: \n
{retrieved_info}\n
请你根据这些信息回答这个问题: {query}""".format(
book=book,
role=role,
retrieved_info=retrieved_info,
query=msgs[i]["content"],
)
# user_input = """假如你是<{book}>中的{role},请与我对话。下面是已知信息: \n
# {retrieved_info}\n
# 请你根据这些信息回答这个问题: {query}""".format(
# book=book,
# role=role,
# retrieved_info=retrieved_info,
# query=msgs[i]["content"],
# )
else:
user_input = """下面是已知信息: \n
{retrieved_info}\n
请你根据这些信息回答这个问题: {query}""".format(
book=book,
role=role,
retrieved_info=retrieved_info,
query=msgs[i]["content"],
)
text += prompt_user.format(user_input)
else:
text += prompt_assistant.format(msgs[i]["content"])
text += f"{ROLE_DICT[book][role]}道:"
return text
def processTTS(
character: str = "sunwukong",
# book: str = "西游记",
text: str = "我是孙悟空",
):
# 设置服务器地址和端口
server_url = "http://localhost:5000/tts" # 服务器地址根据实际情况修改
prompt_language = "中文" # 参考文本的语言
text_language = "中文" # 目标文本的语言
how_to_cut = "不切" # 文本切分方式
top_k = 20 # Top-K 参数
top_p = 0.6 # Top-P 参数
temperature = 0.6 # 温度参数
ref_free = False # 是否使用参考音频
# 准备POST请求的payload
data = {
"character": character,
# 'prompt_text': prompt_text,
"prompt_language": prompt_language,
"text": text,
"text_language": text_language,
"how_to_cut": how_to_cut,
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"ref_free": ref_free,
}
# 发送POST请求
response = requests.post(
server_url,
# files=files,
data=data,
)
# 处理返回的音频文件
if response.status_code == 200:
# 保存返回的音频到文件
with open("output_audio.wav", "wb") as f:
f.write(response.content)
print("Audio saved as output_audio.wav")
else:
print(f"Error: {response.status_code}, {response.text}")
def run(args):
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
ms_dtype=ms.float32, # device_map="auto"
)
lora_config = LoraConfig.from_pretrained(args.adapter_path)
model = get_peft_model(model, lora_config)
model.eval()
messages = [
{
"role": "system",
"content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
},
]
def infer(user_inputs):
messages.append(
{"role": "user", "content": user_inputs},
)
# prompt = tokenizer.apply_chat_template(
# messages,
# tokenize=False,
# add_generation_prompt=True,
# )
text = get_prompt(
messages, book="红楼梦", role="林黛玉", has_RAG=True
) # TODO get RAG prompt
# print((f"text: {text}", f"prompt: {prompt}"))
model_inputs = tokenizer([text], return_tensors="ms")
model_inputs["max_new_tokens"] = args.inf_max_length
# print(f"{model_inputs}")
t1 = time.time()
outputs = model.generate(**model_inputs)
t2 = time.time()
outputs = [
output_ids[len(input_ids) :]
for input_ids, output_ids in zip(model_inputs["input_ids"], outputs)
]
print(
f"generate time: {t2 - t1:.4f}, tokens/s: {outputs[0].shape[0] / (t2 - t1)}"
)
# print(outputs)
text_output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
print(f"A: {text_output.strip('“”')}")
messages.append({"role": "assistant", "content": text_output})
return text_output
if args.isTerminal:
while True:
inputs = input("Q: ")
outputs = infer(inputs)
processTTS(character="lindaiyu", text=outputs)
elif args.isWebsocket:
def init_socket():
import asyncio
import websockets
async def echo(websocket):
async for message in websocket:
client_msg = json.loads(message)
print(f"Received message from client: {client_msg}")
response = {
"status": "success",
"echo": infer(client_msg["message"]),
}
processTTS(character="lindaiyu", text=response["echo"])
await websocket.send(json.dumps(response))
async def main():
async with websockets.serve(
echo, "0.0.0.0", 6006, ping_interval=None
): # 0.0.0.0表示监听所有可用的网络接口,6006表示监听的端口号,需要根据防火墙规则确定
print("WebSocket server started on ws://172.16.185.158:6006")
await asyncio.Future() # 运行直到被取消
asyncio.run(main())
init_socket()
if __name__ == "__main__":
args = get_args()
run(args)