Diff of /scripts/train.py [000000] .. [d9566e]

Switch to unified view

a b/scripts/train.py
1
from collections import OrderedDict
2
from argparse import Namespace
3
import pickle
4
import os
5
import sys
6
7
import pytorch_lightning as pl
8
import torch
9
10
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
11
from sybil.utils.helpers import get_dataset
12
import sybil.utils.losses as losses
13
import sybil.utils.metrics as metrics
14
import sybil.utils.loading as loaders
15
import sybil.models.sybil as model
16
from sybil.parsing import parse_args
17
18
19
20
class SybilLightning(pl.LightningModule):
21
    """
22
    Lightning Module
23
    Methods:
24
        .log/.log_dict: log inputs to logger
25
    Notes:
26
        *_epoch_end method returns None
27
        self can log additional data structures to logger
28
        with self.logger.experiment.log_*
29
        (*= 'text', 'image', 'audio', 'confusion_matrix', 'histogram')
30
    """
31
32
    def __init__(self, args):
33
        super(SybilLightning, self).__init__()
34
        if isinstance(args, dict):
35
            args = Namespace(**args)
36
        self.args = args
37
        self.model = model.SybilNet(args)
38
        self.save_prefix = "default"
39
        self.save_hyperparameters(args)
40
        self._list_of_metrics = [
41
            metrics.get_classification_metrics,
42
            metrics.get_survival_metrics,
43
            metrics.get_risk_metrics
44
        ]
45
46
    def set_finetune(self, finetune_flag):
47
        return
48
49
    def forward(self, x):
50
        return self.model(x)
51
52
    def step(self, batch, batch_idx, optimizer_idx, log_key_prefix=""):
53
        model_output = self(batch["x"])
54
        logging_dict, predictions_dict = OrderedDict(), OrderedDict()
55
56
        if "exam" in batch:
57
            predictions_dict["exam"] = batch["exam"]
58
        if "y" in batch:
59
            predictions_dict["golds"] = batch["y"]
60
61
        if self.args.save_attention_scores:
62
            attentions = {k: v for k, v in model_output.items() if "attention" in k}
63
            predictions_dict.update(attentions)
64
65
        loss_fns = self.get_loss_functions(self.args)
66
        loss = 0
67
        for loss_fn in loss_fns:
68
            local_loss, local_log_dict, local_predictions_dict = loss_fn(
69
                model_output, batch, self, self.args
70
            )
71
            loss += local_loss
72
            logging_dict.update(local_log_dict)
73
            predictions_dict.update(local_predictions_dict)
74
        logging_dict = prefix_dict(logging_dict, log_key_prefix)
75
        predictions_dict = prefix_dict(predictions_dict, log_key_prefix)
76
        return loss, logging_dict, predictions_dict, model_output
77
78
    def training_step(self, batch, batch_idx, optimizer_idx=None):
79
        result = OrderedDict()
80
        loss, logging_dict, predictions_dict, _ = self.step(
81
            batch, batch_idx, optimizer_idx, log_key_prefix="train_"
82
        )
83
        logging_dict["train_loss"] = loss.detach()
84
        self.log_dict(logging_dict, prog_bar=False, on_step=True, on_epoch=True)
85
        result["logs"] = logging_dict
86
        self.log_tensor_dict(predictions_dict, prog_bar=False, logger=False)
87
        result.update(predictions_dict)
88
        # lightning expects 'loss' key in output dict. ow loss := None by default
89
        result["loss"] = loss
90
        return result
91
92
    def validation_step(self, batch, batch_idx, optimizer_idx=None):
93
        result = OrderedDict()
94
        loss, logging_dict, predictions_dict, _ = self.step(
95
            batch, batch_idx, optimizer_idx, log_key_prefix="val_"
96
        )
97
        logging_dict["val_loss"] = loss.detach()
98
        self.log_dict(logging_dict, prog_bar=True, sync_dist=True)
99
        result["logs"] = logging_dict
100
        if self.args.accelerator == "ddp":
101
            predictions_dict = gather_predictions_dict(predictions_dict)
