-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathuseful_scripts.py
112 lines (95 loc) · 4.69 KB
/
useful_scripts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
from os import listdir
from os.path import isfile, join
from rouge import FilesRouge
from utils.data_utils import TsvProcessor
import random
# This file contains scripts that are used to obtain the results for my bachelor's thesis extended experiments.
# They only process data files and calculate rouge scores.
def parse_topics():
"""
Parses the topic files generated by the MeetingParser to generate a predict.txt file that is used as input for
the network's predict run_mode.
"""
path = './data2'
files = [f for f in listdir(path) if isfile(join(path, f))]
with open('predict.txt', 'w') as predict_file:
# To keep track of what line belongs to what meeting, create an additional file that contains this information
with open('line-counts.txt', 'w') as line_counts:
file_number = 0
for file_name in files:
if not file_name.startswith('topcis.'):
continue
if not file_name.endswith('.test.txt'):
continue
file_number += 1
with open(os.path.join(path, file_name), 'r') as data_file:
for line in data_file:
if line is not '':
predict_file.write("%s\n" % line.strip('\n'))
line_counts.write("%s %s\n" % (file_number, file_name))
def create_random_predictions():
"""
Chooses a random sentence from the abstractive summaries of the training data and uses it as the summary
of a topic's transcript.
"""
processor = TsvProcessor()
examples = processor.get_train_examples('./data')
with open('predict.txt', 'r') as predict_file:
with open('predict-predictions-random.txt', 'w') as random_predictions_file:
for line in predict_file:
random_predictions_file.write('%s\n' % random.choice(examples).tgt_text.lower())
def shrink_per_topic_to_per_meeting(input_file='predictions.txt', output_file='summaries.tgt.txt'):
"""
Concatenates the lines of a per-topic file to a per-meeting file.
"""
with open(output_file, 'w') as summaries_src_file:
with open(input_file, 'r') as predictions_file:
with open('line-counts.txt', 'r') as line_counts:
last_line = '1'
for text in line_counts:
line = text.split(' ')[0]
file_name = text.split(' ')[1].strip('\n')
if line != last_line:
summaries_src_file.write('\n')
summaries_src_file.write(predictions_file.readline().strip('\n'))
summaries_src_file.write(' ')
last_line = line
def create_tgt_summaries():
"""
Parses the summary files generated by the MeetingParser to generate a tgt summaries file that can be used
to compute ROUGE scores.
"""
path = './data2'
with open('summaries.tgt.txt', 'w') as summaries_tgt_file:
with open('line-counts.txt', 'r') as line_counts:
last_line = '-1'
for text in line_counts:
line = text.split(' ')[0]
file_name = text.split(' ')[1].strip('\n')
if line != last_line:
with open(os.path.join(path, file_name.replace('topcis', 'summaries')), 'r') as summary:
summaries_tgt_file.write('%s\n' % summary.readline().lower())
last_line = line
def compare_summaries():
"""
Compares the src summaries with the tgt summaries and prints the ROUGE scores.
"""
files_rouge = FilesRouge('summaries.src.txt', 'summaries.tgt.txt')
rouge_scores = files_rouge.get_scores(avg=True)
print_rouge_scores(rouge_scores)
def print_rouge_scores(scores):
"""
Prints the rouge scores in a nice, human-readable format.
"""
rouge_1 = scores['rouge-1']
rouge_2 = scores['rouge-2']
rouge_l = scores['rouge-l']
print("┌─────────┬────────┬────────┬────────┐")
print("│ Metric │ Pre │ Rec │ F │")
print("├─────────┼────────┼────────┼────────┤")
print("│ ROUGE-1 │ %.4f │ %.4f │ %.4f │" % (rouge_1['p'], rouge_1['r'], rouge_1['f']))
print("│ ROUGE-2 │ %.4f │ %.4f │ %.4f │" % (rouge_2['p'], rouge_2['r'], rouge_2['f']))
print("│ ROUGE-L │ %.4f │ %.4f │ %.4f │" % (rouge_l['p'], rouge_l['r'], rouge_l['f']))
print("└─────────┴────────┴────────┴────────┘")
# compare_summaries()