-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathLLM_chat.py
276 lines (236 loc) · 11.7 KB
/
LLM_chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import os
import numpy as np
import pandas as pd
import re
import os
from tqdm import tqdm
import logging
import sys
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.openai_like import OpenAILike
from llama_index.core.indices.vector_store.base import VectorStoreIndex
import torch
# from transformers import pipeline
from typing import Optional, List, Mapping, Any
# from llama_index.core.service_context import ServiceContext
# from llama_index.readers.file import UnstructuredReader, PDFReader, SimpleDirectoryReader, SummaryIndex
# from llama_index.callbacks import CallbackManager
from llama_index.core.llms import (
CustomLLM,
CompletionResponse,
CompletionResponseGen,
LLMMetadata,
)
from llama_index.graph_stores.neo4j import Neo4jGraphStore
from llama_index.core.service_context import ServiceContext
from llama_index.core.indices.vector_store.base import VectorStoreIndex
from llama_index.core.llms.callbacks import llm_completion_callback
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
from llama_index.core import PromptTemplate
from llama_index.core.service_context import set_global_service_context
from llama_index.core.storage.storage_context import StorageContext
from llama_index.core.graph_stores import SimpleGraphStore
from llama_index.core.indices.knowledge_graph.base import KnowledgeGraphIndex
from llama_index.core import Document
from llama_index.readers.file import UnstructuredReader, PDFReader
from llama_index.core import Settings
from pathlib import Path
from pyvis.network import Network
import IPython
from llama_index.core import load_index_from_storage
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# set context window size
context_window = 3900
# set number of output tokens
num_output = 256
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
# store the pipeline/model outside of the LLM class to avoid memory issues
model_name = "internlm2-chat-20b"
tokenizer = AutoTokenizer.from_pretrained("/model/Weight/internlm2-chat-20b", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("/model/Weight/internlm2-chat-20b", device_map="auto",
trust_remote_code=True, torch_dtype=torch.float16).eval()
model.generation_config = GenerationConfig.from_pretrained(f"/model/Weight/internlm2-chat-20b",
trust_remote_code=True)
embed_model = HuggingFaceEmbedding(model_name="/model/Weight/BAAI/bge-m3")
class LLM(CustomLLM):
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(
context_window=context_window,
num_output=num_output,
model_name=model_name,
)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
prompt_length = len(prompt)
# only return newly generated tokens
text, _ = model.chat(tokenizer, prompt, history=[])
return CompletionResponse(text=text)
@llm_completion_callback()
def stream_complete(
self, prompt: str, **kwargs: Any
) -> CompletionResponseGen:
raise NotImplementedError()
llm = LLM()
exl = pd.read_excel('/code/open_clip-main/IneternLM/金融大赛(公告)/2024-02-28-公告测评集(有选项).xls', engine='xlrd')
print(exl[:5])
exl = exl[:3]
num = 0
mac = 0
re_prompt = """
可能涉及的领域有:'企业管理与决策分析', '财务与投资分析', '市场与法律合规', '风险与影响评估', '经营分析','人力资源管理'
请根据文本中的信息,告知身份与任务,领域,相关性(0-1)模板为:
实际领域:
身份:
任务:
企业管理与决策分析相关性:
财务与投资分析相关性:
市场与法律合规相关性:
风险与影响评估相关性:
文本分析与逻辑推理经营效率分析相关性:
人力资源管理相关性:
是否为计算题:
完成分析
请严格按照模板格式,无需新增内容
"""
prompt_q = '''\n请仔细阅读公告内容,并抽取其中的关键事件。针对每个事件,请严格按照格式生成一个问句(无需作出回答)。
格式为:
问题:事件,可能导致的后果是?
问题:事件,可能导致的后果是?
请输出:'''
Settings.llm = llm
Settings.embed_model = embed_model
Settings.chunk_size = 512
username = "neo4j"
password = "12345678"
url_1 = "neo4j://10.6.44.224:1111"
url_2 = "neo4j://10.6.44.224:4444"
url_3 = "neo4j://10.6.44.224:5555"
url_4 = "neo4j://10.6.44.224:3333"
url_5 = "neo4j://10.6.44.224:2222"
url_6 = "neo4j://10.6.44.224:7687"
database = "neo4j"
graph_store1 = Neo4jGraphStore(username=username,password=password,url=url_1,database=database)
graph_store2 = Neo4jGraphStore(username=username,password=password,url=url_2,database=database)
graph_store3 = Neo4jGraphStore(username=username,password=password,url=url_3,database=database)
graph_store4 = Neo4jGraphStore(username=username,password=password,url=url_4,database=database)
graph_store5 = Neo4jGraphStore(username=username,password=password,url=url_5,database=database)
graph_store6 = Neo4jGraphStore(username=username,password='neo4j',url=url_6,database=database)
storage_context1 = StorageContext.from_defaults(persist_dir="/code/open_clip-main/IneternLM/book/战略管理", graph_store=graph_store1)
loaded_index_1 = load_index_from_storage(storage_context1)
neo4j_kg_engine_1 = loaded_index_1.as_query_engine(include_text=False, response_mode="tree_summarize")
storage_context2 = StorageContext.from_defaults(persist_dir="/code/open_clip-main/IneternLM/book/财务管理", graph_store=graph_store2)
loaded_index_2 = load_index_from_storage(storage_context2)
neo4j_kg_engine_2 = loaded_index_2.as_query_engine(include_text=False, response_mode="tree_summarize")
storage_context3 = StorageContext.from_defaults(persist_dir="/code/open_clip-main/IneternLM/book/法律", graph_store=graph_store3)
loaded_index_3 = load_index_from_storage(storage_context3)
neo4j_kg_engine_3 = loaded_index_3.as_query_engine(include_text=False, response_mode="tree_summarize")
storage_context4 = StorageContext.from_defaults(persist_dir="/code/open_clip-main/IneternLM/book/投资与风险", graph_store=graph_store4)
loaded_index_4 = load_index_from_storage(storage_context4)
neo4j_kg_engine_4 = loaded_index_4.as_query_engine(include_text=False, response_mode="tree_summarize")
storage_context5 = StorageContext.from_defaults(persist_dir="/code/open_clip-main/IneternLM/book/经营分析", graph_store=graph_store5)
loaded_index_5 = load_index_from_storage(storage_context5)
neo4j_kg_engine_5 = loaded_index_5.as_query_engine(include_text=False, response_mode="tree_summarize")
storage_context6 = StorageContext.from_defaults(persist_dir="/code/open_clip-main/IneternLM/book/人力", graph_store=graph_store6)
loaded_index_6 = load_index_from_storage(storage_context6)
neo4j_kg_engine_6 = loaded_index_6.as_query_engine(include_text=False, response_mode="tree_summarize")
for i in tqdm(range(len(exl))):
text = exl['评测问题'][i].split('\n')
n_t = '文本为:' + text[0] + '\n' + re_prompt
result = llm.complete(n_t)
new_summary_tmpl_str2 = (
"公告信息如下:\n" + str(text[1]) +
"\n提供的信息如下:\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"问题: {query_str}\n"
"请根据公告信息、提供的信息针对问题给出尽可能详细的答案(仅输出答案即可)"
)
new_summary_tmpl2 = PromptTemplate(new_summary_tmpl_str2)
try:
space = re.search(r'实际领域:(.*?)\n', str(result)).group(1)
except:
space = '未知'
try:
job = re.search(r'身份:(.*?)\n', str(result)).group(1)
except:
job = '专家'
try:
task = re.search(r'任务:(.*?)\n', str(result)).group(1)
except:
task = '分析以下文本'
try:
s1 = re.search(r'企业管理与决策分析相关性:(.*?)\n', str(result)).group(1)
s2 = re.search(r'财务与投资分析相关性:(.*?)\n', str(result)).group(1)
s3 = re.search(r'市场与法律合规相关性:(.*?)\n', str(result)).group(1)
s4 = re.search(r'风险与影响评估相关性:(.*?)\n', str(result)).group(1)
s5 = re.search(r'文本分析与逻辑推理经营效率分析相关性:(.*?)\n', str(result)).group(1)
except:
s1 = s2 = s3 = s4 = s5 = 1
spaces = ['企业管理与决策分析', '财务与投资分析', '市场与法律合规', '风险与影响评估', '经营分析', '人力资源管理']
sn = [s1, s2, s3, s4, s5]
results = '多方面公告理解为:'
for sp in range(len(spaces)):
# prompt = '请以' + str(spaces[sp]) + '的知识,并根据任务:' + task + '\n超级详细的详细解析公告,并超级详细说明带来的直接影响:' + text[1] + '\n仅输出解析即可'
# print(prompt)
query_sp = str(text[1]) + prompt_q
ress2 = str(llm.complete(query_sp))
query2 = ress2.split('问题')
sp_result = ''
for q in query2:
if sp == 0:
neo4j_kg_engine_1.update_prompts({"response_synthesizer:summary_template": new_summary_tmpl2})
ress3 = neo4j_kg_engine_1.query(q)
elif sp == 1:
neo4j_kg_engine_2.update_prompts({"response_synthesizer:summary_template": new_summary_tmpl2})
ress3 = neo4j_kg_engine_2.query(q)
elif sp == 2:
neo4j_kg_engine_3.update_prompts({"response_synthesizer:summary_template": new_summary_tmpl2})
ress3 = neo4j_kg_engine_3.query(q)
elif sp == 3:
neo4j_kg_engine_4.update_prompts({"response_synthesizer:summary_template": new_summary_tmpl2})
ress3 = neo4j_kg_engine_4.query(q)
elif sp == 4:
neo4j_kg_engine_5.update_prompts({"response_synthesizer:summary_template": new_summary_tmpl2})
ress3 = neo4j_kg_engine_5.query(q)
elif sp == 5:
neo4j_kg_engine_6.update_prompts({"response_synthesizer:summary_template": new_summary_tmpl2})
ress3 = neo4j_kg_engine_6.query(q)
sp_result = sp_result + str(ress3) + '\n'
results = results + '\n' + spaces[sp] + '方面分析意见为:' + '\n' + str(
sp_result) + '\n'
result_fin = '公告:' + text[1] + '\n' + results
text_path = '/code/open_clip-main/IneternLM/chat_text/question' + str(i) + '.txt'
with open(text_path, 'w', encoding='utf-8') as f:
f.write(str(result_fin))
Settings.llm = llm
Settings.embed_model = embed_model
Settings.chunk_size = 512
service_context = ServiceContext.from_defaults(
llm=llm,
embed_model=embed_model,
chunk_size=512
)
set_global_service_context(service_context)
loader = UnstructuredReader()
documents = loader.load_data(file=Path(text_path))
graph_store = SimpleGraphStore() # In-memory
storage_context = StorageContext.from_defaults(graph_store=graph_store)
index = KnowledgeGraphIndex.from_documents(
documents,
max_triplets_per_chunk=50,
storage_context=storage_context,
service_context=service_context,
show_progress=True,
# include_embeddings=True, # Query with embeddings
)
g = index.get_networkx_graph()
net = Network(notebook=True, cdn_resources="in_line", directed=True)
net.from_nx(g)
graph_path = '/code/open_clip-main/IneternLM/chat_graph/question' + str(i) + '.html'
net.show(graph_path)
IPython.display.HTML(filename=graph_path)
exl.to_excel('/code/open_clip-main/IneternLM/金融大赛(公告)/chat.xls', index=False)