Diff of /chexbert/src/run_bert.py [000000] .. [4abb48]

Switch to unified view

a b/chexbert/src/run_bert.py
1
import os
2
import argparse
3
import time
4
import torch
5
import torch.nn as nn
6
import numpy as np
7
import pandas as pd
8
import utils
9
from models.bert_labeler import bert_labeler
10
from datasets.impressions_dataset import ImpressionsDataset
11
from constants import *
12
13
def collate_fn_labels(sample_list):
14
     """Custom collate function to pad reports in each batch to the max len
15
     @param sample_list (List): A list of samples. Each sample is a dictionary with
16
                                keys 'imp', 'label', 'len' as returned by the __getitem__
17
                                function of ImpressionsDataset
18
     
19
     @returns batch (dictionary): A dictionary with keys 'imp', 'label', 'len' but now
20
                                  'imp' is a tensor with padding and batch size as the
21
                                   first dimension. 'label' is a stacked tensor of labels
22
                                   for the whole batch with batch size as first dim. And
23
                                   'len' is a list of the length of each sequence in batch
24
     """
25
     tensor_list = [s['imp'] for s in sample_list]
26
     batched_imp = torch.nn.utils.rnn.pad_sequence(tensor_list,
27
                                                   batch_first=True,
28
                                                   padding_value=PAD_IDX)
29
     label_list = [s['label'] for s in sample_list]
30
     batched_label = torch.stack(label_list, dim=0)
31
     len_list = [s['len'] for s in sample_list]
32
     
33
     batch = {'imp': batched_imp, 'label': batched_label, 'len': len_list}
34
     return batch
35
36
def load_data(train_csv_path, train_list_path, dev_csv_path,
37
              dev_list_path, train_weights=None, batch_size=BATCH_SIZE,
38
              shuffle=True, num_workers=NUM_WORKERS):
39
     """ Create ImpressionsDataset objects for train and test data
40
     @param train_csv_path (string): path to training csv file containing labels 
41
     @param train_list_path (string): path to list of encoded impressions for train set
42
     @param dev_csv_path (string): same as train_csv_path but for dev set
43
     @param dev_list_path (string): same as train_list_path but for dev set
44
     @param train_weights (torch.Tensor): Tensor of shape (train_set_size) containing weights
45
                                          for each training example, for the purposes of batch
46
                                          sampling with replacement
47
     @param batch_size (int): the batch size. As per the BERT repository, the max batch size
48
                              that can fit on a TITAN XP is 6 if the max sequence length
49
                              is 512, which is our case. We have 3 TITAN XP's
50
     @param shuffle (bool): Whether to shuffle data before each epoch, ignored if train_weights
51
                            is not None
52
     @param num_workers (int): How many worker processes to use to load data
53
54
     @returns dataloaders (tuple): tuple of two ImpressionsDataset objects, for train and dev sets
55
     """
56
     collate_fn = collate_fn_labels
57
     train_dset = ImpressionsDataset(train_csv_path, train_list_path)
58
     dev_dset = ImpressionsDataset(dev_csv_path, dev_list_path)
59
60
     if train_weights is None:
61
          train_loader = torch.utils.data.DataLoader(train_dset, batch_size=batch_size, shuffle=shuffle,
62
                                                     num_workers=num_workers, collate_fn=collate_fn)
63
     else:
64
          sampler = torch.utils.data.WeightedRandomSampler(weights=train_weights,
65
                                                           num_samples=len(train_weights),
66
                                                           replacement=True)
67
          train_loader = torch.utils.data.DataLoader(train_dset,
68
                                                     batch_size=batch_size,
69
                                                     num_workers=num_workers,
70
                                                     collate_fn=collate_fn,
71
                                                     sampler=sampler)
72
          
73
     dev_loader = torch.utils.data.DataLoader(dev_dset, batch_size=batch_size, shuffle=shuffle,
74
                                              num_workers=num_workers, collate_fn=collate_fn)
75
     dataloaders = (train_loader, dev_loader)
76
     return dataloaders
77
78
def load_test_data(test_csv_path, test_list_path, batch_size=BATCH_SIZE, 
79
                   num_workers=NUM_WORKERS, shuffle=False):
