Switch to unified view

a b/development/dpr-qa-test-pipline/main.py
1
import sys
2
import argparse
3
from dpr_qa_test import DPRTest
4
5
6
def main(args):
7
    parser = argparse.ArgumentParser(description='To get arguments')
8
9
    parser.add_argument('--max_seq_lens',
10
                        nargs='*',
11
                        type=int,
12
                        default=[256],
13
                        help='Max sequence length of one input text for the model',
14
                        required=False)
15
16
    parser.add_argument('--max_seq_len_passages',
17
                        nargs='*',
18
                        type=int,
19
                        default=[256],
20
                        help='Longest length of each passage/context sequence. Maximum number of tokens for the passage text. Longer ones will be cut down.',
21
                        required=False)
22
23
    parser.add_argument('--max_seq_len_queries',
24
                        nargs='*',
25
                        type=int,
26
                        default=[64],
27
                        help='Longest length of each query sequence. Maximum number of tokens for the query text. Longer ones will be cut down.',
28
                        required=False)
29
30
    parser.add_argument('--embed_titles',
31
                        nargs='*',
32
                        type=bool,
33
                        default=[True, False],
34
                        help='Whether to concatenate title and passage to a text pair that is then used to create the embedding.',
35
                        required=False)
36
37
    parser.add_argument('--context_window_sizes',
38
                        nargs='*',
39
                        type=int,
40
                        default=[150, 175],
41
                        help='The size, in characters, of the window around the answer span that is used when displaying the context around the answer.',
42
                        required=False)
43
44
    parser.add_argument('--doc_strides',
45
                        nargs='*',
46
                        type=int,
47
                        default=[100, 128],
48
                        help='Length of striding window for splitting long texts (used if len(text) > max_seq_len)',
49
                        required=False)
50
51
    parser.add_argument('--retriever_top_ks',
52
                        nargs='*',
53
                        type=int,
54
                        default=[3, 5, 7],
55
                        help='How many documents to return per query.',
56
                        required=False)
57
58
    parser.add_argument('--reader_top_ks',
59
                        nargs='*',
60
                        type=int,
61
                        default=[3, 5, 7],
62
                        help='The maximum number of answers to return',
63
                        required=False)
64
65
    parser.add_argument('--reader_models',
66
                        nargs='*',
67
                        type=str,
68
                        default=['ktrapeznikov/albert-xlarge-v2-squad-v2',
69
                                 'deepset/roberta-base-squad2',
70
                                 'deepset/minilm-uncased-squad2',
71
                                 'ahotrod/albert_xxlargev1_squad2_512'],
72
                        help='The maximum number of answers to return',
73
                        required=False)
74
75
    parser.add_argument('--text_datasets',
76
                        nargs='*',
77
                        type=str,
78
                        default=['titleText-threeSentences.csv',
79
                                 'titleText-paragraphs.csv'],
80
                        help='The maximum number of answers to return',
81
                        required=False)
82
83
    parser.add_argument('--qa_datasets',
84
                        nargs='*',
85
                        type=str,
86
                        default=['qa-SQUAD.json'],
87
                        help='The maximum number of answers to return',
88
                        required=False)
89
90
    parser.add_argument('--report_out_dir',
91
                        type=str,
92
                        default='dpr-qa-report.csv',
93
                        help='The maximum number of answers to return',
94
                        required=False)
95
96
    parser.add_argument('--sample_out_dir',
97
                        type=str,
98
                        default='dpr-qa-sample.json',
99
                        help='The maximum number of answers to return',
100
                        required=False)
101
102
    args = parser.parse_args()
103
104
    report_generator = DPRTest.report_generator(max_seq_lens=args.max_seq_lens,
105
                                                max_seq_len_passages=args.max_seq_len_passages,
106
                                                max_seq_len_queries=args.max_seq_len_queries,
107
                                                embed_titles=args.embed_titles,
108
                                                context_window_sizes=args.context_window_sizes,
109
                                                doc_strides=args.doc_strides,
110
                                                retriever_top_ks=args.retriever_top_ks,
111
                                                reader_top_ks=args.reader_top_ks,
112
                                                reader_models=args.reader_models,
113
                                                text_datasets=args.text_datasets,
114
                                                qa_datasets=args.qa_datasets,
115
                                                report_out_dir=args.qa_datasets,
116
                                                sample_out_dir=args.sample_out_dir,
117
                                                )
118
119
    _, _ = report_generator.get_report()
120
121
122
if __name__ == "__main__":
123
    main(sys.argv[1:])