a b/chexbert/src/bert_tokenizer.py
1
import pandas as pd
2
from transformers import BertTokenizer, AutoTokenizer
3
import json
4
from tqdm import tqdm
5
import argparse
6
7
def get_impressions_from_csv(path):
8
        df = pd.read_csv(path, header=None)
9
        imp = df[0]
10
        # if nan
11
        imp = imp.fillna('')
12
        imp = imp.str.strip()
13
        imp = imp.replace('\n',' ', regex=True)
14
        imp = imp.replace('\s+', ' ', regex=True)
15
        imp = imp.str.strip()
16
        return imp
17
18
def tokenize(impressions, tokenizer):
19
        new_impressions = []
20
        print("\nTokenizing report impressions. All reports are cut off at 512 tokens.")
21
        for i in tqdm(range(impressions.shape[0])):
22
                tokenized_imp = tokenizer.tokenize(impressions.iloc[i])
23
                if tokenized_imp: #not an empty report
24
                        res = tokenizer.encode_plus(tokenized_imp)['input_ids']
25
                        if len(res) > 512: #length exceeds maximum size
26
                                #print("report length bigger than 512")
27
                                res = res[:511] + [tokenizer.sep_token_id]
28
                        new_impressions.append(res)
29
                else: #an empty report
30
                        new_impressions.append([tokenizer.cls_token_id, tokenizer.sep_token_id]) 
31
        return new_impressions
32
33
def load_list(path):
34
        with open(path, 'r') as filehandle:
35
                impressions = json.load(filehandle)
36
                return impressions
37
38
if __name__ == "__main__":
39
        parser = argparse.ArgumentParser(description='Tokenize radiology report impressions and save as a list.')
40
        parser.add_argument('-d', '--data', type=str, nargs='?', required=True,
41
                            help='path to csv containing reports. The reports should be \
42
                            under the \"Report Impression\" column')
43
        parser.add_argument('-o', '--output_path', type=str, nargs='?', required=True,
44
                            help='path to intended output file')
45
        args = parser.parse_args()
46
        csv_path = args.data
47
        out_path = args.output_path
48
        
49
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
50
51
        impressions = get_impressions_from_csv(csv_path)
52
        new_impressions = tokenize(impressions, tokenizer)
53
        with open(out_path, 'w') as filehandle:
54
                json.dump(new_impressions, filehandle)