102
        self.log_tensor_dict(predictions_dict, prog_bar=False, logger=False)
103
        result.update(predictions_dict)
104
        return result
105
106
    def test_step(self, batch, batch_idx, optimizer_idx=None):
107
        result = OrderedDict()
108
        loss, logging_dict, predictions_dict, model_output = self.step(
109
            batch, batch_idx, optimizer_idx, log_key_prefix="test_"
110
        )
111
        logging_dict["{}_loss".format(self.save_prefix)] = loss.detach()
112
        result["logs"] = logging_dict
113
114
        if self.args.accelerator == "ddp":
115
            predictions_dict = gather_predictions_dict(predictions_dict)
116
117
        self.log_tensor_dict(predictions_dict, prog_bar=False, logger=False)
118
        result.update(predictions_dict)
119
        return result
120
121
    def training_epoch_end(self, outputs):
122
        if len(outputs) == 0:
123
            return
124
        outputs = gather_step_outputs(outputs)
125
        # loss already logged in progress_bar_dict (get_progress_bar_dict()),
126
        # and logging twice creates issue
127
        del outputs["loss"]
128
        epoch_metrics = compute_epoch_metrics(
129
            self._list_of_metrics, outputs, self.args, self.device, key_prefix="train_"
130
        )
131
        for k, v in outputs["logs"].items():
132
            epoch_metrics[k] = v.mean()
133
        self.log_dict(epoch_metrics, prog_bar=True, logger=True)
134
135
    def validation_epoch_end(self, outputs):
136
        if len(outputs) == 0:
137
            return
138
        outputs = gather_step_outputs(outputs)
139
        epoch_metrics = compute_epoch_metrics(
140
            self._list_of_metrics, outputs, self.args, self.device, key_prefix="val_"
141
        )
142
        for k, v in outputs["logs"].items():
143
            epoch_metrics[k] = v.mean()
144
        self.log_dict(epoch_metrics, prog_bar=True, logger=True)
145
146
    def test_epoch_end(self, outputs):
147
        self.save_prefix= 'test'
148
        if len(outputs) == 0:
149
            return
150
        outputs = gather_step_outputs(outputs)
151
        epoch_metrics = compute_epoch_metrics(
152
            self._list_of_metrics, outputs, self.args, self.device, key_prefix="test_"
153
        )
154
155
        for k, v in outputs["logs"].items():
156
            epoch_metrics[k] = v.mean()
157
158
        self.log_dict(epoch_metrics, prog_bar=True, logger=True)
159
160
        # Dump metrics for use by dispatcher
161
        metrics_dict = {
162
            k[len(self.save_prefix) :]: v.mean().item()
163
            for k, v in outputs.items()
164
            if "loss" in k
165
        }
166
        metrics_dict.update(
167
            {
168
                k[len(self.save_prefix) :]: v.mean().item()
169
                for k, v in epoch_metrics.items()
170
            }
171
        )
172
        metrics_filename = "{}.{}.metrics".format(
173
            self.args.results_path, self.save_prefix
174
        )
175
        pickle.dump(metrics_dict, open(metrics_filename, "wb"))
176
        if self.args.save_predictions and self.global_rank == 0:
177
            predictions_dict = {
178
                k: v.cpu() if isinstance(v, torch.Tensor) else v
179
                for k, v in outputs.items()
180
            }
181
            predictions_filename = "{}.{}.predictions".format(
182
                self.args.results_path, self.save_prefix
183
            )
184
            pickle.dump(predictions_dict, open(predictions_filename, "wb"))
185
186
    def configure_optimizers(self):
187
        """
188
        Helper function to fetch optimizer based on args.
189
        """
190
        params = [param for param in self.model.parameters() if param.requires_grad]
191
        if self.args.optimizer == "adam":
192
            optimizer = torch.optim.Adam(
193
                params, lr=self.args.lr, weight_decay=self.args.weight_decay
194
            )
195
        elif self.args.optimizer == "adagrad":
196
            optimizer = torch.optim.Adagrad(
197
                params, lr=self.args.lr, weight_decay=self.args.weight_decay
198
            )
