Download this file

89 lines (72 with data), 3.2 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import sys
import argparse
from summarizer_test import summarizertest
def main(args):
parser = argparse.ArgumentParser(description='To get arguments')
parser.add_argument('--generator_max_length',
nargs='*',
type=int,
default=[80, 110],
help='Maximum length of generator text',
required=False)
parser.add_argument('--generator_min_length',
nargs='*',
type=int,
default=[10, 20],
help='Minimum length of generator text',
required=False)
parser.add_argument('--generator_top_k',
nargs='*',
type=int,
default=[50, 100],
help='top_k of generator model',
required=False)
parser.add_argument('--generator_length_penalty',
nargs='*',
type=int,
default=[0.8, 1],
help='length_penalty of generator model',
required=False)
parser.add_argument('--generator_no_repeat_ngram_size',
nargs='*',
type=int,
default=[2, 3],
help='no_repeat_ngram_size of generator model',
required=False)
parser.add_argument('--generator_sequences',
nargs='*',
type=int,
default=[2],
help='count of sequences that model returned',
required=False)
parser.add_argument('--summarizer_models',
nargs='*',
type=str,
default=["sshleifer/distilbart-cnn-12-6"],
help='Models that want test in data',
required=False)
parser.add_argument('--report_out_dir',
type=str,
default='summarizer-report.json',
help='The report of run all models with diffrent parameters',
required=False)
parser.add_argument('--data_dir',
type=str,
default='summarizer-context.json',
help='dataset that include long text for summarization',
required=False)
args = parser.parse_args()
report_generator = summarizertest.ReportGenerator(
models_names=args.summarizer_models,
val_contexts_path=args.data_dir,
report_path=args.report_out_dir,
max_lengths=args.generator_max_length,
min_lengths=args.generator_min_length,
top_k=args.generator_top_k,
penalty_l=args.generator_length_penalty,
no_repeat_ngram_size=args.generator_no_repeat_ngram_size,
num_return_sequences=args.generator_sequences,
)
report_generator.get_report()
if __name__ == "__main__":
main(sys.argv[1:])