Switch to unified view

a b/development/summarizer-test-pipeline/main.py
1
import sys
2
import argparse
3
from summarizer_test import summarizertest
4
5
6
def main(args):
7
    parser = argparse.ArgumentParser(description='To get arguments')
8
9
    parser.add_argument('--generator_max_length',
10
                        nargs='*',
11
                        type=int,
12
                        default=[80, 110],
13
                        help='Maximum length of generator text',
14
                        required=False)
15
16
    parser.add_argument('--generator_min_length',
17
                        nargs='*',
18
                        type=int,
19
                        default=[10, 20],
20
                        help='Minimum length of generator text',
21
                        required=False)
22
23
    parser.add_argument('--generator_top_k',
24
                        nargs='*',
25
                        type=int,
26
                        default=[50, 100],
27
                        help='top_k of generator model',
28
                        required=False)
29
30
    parser.add_argument('--generator_length_penalty',
31
                        nargs='*',
32
                        type=int,
33
                        default=[0.8, 1],
34
                        help='length_penalty of generator model',
35
                        required=False)
36
37
    parser.add_argument('--generator_no_repeat_ngram_size',
38
                        nargs='*',
39
                        type=int,
40
                        default=[2, 3],
41
                        help='no_repeat_ngram_size of generator model',
42
                        required=False)
43
44
    parser.add_argument('--generator_sequences',
45
                        nargs='*',
46
                        type=int,
47
                        default=[2],
48
                        help='count of sequences that model returned',
49
                        required=False)
50
51
    parser.add_argument('--summarizer_models',
52
                        nargs='*',
53
                        type=str,
54
                        default=["sshleifer/distilbart-cnn-12-6"],
55
                        help='Models that want test in data',
56
                        required=False)
57
58
    parser.add_argument('--report_out_dir',
59
                        type=str,
60
                        default='summarizer-report.json',
61
                        help='The report of run all models with diffrent parameters',
62
                        required=False)
63
64
    parser.add_argument('--data_dir',
65
                        type=str,
66
                        default='summarizer-context.json',
67
                        help='dataset that include long text for summarization',
68
                        required=False)
69
70
    args = parser.parse_args()
71
72
    report_generator = summarizertest.ReportGenerator(
73
        models_names=args.summarizer_models,
74
        val_contexts_path=args.data_dir,
75
        report_path=args.report_out_dir,
76
        max_lengths=args.generator_max_length,
77
        min_lengths=args.generator_min_length,
78
        top_k=args.generator_top_k,
79
        penalty_l=args.generator_length_penalty,
80
        no_repeat_ngram_size=args.generator_no_repeat_ngram_size,
81
        num_return_sequences=args.generator_sequences,
82
        )
83
84
    report_generator.get_report()
85
86
87
if __name__ == "__main__":
88
    main(sys.argv[1:])