80
     """ Create ImpressionsDataset object for the test set
81
     @param test_csv_path (string): path to test csv file containing labels 
82
     @param test_list_path (string): path to list of encoded impressions
83
     @param batch_size (int): the batch size. As per the BERT repository, the max batch size
84
                              that can fit on a TITAN XP is 6 if the max sequence length
85
                              is 512, which is our case. We have 3 TITAN XP's 
86
     @param num_workers (int): how many worker processes to use to load data 
87
     @param shuffle (bool): whether to shuffle the data or not
88
89
     @returns test_loader (dataloader): dataloader object for test set
90
     """
91
     collate_fn = collate_fn_labels
92
     test_dset = ImpressionsDataset(test_csv_path, test_list_path)
93
     test_loader = torch.utils.data.DataLoader(test_dset, batch_size=batch_size, shuffle=shuffle,
94
                                               num_workers=num_workers, collate_fn=collate_fn)
95
     return test_loader
96
97
def train(save_path, dataloaders, f1_weights, model=None, device=None,
98
          optimizer=None, lr=LEARNING_RATE, log_every=LOG_EVERY,
99
          valid_niter=VALID_NITER, best_metric=0.0):
100
     """ Main training loop for the labeler
101
     @param save_path (string): Directory in which model weights are stored
102
     @param model (nn.Module): the labeler model to train, if applicable
103
     @param device (torch.device): device for the model. If model is not None, this
104
                                   parameter is required
105
     @param dataloaders (tuple): tuple of dataloader objects as returned by load_data
106
     @param f1_weights (dictionary): maps conditions to weights for blank, negation,
107
                                     uncertain and positxive f1 task averaging
108
     @param optimizer (torch.optim.Optimizer): the optimizer to use, if applicable
109
     @param lr (float): learning rate to use in the optimizer, ignored if optimizer
110
                        is not None
111
     @param log_every (int): number of iterations to log after
112
     @param valid_niter (int): number of iterations after which to evaluate the model and
113
                               save it if it is better than old best model
114
     @param best_metric (float): save checkpoints only if dev set performance is higher
115
                                than best_metric
116
     """
117
     if model and not device:
118
          print("train function error: Model specified but not device")
119
          return
120
     
121
     if model is None:
122
          model = bert_labeler(pretrain_path=PRETRAIN_PATH)
123
          model.train()   #put the model into train mode
124
          device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
125
          if torch.cuda.device_count() > 1:
126
               print("Using", torch.cuda.device_count(), "GPUs!")
127
               model = nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) #to utilize multiple GPU's
128
          model = model.to(device)
129
     else:
130
          model.train()
131
          
132
     if optimizer is None:
133
          optimizer = torch.optim.Adam(model.parameters(), lr=lr)
134
          
135
     begin_time = time.time()
136
     report_examples = 0
137
     report_loss = 0.0
138
     train_ld = dataloaders[0]
139
     dev_ld = dataloaders[1]
140
     loss_func = nn.CrossEntropyLoss(reduction='sum')
141
     
142
     print('begin labeler training')
143
     for epoch in range(NUM_EPOCHS):
144
          for i, data in enumerate(train_ld, 0):
145
               batch = data['imp'] #(batch_size, max_len)
146
               batch = batch.to(device)
147
               label = data['label'] #(batch_size, 14)
148
               label = label.permute(1, 0).to(device)
149
               src_len = data['len']
150
               batch_size = batch.shape[0]
151
               attn_mask = utils.generate_attention_masks(batch, src_len, device)
152
153
               optimizer.zero_grad()
154
               out = model(batch, attn_mask) #list of 14 tensors
155
156
               batch_loss = 0.0
157
               for j in range(len(out)):
158
                    batch_loss += loss_func(out[j], label[j])
159
                    
160
               report_loss += batch_loss
161
               report_examples += batch_size
162
               loss = batch_loss / batch_size     
163
               loss.backward()
164
               optimizer.step()
165
166
               if (i+1) % log_every == 0:
167
                    print('epoch %d, iter %d, avg_loss %.3f, time_elapsed %.3f sec' % (epoch+1, i+1, report_loss/report_examples,
168
                                                                                       time.time() - begin_time))
169
                    report_loss = 0.0
170
                    report_examples = 0
171
                    
172
               if (i+1) % valid_niter == 0:
173
                    print('\n begin validation')
