-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathconvert_pass.py
executable file
·148 lines (113 loc) · 4.89 KB
/
convert_pass.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#!/usr/bin/env python
#
# Copyright 2014+ Carnegie Mellon University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Convert a Wikipedia corpus used in the Facebook DPR project:
https://github.com/facebookresearch/DPR/tree/master/data
Optionally, one can specify a subset of the corpus by providing
a numpy array with passage IDs to include (otherwise we will
use all the passages).
The input is a TAB-separated file with three columns: id, passage text, title
This conversion script preserves original passages, but it also tokenizes them.
"""
import numpy as np
import json
import argparse
import multiprocessing
from flexneuart.io import FileWrapper
from flexneuart.io.stopwords import read_stop_words, STOPWORD_FILE
from flexneuart.text_proc.parse import SpacyTextParser, add_retokenized_field
from flexneuart.data_convert import add_bert_tok_args, create_bert_tokenizer_if_needed
from flexneuart.config import IMAP_PROC_CHUNK_QTY, REPORT_QTY, \
TEXT_BERT_TOKENIZED_NAME, \
TEXT_FIELD_NAME, DOCID_FIELD, \
TEXT_RAW_FIELD_NAME, TEXT_UNLEMM_FIELD_NAME, \
TITLE_RAW_FIELD_NAME, TITLE_FIELD_NAME, TITLE_UNLEMM_FIELD_NAME, \
SPACY_MODEL
parser = argparse.ArgumentParser(description='Convert a Wikipedia corpus downloaded from github.com/facebookresearch/DPR.')
parser.add_argument('--input_file', metavar='input file', help='input directory',
type=str, required=True)
parser.add_argument('--passage_ids', metavar='optional passage ids',
type=str, default=None, help='an optional numpy array with passage ids to select')
parser.add_argument('--out_file', metavar='output file',
help='output JSONL file',
type=str, required=True)
# Default is: Number of cores minus one for the spaning process
parser.add_argument('--proc_qty', metavar='# of processes', help='# of NLP processes to span',
type=int, default=multiprocessing.cpu_count() - 1)
add_bert_tok_args(parser)
args = parser.parse_args()
arg_vars = vars(args)
print(args)
bert_tokenizer = create_bert_tokenizer_if_needed(args)
# Lower cased
stop_words = read_stop_words(STOPWORD_FILE, lower_case=True)
print(stop_words)
flt_pass_ids = None
if args.passage_ids is not None:
flt_pass_ids = set(np.load(args.passage_ids))
print(f'Restricting parsing to {len(flt_pass_ids)} passage IDs')
fields = [TEXT_FIELD_NAME, TEXT_UNLEMM_FIELD_NAME, TITLE_UNLEMM_FIELD_NAME, TEXT_RAW_FIELD_NAME]
# Lower cased
text_processor = SpacyTextParser(SPACY_MODEL, stop_words,
keep_only_alpha_num=True, lower_case=True,
enable_pos=True)
class PassParseWorker:
def __call__(self, line):
if not line:
return None
line = line.strip()
if not line:
return None
fields = line.split('\t')
if ' '.join(fields) == 'id text title':
return ''
assert len(fields) == 3, f"Wrong format fline: {line}"
# The passage text is not lower cased, please keep it this way.
pass_id, raw_text, raw_title = fields
if flt_pass_ids is not None:
if pass_id not in flt_pass_ids:
return ''
text_lemmas, text_unlemm = text_processor.proc_text(raw_text)
title_lemmas, title_unlemm = text_processor.proc_text(raw_title)
doc = {DOCID_FIELD: pass_id,
TEXT_FIELD_NAME: title_lemmas + ' ' + text_lemmas,
TITLE_UNLEMM_FIELD_NAME: title_unlemm,
TEXT_UNLEMM_FIELD_NAME: text_unlemm,
TEXT_RAW_FIELD_NAME: raw_title + ' ' + raw_text}
add_retokenized_field(doc, TEXT_RAW_FIELD_NAME, TEXT_BERT_TOKENIZED_NAME, bert_tokenizer)
return json.dumps(doc)
inp_file = FileWrapper(args.input_file)
out_file = FileWrapper(args.out_file, 'w')
proc_qty = args.proc_qty
print(f'Spanning {proc_qty} processes')
pool = multiprocessing.Pool(processes=proc_qty)
ln = 0
ln_ign = 0
for doc_str in pool.imap(PassParseWorker(), inp_file, IMAP_PROC_CHUNK_QTY):
ln = ln + 1
if doc_str is not None:
if doc_str:
out_file.write(doc_str + '\n')
else:
ln_ign += 1
else:
print('Ignoring misformatted line %d' % ln)
if ln % REPORT_QTY == 0:
print('Read %d passages, processed %d passages' % (ln, ln - ln_ign))
print('Processed %d passages' % ln)
inp_file.close()
out_file.close()