|
a |
|
b/chexbert/src/bert_tokenizer.py |
|
|
1 |
import pandas as pd |
|
|
2 |
from transformers import BertTokenizer, AutoTokenizer |
|
|
3 |
import json |
|
|
4 |
from tqdm import tqdm |
|
|
5 |
import argparse |
|
|
6 |
|
|
|
7 |
def get_impressions_from_csv(path): |
|
|
8 |
df = pd.read_csv(path, header=None) |
|
|
9 |
imp = df[0] |
|
|
10 |
# if nan |
|
|
11 |
imp = imp.fillna('') |
|
|
12 |
imp = imp.str.strip() |
|
|
13 |
imp = imp.replace('\n',' ', regex=True) |
|
|
14 |
imp = imp.replace('\s+', ' ', regex=True) |
|
|
15 |
imp = imp.str.strip() |
|
|
16 |
return imp |
|
|
17 |
|
|
|
18 |
def tokenize(impressions, tokenizer): |
|
|
19 |
new_impressions = [] |
|
|
20 |
print("\nTokenizing report impressions. All reports are cut off at 512 tokens.") |
|
|
21 |
for i in tqdm(range(impressions.shape[0])): |
|
|
22 |
tokenized_imp = tokenizer.tokenize(impressions.iloc[i]) |
|
|
23 |
if tokenized_imp: #not an empty report |
|
|
24 |
res = tokenizer.encode_plus(tokenized_imp)['input_ids'] |
|
|
25 |
if len(res) > 512: #length exceeds maximum size |
|
|
26 |
#print("report length bigger than 512") |
|
|
27 |
res = res[:511] + [tokenizer.sep_token_id] |
|
|
28 |
new_impressions.append(res) |
|
|
29 |
else: #an empty report |
|
|
30 |
new_impressions.append([tokenizer.cls_token_id, tokenizer.sep_token_id]) |
|
|
31 |
return new_impressions |
|
|
32 |
|
|
|
33 |
def load_list(path): |
|
|
34 |
with open(path, 'r') as filehandle: |
|
|
35 |
impressions = json.load(filehandle) |
|
|
36 |
return impressions |
|
|
37 |
|
|
|
38 |
if __name__ == "__main__": |
|
|
39 |
parser = argparse.ArgumentParser(description='Tokenize radiology report impressions and save as a list.') |
|
|
40 |
parser.add_argument('-d', '--data', type=str, nargs='?', required=True, |
|
|
41 |
help='path to csv containing reports. The reports should be \ |
|
|
42 |
under the \"Report Impression\" column') |
|
|
43 |
parser.add_argument('-o', '--output_path', type=str, nargs='?', required=True, |
|
|
44 |
help='path to intended output file') |
|
|
45 |
args = parser.parse_args() |
|
|
46 |
csv_path = args.data |
|
|
47 |
out_path = args.output_path |
|
|
48 |
|
|
|
49 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
50 |
|
|
|
51 |
impressions = get_impressions_from_csv(csv_path) |
|
|
52 |
new_impressions = tokenize(impressions, tokenizer) |
|
|
53 |
with open(out_path, 'w') as filehandle: |
|
|
54 |
json.dump(new_impressions, filehandle) |