[4abb48]: / chexbert / src / label.py

Download this file

150 lines (127 with data), 6.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import argparse
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import utils
from models.bert_labeler import bert_labeler
from bert_tokenizer import tokenize
from transformers import BertTokenizer
from collections import OrderedDict
from datasets.unlabeled_dataset import UnlabeledDataset
from constants import *
from tqdm import tqdm
def collate_fn_no_labels(sample_list):
"""Custom collate function to pad reports in each batch to the max len,
where the reports have no associated labels
@param sample_list (List): A list of samples. Each sample is a dictionary with
keys 'imp', 'len' as returned by the __getitem__
function of ImpressionsDataset
@returns batch (dictionary): A dictionary with keys 'imp' and 'len' but now
'imp' is a tensor with padding and batch size as the
first dimension. 'len' is a list of the length of
each sequence in batch
"""
tensor_list = [s['imp'] for s in sample_list]
batched_imp = torch.nn.utils.rnn.pad_sequence(tensor_list,
batch_first=True,
padding_value=PAD_IDX)
len_list = [s['len'] for s in sample_list]
batch = {'imp': batched_imp, 'len': len_list}
return batch
def load_unlabeled_data(csv_path, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,
shuffle=False):
""" Create UnlabeledDataset object for the input reports
@param csv_path (string): path to csv file containing reports
@param batch_size (int): the batch size. As per the BERT repository, the max batch size
that can fit on a TITAN XP is 6 if the max sequence length
is 512, which is our case. We have 3 TITAN XP's
@param num_workers (int): how many worker processes to use to load data
@param shuffle (bool): whether to shuffle the data or not
@returns loader (dataloader): dataloader object for the reports
"""
collate_fn = collate_fn_no_labels
dset = UnlabeledDataset(csv_path)
loader = torch.utils.data.DataLoader(dset, batch_size=batch_size, shuffle=shuffle,
num_workers=NUM_WORKERS, collate_fn=collate_fn)
return loader
def label(checkpoint_path, csv_path):
"""Labels a dataset of reports
@param checkpoint_path (string): location of saved model checkpoint
@param csv_path (string): location of csv with reports
@returns y_pred (List[List[int]]): Labels for each of the 14 conditions, per report
"""
ld = load_unlabeled_data(csv_path)
model = bert_labeler()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 0: #works even if only 1 GPU available
print("Using", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) #to utilize multiple GPU's
model = model.to(device)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
else:
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
new_state_dict = OrderedDict()
for k, v in checkpoint['model_state_dict'].items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
was_training = model.training
model.eval()
y_pred = [[] for _ in range(len(CONDITIONS))]
print("\nBegin report impression labeling. The progress bar counts the # of batches completed:")
print("The batch size is %d" % BATCH_SIZE)
with torch.no_grad():
for i, data in enumerate(tqdm(ld)):
batch = data['imp'] #(batch_size, max_len)
batch = batch.to(device)
src_len = data['len']
batch_size = batch.shape[0]
attn_mask = utils.generate_attention_masks(batch, src_len, device)
out = model(batch, attn_mask)
for j in range(len(out)):
curr_y_pred = out[j].argmax(dim=1) #shape is (batch_size)
y_pred[j].append(curr_y_pred)
for j in range(len(y_pred)):
y_pred[j] = torch.cat(y_pred[j], dim=0)
if was_training:
model.train()
y_pred = [t.tolist() for t in y_pred]
return y_pred
def save_preds(y_pred, csv_path, out_path):
"""Save predictions as out_path/labeled_reports.csv
@param y_pred (List[List[int]]): list of predictions for each report
@param csv_path (string): path to csv containing reports
@param out_path (string): path to output directory
"""
y_pred = np.array(y_pred)
y_pred = y_pred.T
df = pd.DataFrame(y_pred, columns=CONDITIONS)
findings = pd.read_csv(csv_path, header=None)[0]
# dicom was used for labeling training set, but is not available for labeling predictions
# dicom_ids = pd.read_csv(csv_path)['dicom_id']
# df['dicom_id'] = dicom_ids.tolist()
df['findings'] = findings.tolist()
new_cols = ['findings'] +CONDITIONS #['dicom_id']
df = df[new_cols]
df.replace(0, np.nan, inplace=True) #blank class is NaN
df.replace(3, -1, inplace=True) #uncertain class is -1
df.replace(2, 0, inplace=True) #negative class is 0
df.to_csv(out_path, index=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Label a csv file containing radiology reports')
parser.add_argument('-d', '--data', type=str, nargs='?', required=True,
help='path to csv containing reports. The reports should be \
under the \"Report Impression\" column')
parser.add_argument('-o', '--output_dir', type=str, nargs='?', required=True,
help='path to intended output folder')
parser.add_argument('-c', '--checkpoint', type=str, nargs='?', required=True,
help='path to the pytorch checkpoint')
args = parser.parse_args()
csv_path = args.data
out_path = args.output_dir
checkpoint_path = args.checkpoint
y_pred = label(checkpoint_path, csv_path)
save_preds(y_pred, csv_path, out_path)