174
                    metrics = utils.evaluate(model, dev_ld, device, f1_weights)
175
                    weighted = metrics['weighted']
176
                    kappas = metrics['kappa']
177
178
                    for j in range(len(CONDITIONS)):
179
                         print('%s kappa: %.3f' % (CONDITIONS[j], kappas[j]))
180
                    print('average: %.3f' % (np.mean(kappas)))
181
                         
182
                    #for j in range(len(CONDITIONS)):
183
                    #     print('%s weighted_f1: %.3f' % (CONDITIONS[j], weighted[j]))
184
                    #print('average of weighted_f1: %.3f' % (np.mean(weighted)))
185
186
                    for j in range(len(CONDITIONS)):
187
                         print('%s blank_f1:  %.3f, negation_f1: %.3f, uncertain_f1: %.3f, positive: %.3f' % (CONDITIONS[j],
188
                                                                                                              metrics['blank'][j],
189
                                                                                                              metrics['negation'][j],
190
                                                                                                              metrics['uncertain'][j],
191
                                                                                                              metrics['positive'][j]))
192
                         
193
                    metric_avg = np.mean(kappas)
194
                    if metric_avg > best_metric: #new best network
195
                         print("saving new best network!\n")
196
                         best_metric = metric_avg
197
                         path = os.path.join(save_path, "model_epoch%d_iter%d" % (epoch+1, i+1))
198
                         torch.save({'epoch': epoch+1,
199
                                     'model_state_dict': model.state_dict(),
200
                                     'optimizer_state_dict': optimizer.state_dict()},
201
                                    path)
202
203
def model_from_ckpt(model, ckpt_path):
204
     """Load up model checkpoint
205
     @param model (nn.Module): the module to be loaded
206
     @param ckpt_path (string): path to a checkpoint. If this is None, then
207
                                model is trained from scratch
208
209
     @return (tuple): tuple containing the model, optimizer and device
210
     """
211
     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
212
     if torch.cuda.device_count() > 1:
213
          print("Using", torch.cuda.device_count(), "GPUs!")
214
          model = nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) #to utilize multiple GPU's
215
     model = model.to(device)
216
     optimizer = torch.optim.Adam(model.parameters())
217
218
     checkpoint = torch.load(ckpt_path)
219
     model.load_state_dict(checkpoint['model_state_dict'])
220
     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
221
222
     return (model, optimizer, device)
223
224
if __name__ == '__main__':
225
     parser = argparse.ArgumentParser(description='Train BERT-base model on task of labeling 14 medical conditions.')
226
     parser.add_argument('--train_csv', type=str, nargs='?', required=True,
227
                         help='path to csv containing train reports.')
228
     parser.add_argument('--dev_csv', type=str, nargs='?', required=True,
229
                         help='path to csv containing dev reports.')
230
     parser.add_argument('--train_imp_list', type=str, nargs='?', required=True,
231
                         help='path to list of tokenized train set report impressions')
232
     parser.add_argument('--dev_imp_list', type=str, nargs='?', required=True,
233
                         help='path to list of tokenized dev set report impressions')
234
     parser.add_argument('--output_dir', type=str, nargs='?', required=True,
235
                         help='path to output directory where checkpoints will be saved')
236
     parser.add_argument('--checkpoint', type=str, nargs='?', required=False,
237
                         help='path to existing checkpoint to initialize weights from')
238
     args = parser.parse_args()
239
     train_csv_path = args.train_csv
240
     dev_csv_path = args.dev_csv
241
     train_imp_path = args.train_imp_list
242
     dev_imp_path = args.dev_imp_list
243
     out_path = args.output_dir
244
     checkpoint_path = args.checkpoint
245
246
     if checkpoint_path:
247
          model, optimizer, device = model_from_ckpt(bert_labeler(), checkpoint_path)
248
     else:
249
          model, optimizer, device = None, None, None
250
     f1_weights = utils.get_weighted_f1_weights(dev_csv_path)
251
     dataloaders = load_data(train_csv_path, train_imp_path, dev_csv_path, dev_imp_path)
252
     train(save_path=out_path,
253
           dataloaders=dataloaders,
254
           model=model,
255
           optimizer=optimizer,
256
           device=device, 
257
           f1_weights=f1_weights)
258