199
        elif self.args.optimizer == "sgd":
200
            optimizer = torch.optim.SGD(
201
                params,
202
                lr=self.args.lr,
203
                weight_decay=self.args.weight_decay,
204
                momentum=self.args.momentum,
205
            )
206
        else:
207
            raise Exception("Optimizer {} not supported!".format(self.args.optimizer))
208
209
        scheduler = {
210
            "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
211
                optimizer,
212
                patience=self.args.patience,
213
                factor=self.args.lr_decay,
214
                mode="min" if "loss" in self.args.tuning_metric else "max",
215
            ),
216
            "monitor": "val_{}".format(self.args.tuning_metric),
217
            "interval": "epoch",
218
            "frequency": 1,
219
        }
220
        return [optimizer], [scheduler]
221
222
    def log_tensor_dict(
223
        self,
224
        output,
225
        prog_bar=False,
226
        logger=True,
227
        on_step=None,
228
        on_epoch=None,
229
        sync_dist=False,
230
    ):
231
        dict_of_tensors = {
232
            k: v.float() for k, v in output.items() if isinstance(v, torch.Tensor)
233
        }
234
        self.log_dict(
235
            dict_of_tensors,
236
            prog_bar=prog_bar,
237
            logger=logger,
238
            on_step=on_step,
239
            on_epoch=on_epoch,
240
            sync_dist=sync_dist,
241
        )
242
243
    def get_loss_functions(self, args):
244
        loss_fns = [losses.get_survival_loss]
245
        if args.use_annotations:
246
            loss_fns.append(losses.get_annotation_loss)
247
248
        return loss_fns
249
250
251
def prefix_dict(d, prefix):
252
    r = OrderedDict()
253
    for k, v in d.items():
254
        r[prefix + k] = v
255
    return r
256
257
258
def gather_predictions_dict(predictions):
259
    gathered_preds = {
260
        k: concat_all_gather(v) if isinstance(v, torch.Tensor) else v
261
        for k, v in predictions.items()
262
    }
263
    return gathered_preds
264
265
266
def gather_step_outputs(outputs):
267
    output_dict = OrderedDict()
268
    if isinstance(outputs[-1], list):
269
        outputs = outputs[0]
270
271
    for k in outputs[-1].keys():
272
        if k == "logs":
273
            output_dict[k] = gather_step_outputs([output["logs"] for output in outputs])
274
        elif (
275
            isinstance(outputs[-1][k], torch.Tensor) and len(outputs[-1][k].shape) == 0
276
        ):
277
            output_dict[k] = torch.stack([output[k] for output in outputs])
278
        elif isinstance(outputs[-1][k], torch.Tensor):
279
            output_dict[k] = torch.cat([output[k] for output in outputs], dim=0)
280
        else:
281
            output_dict[k] = [output[k] for output in outputs]
282
    return output_dict
283
284
285
@torch.no_grad()
286
def concat_all_gather(tensor):
287
    """
288
    Performs all_gather operation on the provided tensors.
289
    *** Warning ***: torch.distributed.all_gather has no gradient.
290
    """
291
292
    tensors_gather = [
293
        torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
294
    ]
295
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
296
    output = torch.cat(tensors_gather, dim=0)
297
    return output
298
299
300
def compute_epoch_metrics(list_of_metrics, result_dict, args, device, key_prefix=""):
301
    stats_dict = OrderedDict()
302
303
    """
304
        Remove prefix from keys. For instance, convert:
305
        val_probs -> probs for standard handling in the metric fucntions
306
    """
307
    result_dict_wo_key_prefix = {}
308
309
    for k, v in result_dict.items():
310
        if isinstance(v, list) and isinstance(v[-1], torch.Tensor):
311
            v = torch.cat(v, dim=-1)
312
        if isinstance(v, torch.Tensor):
313
            v = v.cpu().numpy()
314
        if k == "meta":
315
            continue
316
        if key_prefix != "" and k.startswith(key_prefix):
317
            k_wo_prefix = k[len(key_prefix) :]
