|
a |
|
b/development/dpr-qa-test-pipline/main.py |
|
|
1 |
import sys |
|
|
2 |
import argparse |
|
|
3 |
from dpr_qa_test import DPRTest |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
def main(args): |
|
|
7 |
parser = argparse.ArgumentParser(description='To get arguments') |
|
|
8 |
|
|
|
9 |
parser.add_argument('--max_seq_lens', |
|
|
10 |
nargs='*', |
|
|
11 |
type=int, |
|
|
12 |
default=[256], |
|
|
13 |
help='Max sequence length of one input text for the model', |
|
|
14 |
required=False) |
|
|
15 |
|
|
|
16 |
parser.add_argument('--max_seq_len_passages', |
|
|
17 |
nargs='*', |
|
|
18 |
type=int, |
|
|
19 |
default=[256], |
|
|
20 |
help='Longest length of each passage/context sequence. Maximum number of tokens for the passage text. Longer ones will be cut down.', |
|
|
21 |
required=False) |
|
|
22 |
|
|
|
23 |
parser.add_argument('--max_seq_len_queries', |
|
|
24 |
nargs='*', |
|
|
25 |
type=int, |
|
|
26 |
default=[64], |
|
|
27 |
help='Longest length of each query sequence. Maximum number of tokens for the query text. Longer ones will be cut down.', |
|
|
28 |
required=False) |
|
|
29 |
|
|
|
30 |
parser.add_argument('--embed_titles', |
|
|
31 |
nargs='*', |
|
|
32 |
type=bool, |
|
|
33 |
default=[True, False], |
|
|
34 |
help='Whether to concatenate title and passage to a text pair that is then used to create the embedding.', |
|
|
35 |
required=False) |
|
|
36 |
|
|
|
37 |
parser.add_argument('--context_window_sizes', |
|
|
38 |
nargs='*', |
|
|
39 |
type=int, |
|
|
40 |
default=[150, 175], |
|
|
41 |
help='The size, in characters, of the window around the answer span that is used when displaying the context around the answer.', |
|
|
42 |
required=False) |
|
|
43 |
|
|
|
44 |
parser.add_argument('--doc_strides', |
|
|
45 |
nargs='*', |
|
|
46 |
type=int, |
|
|
47 |
default=[100, 128], |
|
|
48 |
help='Length of striding window for splitting long texts (used if len(text) > max_seq_len)', |
|
|
49 |
required=False) |
|
|
50 |
|
|
|
51 |
parser.add_argument('--retriever_top_ks', |
|
|
52 |
nargs='*', |
|
|
53 |
type=int, |
|
|
54 |
default=[3, 5, 7], |
|
|
55 |
help='How many documents to return per query.', |
|
|
56 |
required=False) |
|
|
57 |
|
|
|
58 |
parser.add_argument('--reader_top_ks', |
|
|
59 |
nargs='*', |
|
|
60 |
type=int, |
|
|
61 |
default=[3, 5, 7], |
|
|
62 |
help='The maximum number of answers to return', |
|
|
63 |
required=False) |
|
|
64 |
|
|
|
65 |
parser.add_argument('--reader_models', |
|
|
66 |
nargs='*', |
|
|
67 |
type=str, |
|
|
68 |
default=['ktrapeznikov/albert-xlarge-v2-squad-v2', |
|
|
69 |
'deepset/roberta-base-squad2', |
|
|
70 |
'deepset/minilm-uncased-squad2', |
|
|
71 |
'ahotrod/albert_xxlargev1_squad2_512'], |
|
|
72 |
help='The maximum number of answers to return', |
|
|
73 |
required=False) |
|
|
74 |
|
|
|
75 |
parser.add_argument('--text_datasets', |
|
|
76 |
nargs='*', |
|
|
77 |
type=str, |
|
|
78 |
default=['titleText-threeSentences.csv', |
|
|
79 |
'titleText-paragraphs.csv'], |
|
|
80 |
help='The maximum number of answers to return', |
|
|
81 |
required=False) |
|
|
82 |
|
|
|
83 |
parser.add_argument('--qa_datasets', |
|
|
84 |
nargs='*', |
|
|
85 |
type=str, |
|
|
86 |
default=['qa-SQUAD.json'], |
|
|
87 |
help='The maximum number of answers to return', |
|
|
88 |
required=False) |
|
|
89 |
|
|
|
90 |
parser.add_argument('--report_out_dir', |
|
|
91 |
type=str, |
|
|
92 |
default='dpr-qa-report.csv', |
|
|
93 |
help='The maximum number of answers to return', |
|
|
94 |
required=False) |
|
|
95 |
|
|
|
96 |
parser.add_argument('--sample_out_dir', |
|
|
97 |
type=str, |
|
|
98 |
default='dpr-qa-sample.json', |
|
|
99 |
help='The maximum number of answers to return', |
|
|
100 |
required=False) |
|
|
101 |
|
|
|
102 |
args = parser.parse_args() |
|
|
103 |
|
|
|
104 |
report_generator = DPRTest.report_generator(max_seq_lens=args.max_seq_lens, |
|
|
105 |
max_seq_len_passages=args.max_seq_len_passages, |
|
|
106 |
max_seq_len_queries=args.max_seq_len_queries, |
|
|
107 |
embed_titles=args.embed_titles, |
|
|
108 |
context_window_sizes=args.context_window_sizes, |
|
|
109 |
doc_strides=args.doc_strides, |
|
|
110 |
retriever_top_ks=args.retriever_top_ks, |
|
|
111 |
reader_top_ks=args.reader_top_ks, |
|
|
112 |
reader_models=args.reader_models, |
|
|
113 |
text_datasets=args.text_datasets, |
|
|
114 |
qa_datasets=args.qa_datasets, |
|
|
115 |
report_out_dir=args.qa_datasets, |
|
|
116 |
sample_out_dir=args.sample_out_dir, |
|
|
117 |
) |
|
|
118 |
|
|
|
119 |
_, _ = report_generator.get_report() |
|
|
120 |
|
|
|
121 |
|
|
|
122 |
if __name__ == "__main__": |
|
|
123 |
main(sys.argv[1:]) |