forked from zjunlp/DeepKE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
149 lines (126 loc) · 5.1 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import sys
import torch
import logging
import hydra
import models
from hydra import utils
from utils import load_pkl, load_csv
from serializer import Serializer
from preprocess import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data
import matplotlib.pyplot as plt
logger = logging.getLogger(__name__)
def _preprocess_data(data, cfg):
vocab = load_pkl(os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl'), verbose=False)
relation_data = load_csv(os.path.join(cfg.cwd, cfg.data_path, 'relation.csv'), verbose=False)
rels = _handle_relation_data(relation_data)
cfg.vocab_size = vocab.count
serializer = Serializer(do_chinese_split=cfg.chinese_split)
serial = serializer.serialize
_serialize_sentence(data, serial, cfg)
_convert_tokens_into_index(data, vocab)
_add_pos_seq(data, cfg)
logger.info('start sentence preprocess...')
formats = '\nsentence: {}\nchinese_split: {}\nreplace_entity_with_type: {}\nreplace_entity_with_scope: {}\n' \
'tokens: {}\ntoken2idx: {}\nlength: {}\nhead_idx: {}\ntail_idx: {}'
logger.info(
formats.format(data[0]['sentence'], cfg.chinese_split, cfg.replace_entity_with_type,
cfg.replace_entity_with_scope, data[0]['tokens'], data[0]['token2idx'], data[0]['seq_len'],
data[0]['head_idx'], data[0]['tail_idx']))
return data, rels
def _get_predict_instance(cfg):
flag = input('是否使用范例[y/n],退出请输入: exit .... ')
flag = flag.strip().lower()
if flag == 'y' or flag == 'yes':
sentence = '《乡村爱情》是一部由知名导演赵本山在1985年所拍摄的农村青春偶像剧。'
head = '乡村爱情'
tail = '赵本山'
head_type = '电视剧'
tail_type = '人物'
elif flag == 'n' or flag == 'no':
sentence = input('请输入句子:')
head = input('请输入句中需要预测关系的头实体:')
head_type = input('请输入头实体类型(可以为空,按enter跳过):')
tail = input('请输入句中需要预测关系的尾实体:')
tail_type = input('请输入尾实体类型(可以为空,按enter跳过):')
elif flag == 'exit':
sys.exit(0)
else:
print('please input yes or no, or exit!')
_get_predict_instance()
instance = dict()
instance['sentence'] = sentence.strip()
instance['head'] = head.strip()
instance['tail'] = tail.strip()
if head_type.strip() == '' or tail_type.strip() == '':
cfg.replace_entity_with_type = False
instance['head_type'] = 'None'
instance['tail_type'] = 'None'
else:
instance['head_type'] = head_type.strip()
instance['tail_type'] = tail_type.strip()
return instance
# 自定义模型存储的路径
fp = 'xxx/checkpoints/2019-12-03_17-35-30/cnn_epoch21.pth'
@hydra.main(config_path='conf/config.yaml')
def main(cfg):
cwd = utils.get_original_cwd()
cfg.cwd = cwd
cfg.pos_size = 2 * cfg.pos_limit + 2
print(cfg.pretty())
# get predict instance
instance = _get_predict_instance(cfg)
data = [instance]
# preprocess data
data, rels = _preprocess_data(data, cfg)
# model
__Model__ = {
'cnn': models.PCNN,
'rnn': models.BiLSTM,
'transformer': models.Transformer,
'gcn': models.GCN,
'capsule': models.Capsule,
'lm': models.LM,
}
# 最好在 cpu 上预测
# cfg.use_gpu = False
if cfg.use_gpu and torch.cuda.is_available():
device = torch.device('cuda', cfg.gpu_id)
else:
device = torch.device('cpu')
logger.info(f'device: {device}')
model = __Model__[cfg.model_name](cfg)
logger.info(f'model name: {cfg.model_name}')
logger.info(f'\n {model}')
model.load(fp, device=device)
model.to(device)
model.eval()
x = dict()
x['word'], x['lens'] = torch.tensor([data[0]['token2idx']]), torch.tensor([data[0]['seq_len']])
if cfg.model_name != 'lm':
x['head_pos'], x['tail_pos'] = torch.tensor([data[0]['head_pos']]), torch.tensor([data[0]['tail_pos']])
if cfg.model_name == 'cnn':
if cfg.use_pcnn:
x['pcnn_mask'] = torch.tensor([data[0]['entities_pos']])
for key in x.keys():
x[key] = x[key].to(device)
with torch.no_grad():
y_pred = model(x)
y_pred = torch.softmax(y_pred, dim=-1)[0]
prob = y_pred.max().item()
prob_rel = list(rels.keys())[y_pred.argmax().item()]
logger.info(f"\"{data[0]['head']}\" 和 \"{data[0]['tail']}\" 在句中关系为:\"{prob_rel}\",置信度为{prob:.2f}。")
if cfg.predict_plot:
# maplot 默认显示不支持中文
plt.rcParams["font.family"] = 'Arial Unicode MS'
x = list(rels.keys())
height = list(y_pred.cpu().numpy())
plt.bar(x, height)
for x, y in zip(x, height):
plt.text(x, y, '%.2f' % y, ha="center", va="bottom")
plt.xlabel('关系')
plt.ylabel('置信度')
plt.xticks(rotation=315)
plt.show()
if __name__ == '__main__':
main()