--- a +++ b/development/dpr-qa-test-pipline/main.py @@ -0,0 +1,123 @@ +import sys +import argparse +from dpr_qa_test import DPRTest + + +def main(args): + parser = argparse.ArgumentParser(description='To get arguments') + + parser.add_argument('--max_seq_lens', + nargs='*', + type=int, + default=[256], + help='Max sequence length of one input text for the model', + required=False) + + parser.add_argument('--max_seq_len_passages', + nargs='*', + type=int, + default=[256], + help='Longest length of each passage/context sequence. Maximum number of tokens for the passage text. Longer ones will be cut down.', + required=False) + + parser.add_argument('--max_seq_len_queries', + nargs='*', + type=int, + default=[64], + help='Longest length of each query sequence. Maximum number of tokens for the query text. Longer ones will be cut down.', + required=False) + + parser.add_argument('--embed_titles', + nargs='*', + type=bool, + default=[True, False], + help='Whether to concatenate title and passage to a text pair that is then used to create the embedding.', + required=False) + + parser.add_argument('--context_window_sizes', + nargs='*', + type=int, + default=[150, 175], + help='The size, in characters, of the window around the answer span that is used when displaying the context around the answer.', + required=False) + + parser.add_argument('--doc_strides', + nargs='*', + type=int, + default=[100, 128], + help='Length of striding window for splitting long texts (used if len(text) > max_seq_len)', + required=False) + + parser.add_argument('--retriever_top_ks', + nargs='*', + type=int, + default=[3, 5, 7], + help='How many documents to return per query.', + required=False) + + parser.add_argument('--reader_top_ks', + nargs='*', + type=int, + default=[3, 5, 7], + help='The maximum number of answers to return', + required=False) + + parser.add_argument('--reader_models', + nargs='*', + type=str, + default=['ktrapeznikov/albert-xlarge-v2-squad-v2', + 'deepset/roberta-base-squad2', + 'deepset/minilm-uncased-squad2', + 'ahotrod/albert_xxlargev1_squad2_512'], + help='The maximum number of answers to return', + required=False) + + parser.add_argument('--text_datasets', + nargs='*', + type=str, + default=['titleText-threeSentences.csv', + 'titleText-paragraphs.csv'], + help='The maximum number of answers to return', + required=False) + + parser.add_argument('--qa_datasets', + nargs='*', + type=str, + default=['qa-SQUAD.json'], + help='The maximum number of answers to return', + required=False) + + parser.add_argument('--report_out_dir', + type=str, + default='dpr-qa-report.csv', + help='The maximum number of answers to return', + required=False) + + parser.add_argument('--sample_out_dir', + type=str, + default='dpr-qa-sample.json', + help='The maximum number of answers to return', + required=False) + + args = parser.parse_args() + + report_generator = DPRTest.report_generator(max_seq_lens=args.max_seq_lens, + max_seq_len_passages=args.max_seq_len_passages, + max_seq_len_queries=args.max_seq_len_queries, + embed_titles=args.embed_titles, + context_window_sizes=args.context_window_sizes, + doc_strides=args.doc_strides, + retriever_top_ks=args.retriever_top_ks, + reader_top_ks=args.reader_top_ks, + reader_models=args.reader_models, + text_datasets=args.text_datasets, + qa_datasets=args.qa_datasets, + report_out_dir=args.qa_datasets, + sample_out_dir=args.sample_out_dir, + ) + + _, _ = report_generator.get_report() + + +if __name__ == "__main__": + main(sys.argv[1:])