318
            result_dict_wo_key_prefix[k_wo_prefix] = v
319
        else:
320
            result_dict_wo_key_prefix[k] = v
321
322
    for k, v in result_dict["logs"].items():
323
        if k.startswith(key_prefix):
324
            result_dict_wo_key_prefix[k[len(key_prefix) :]] = v
325
326
    for metric_func in list_of_metrics:
327
        stats_wo_prefix = metric_func(result_dict_wo_key_prefix, args)
328
        for k, v in stats_wo_prefix.items():
329
            stats_dict[key_prefix + k] = torch.tensor(v, device=device)
330
331
    return stats_dict
332
333
334
def train(args):
335
    if not args.turn_off_checkpointing:
336
        checkpoint_callback = pl.callbacks.ModelCheckpoint(
337
            dirpath=args.save_dir,
338
            save_top_k=1,
339
            verbose=True,
340
            monitor="val_{}".format(args.tuning_metric)
341
            if args.tuning_metric is not None
342
            else None,
343
            save_last=True,
344
            mode="min"
345
            if args.tuning_metric is not None and "loss" in args.tuning_metric
346
            else "max",
347
        )
348
        args.callbacks = [checkpoint_callback]
349
    trainer = pl.Trainer.from_argparse_args(args)
350
    # Remove callbacks from args for safe pickling later
351
    args.callbacks = None
352
    args.num_nodes = trainer.num_nodes
353
    args.num_processes = trainer.num_processes
354
    args.world_size = args.num_nodes * args.num_processes
355
    args.global_rank = trainer.global_rank
356
    args.local_rank = trainer.local_rank
357
358
    train_dataset = loaders.get_train_dataset_loader(
359
        args, get_dataset(args.dataset, "train", args)
360
    )
361
    dev_dataset = loaders.get_eval_dataset_loader(
362
        args, get_dataset(args.dataset, "dev", args), False
363
    )
364
365
    args.censoring_distribution = metrics.get_censoring_dist(train_dataset.dataset)
366
    module = SybilLightning(args)
367
368
    # print args
369
    for key, value in sorted(vars(args).items()):
370
        print("{} -- {}".format(key.upper(), value))
371
372
    if args.snapshot is not None:
373
        module = module.load_from_checkpoint(checkpoint_path= args.snapshot, strict=False)
374
        module.args = args
375
    
376
    trainer.fit(module, train_dataset, dev_dataset)
377
    args.model_path = trainer.checkpoint_callback.best_model_path
378
    print("Saving args to {}".format(args.results_path))
379
    pickle.dump(vars(args), open(args.results_path, "wb"))
380
381
def test(args):
382
    trainer = pl.Trainer.from_argparse_args(args)
383
    # Remove callbacks from args for safe pickling later
384
    args.callbacks = None
385
    args.num_nodes = trainer.num_nodes
386
    args.num_processes = trainer.num_processes
387
    args.world_size = args.num_nodes * args.num_processes
388
    args.global_rank = trainer.global_rank
389
    args.local_rank = trainer.local_rank
390
391
    train_dataset = loaders.get_train_dataset_loader(
392
        args, get_dataset(args.dataset, "train", args)
393
    )
394
    test_dataset = loaders.get_eval_dataset_loader(
395
        args, get_dataset(args.dataset, "test", args), False
396
    )
397
398
    args.censoring_distribution = metrics.get_censoring_dist(train_dataset.dataset)
399
    module = SybilLightning(args)
400
    module = module.load_from_checkpoint(checkpoint_path= args.snapshot, strict=False)
401
    module.args = args
402
403
    # print args
404
    for key, value in sorted(vars(args).items()):
405
        print("{} -- {}".format(key.upper(), value))
406
407
    trainer.test(module, test_dataset)
408
409
    print("Saving args to {}".format(args.results_path))
410
    pickle.dump(vars(args), open(args.results_path, "wb"))
411
412
if __name__ == "__main__":
413
    args = parse_args()
414
    if args.train:
415
        train(args)
416
    elif args.test:
417
        test(args)