[7fc5df]: / scripts / benchmark.py

Download this file

112 lines (91 with data), 3.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
from timeit import default_timer as timer
from typing import List
import numpy as np
import pandas as pd
import torch
from flair.datasets import CONLL_03_DUTCH
from loguru import logger
from tqdm import tqdm
from deidentify.base import Document
from deidentify.taggers import CRFTagger, DeduceTagger, FlairTagger, TextTagger
from deidentify.tokenizer import TokenizerFactory
N_REPETITIONS = 5
N_SENTS = 5000
def load_data():
corpus = CONLL_03_DUTCH()
sentences = corpus.train[:N_SENTS]
tokens = sum(len(sent) for sent in sentences)
docs = [Document(name='', text=sent.to_plain_string(), annotations=[]) for sent in sentences]
return docs, tokens
def benchmark_tagger(tagger: TextTagger, docs: List[Document], num_tokens: int):
durations = []
for _ in tqdm(range(0, N_REPETITIONS), desc='Repetitions'):
start = timer()
tagger.annotate(docs)
end = timer()
durations.append(end - start)
if isinstance(tagger, FlairTagger) and torch.cuda.is_available():
torch.cuda.empty_cache()
return {
'mean': np.mean(durations),
'std': np.std(durations),
'tokens/s': num_tokens / np.mean(durations),
'docs/s': len(docs) / np.mean(durations),
'num_docs': len(docs),
'num_tokens': num_tokens
}
def main(args):
logger.info('Load data...')
documents, num_tokens = load_data()
logger.info('Initialize taggers...')
tokenizer_crf = TokenizerFactory().tokenizer(corpus='ons', disable=())
tokenizer_bilstm = TokenizerFactory().tokenizer(corpus='ons', disable=("tagger", "ner"))
taggers = [
('DEDUCE', DeduceTagger(verbose=True)),
('CRF', CRFTagger(
model='model_crf_ons_tuned-v0.1.0',
tokenizer=tokenizer_crf,
verbose=True
)),
('BiLSTM-CRF (large)', FlairTagger(
model='model_bilstmcrf_ons_large-v0.1.0',
tokenizer=tokenizer_bilstm,
mini_batch_size=args.bilstmcrf_large_batch_size,
verbose=True
)),
('BiLSTM-CRF (fast)', FlairTagger(
model='model_bilstmcrf_ons_fast-v0.1.0',
tokenizer=tokenizer_bilstm,
mini_batch_size=args.bilstmcrf_fast_batch_size,
verbose=True
))
]
benchmark_results = []
tagger_names = []
for tagger_name, tagger in taggers:
logger.info(f'Benchmark inference for tagger: {tagger_name}')
scores = benchmark_tagger(tagger, documents, num_tokens)
benchmark_results.append(scores)
tagger_names.append(tagger_name)
df = pd.DataFrame(data=benchmark_results, index=tagger_names)
df.to_csv(f'{args.benchmark_name}.csv')
logger.info('\n{}', df)
def arg_parser():
parser = argparse.ArgumentParser()
parser.add_argument("benchmark_name", type=str, help="Name of the benchmark.")
parser.add_argument(
"--bilstmcrf_large_batch_size",
type=int,
help="Batch size to use with the large model.",
default=256
)
parser.add_argument(
"--bilstmcrf_fast_batch_size",
type=int,
help="Batch size to use with the fast model.",
default=256
)
return parser.parse_args()
if __name__ == '__main__':
main(arg_parser())