Diff of /train.py [000000] .. [352cae]

Switch to unified view

a b/train.py
1
import os
2
import random
3
import subprocess
4
import argparse
5
import time
6
import numpy as np
7
import pandas as pd
8
from sksurv.metrics import concordance_index_censored
9
10
import torch
11
from torch.utils.tensorboard import SummaryWriter
12
from torch.utils.data import Dataset
13
14
from model import HECTOR
15
from im4MEC import Im4MEC
16
from utils import *
17
from utils_loss import NLLSurvLoss
18
19
def set_seed():
20
    random.seed(0)
21
    np.random.seed(0)
22
    torch.manual_seed(0)
23
    torch.cuda.manual_seed_all(0)
24
    torch.backends.cudnn.benchmark = False
25
    torch.backends.cudnn.deterministic = True
26
27
def seed_worker(worker_id):
28
    worker_seed = torch.initial_seed() % 2**32
29
    np.random.seed(worker_seed)
30
    random.seed(worker_seed)
31
32
def evaluate_model(epoch, model, model_mol, device, loader, n_bins, writer, loss_fn, bins_values, train_BS, test_BS):
33
    model.eval()
34
35
    eval_loss = 0.
36
37
    all_survival_probs = np.zeros((len(loader), n_bins))
38
    all_risk_scores = np.zeros((len(loader))) # This is the computed risk score.
39
    all_censorships = np.zeros((len(loader))) # This is the binary censorship status: 1 censored; 0 uncensored (reccured).
40
    all_event_times = np.zeros((len(loader)))
41
42
    with torch.no_grad():
43
        for batch_idx, (data, features_flattened, label, event_time, censorship, stage, _) in enumerate(loader):
44
            data, label, censorship, stage = data.to(device), label.to(device), censorship.to(device), stage.to(device)
45
            _, _, Y_hat, _, _ = model_mol(features_flattened.to(device))
46
47
            hazards_prob, survival_prob, Y_hat, _, _ = model(data, stage, Y_hat.squeeze(1)) # Returns hazards, survival, Y_hat, A_raw, M.
48
49
            # We can emphasize on the contribution of uncensored patient cases only in training by minimizing a weighted sum of the 2 losses
50
            loss = loss_fn(hazards=hazards_prob, S=survival_prob, Y=label, c=censorship, alpha=0)
51
            eval_loss += loss.item()
52
53
            risk = -torch.sum(survival_prob, dim=1).cpu().numpy()
54
            all_risk_scores[batch_idx] = risk
55
            all_censorships[batch_idx] = censorship.cpu().numpy()
56
            all_event_times[batch_idx] = event_time
57
            all_survival_probs[batch_idx] = survival_prob.cpu().numpy()
58
59
    eval_loss /= len(loader)
60
61
    # Compute a few survival metrics.
62
    c_index = concordance_index_censored(
63
        event_indicator=(1-all_censorships).astype(bool), 
64
        event_time=all_event_times, 
65
        estimate=all_risk_scores, tied_tol=1e-08)[0]
66
    
67
    # Years of interest can be adapted in utils.py
68
    (BS, years_of_interest), (IBS, yearI_of_interest, yearF_of_interest), (_, meanAUC), (c_index_ipcw) = compute_surv_metrics_eval(bins_values, all_survival_probs, all_risk_scores, train_BS, test_BS)
69
    
70
    print(f'Eval epoch: {epoch}, loss: {eval_loss}, c_index: {c_index}, BS at each {years_of_interest}Y: {BS}, IBS and mean cumAUC from {yearI_of_interest}Y to {yearF_of_interest}Y: {IBS} and {meanAUC}')
71
72
    writer.add_scalar("Loss/eval", eval_loss, epoch)
73
    writer.add_scalar("C_index/eval", c_index, epoch)
74
    for i in range(len(years_of_interest)):
75
        writer.add_scalar(f"eval_metrics/BS_{str(years_of_interest[i])}Y", BS[i], epoch)
76
    writer.add_scalar(f"eval_metrics/IBS_{str(yearI_of_interest)}Y-{str(yearF_of_interest)}Y", IBS, epoch)
77
    writer.add_scalar(f"eval_metrics/meanAUC_{str(yearI_of_interest)}Y-{str(yearF_of_interest)}Y", meanAUC, epoch)
78
79
    return eval_loss, c_index, (BS, IBS, meanAUC, c_index_ipcw)
80
81
def train_one_epoch(epoch, model, model_mol, device, train_loader, optimizer, n_bins, writer, loss_fn):
82
    
83
    model.train()
84
    epoch_start_time = time.time()
85
    train_loss = 0.
86
87
    all_risk_scores = np.zeros((len(train_loader))) # Computed risk score.
88
    all_censorships = np.zeros((len(train_loader))) # Binary censorship status: 1 censored; 0 uncensored.
89
    all_event_times = np.zeros((len(train_loader))) # Real t event time or last follow-up.
90
91
    batch_start_time = time.time()
92
93
    for batch_idx, (data, features_flattened, label, event_time, censorship, stage, _) in enumerate(train_loader):
94
95
        data_load_duration = time.time() - batch_start_time
96
97
        data, label, censorship, stage = data.to(device), label.to(device), censorship.to(device), stage.to(device)
98
        # To get the image-based molecular class, non-merged features were used as this model was trained with way. 
99
        # Merged features could be used alternatively. 
100
        _, _, Y_hat, _, _ = model_mol(features_flattened.to(device))
101
102
        # Returns hazards, survival, Y_hat, A_raw, M.
103
        hazards_prob, survival_prob, Y_hat, _, _ = model(data, stage, Y_hat.squeeze(1)) 
104
105
        # Loss.
106
        loss = loss_fn(hazards=hazards_prob, S=survival_prob, Y=label, c=censorship)
107
        train_loss += loss.item()
108
109
        # Store outputs.
110
        risk = -torch.sum(survival_prob, dim=1).detach().cpu().numpy()
111
        all_risk_scores[batch_idx] = risk
112
        all_censorships[batch_idx] = censorship.item()
113
        all_event_times[batch_idx] = event_time
114
115
        # Backward pass.
116
        loss.backward()
117
118
        # Step.
119
        optimizer.step()
120
        optimizer.zero_grad()
121
122
        batch_duration = time.time() - batch_start_time
123
        batch_start_time = time.time()
124
125
        writer.add_scalar("duration/data_load", data_load_duration, epoch)
126
        writer.add_scalar("duration/batch", batch_duration, epoch)
127
128
    epoch_duration = time.time() - epoch_start_time
129
    print(f"Finished training on epoch {epoch} in {epoch_duration:.2f}s")
130
131
    train_loss /= len(train_loader)
132
133
    train_c_index = concordance_index_censored(
134
        event_indicator=(1-all_censorships).astype(bool), 
135
        event_time=all_event_times, 
136
        estimate=all_risk_scores, tied_tol=1e-08)[0]
137
     
138
    print(f'Epoch: {epoch}, epoch_duration : {epoch_duration}, train_loss: {train_loss}, train_c_index: {train_c_index}')
139
140
    filepath = os.path.join(writer.log_dir, f"{epoch}_checkpoint.pt")
141
    print(f"Saving model to {filepath}")
142
    torch.save(model.state_dict(), filepath)
143
144
    writer.add_scalar("duration/epoch", epoch_duration, epoch)
145
    writer.add_scalar("LR", get_lr(optimizer), epoch)
146
    writer.add_scalar("Loss/train", train_loss, epoch)
147
    writer.add_scalar("C_index/train", train_c_index, epoch)
148
149
def run_train_eval_loop(train_loader, val_loader, loss_fn, hparams, run_id, BS_data, checkpoint_model_molecular):
150
    writer = SummaryWriter(os.path.join("./runs", run_id))
151
    device = torch.device("cuda")
152
    n_bins = hparams["n_bins"]
153
154
    model = HECTOR(
155
        input_feature_size=hparams["input_feature_size"],
156
        precompression_layer=hparams["precompression_layer"],
157
        feature_size_comp=hparams["feature_size_comp"],
158
        feature_size_attn=hparams["feature_size_attn"],
159
        postcompression_layer=hparams["postcompression_layer"],
160
        feature_size_comp_post=hparams["feature_size_comp_post"],
161
        dropout=True,
162
        p_dropout_fc=hparams["p_dropout_fc"],
163
        p_dropout_atn=hparams["p_dropout_atn"],
164
        n_classes=n_bins,
165
166
        input_stage_size=hparams["input_stage_size"],
167
        embedding_dim_stage=hparams["embedding_dim_stage"],
168
        depth_dim_stage=hparams["depth_dim_stage"],
169
        act_fct_stage=hparams["act_fct_stage"],
170
        dropout_stage=hparams["dropout_stage"],
171
        p_dropout_stage=hparams["p_dropout_stage"],
172
173
        input_mol_size=4,
174
        embedding_dim_mol=hparams["embedding_dim_mol"],
175
        depth_dim_mol=hparams["depth_dim_mol"],
176
        act_fct_mol=hparams["act_fct_mol"],
177
        dropout_mol=hparams["dropout_mol"],
178
        p_dropout_mol=hparams["p_dropout_mol"],
179
180
        fusion_type=hparams["fusion_type"],
181
        use_bilinear=hparams["use_bilinear"],
182
        gate_hist=hparams["gate_hist"],
183
        gate_stage=hparams["gate_stage"],
184
        gate_mol=hparams["gate_mol"],
185
        scale=hparams["scale"],
186
    ).to(device)
187
    print('model')
188
    print_model(model)
189
190
    # This model is instance with the trained weights towards molecular classification and will be used in inference mode only.
191
    # NOTE: it is important that the molecular model, here im4MEC, has been trained on the same patients as training to avoid patient-level information leakage. 
192
    model_mol = Im4MEC(
193
        input_feature_size=hparams["input_feature_size"],
194
        precompression_layer=True,
195
        feature_size_comp=hparams["feature_size_comp_molecular"],
196
        feature_size_attn=hparams["feature_size_attn_molecular"],
197
        n_classes=hparams["n_classes_molecular"],
198
        dropout=True, # Not used in inference.
199
        p_dropout_fc=0.25,
200
        p_dropout_atn=0.25,
201
    ).to(device)
202
203
    msg = model_mol.load_state_dict(torch.load(checkpoint_model_molecular, map_location=device), strict=True)
204
    print(msg)
205
206
    for p in model_mol.parameters():
207
        p.requires_grad = False
208
    print(f"HECTOR and plugged-in im4MEC are built and checkpoints loaded")
209
    model_mol.eval()
210
211
    optimizer = torch.optim.Adam(
212
        filter(lambda p: p.requires_grad, model.parameters()),
213
        lr=hparams["initial_lr"],
214
        weight_decay=hparams["weight_decay"],
215
    )
216
217
    # Using a multi-step LR decay routine.
218
    milestones = [int(x) for x in hparams["milestones"].split(",")]
219
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
220
        optimizer, milestones=milestones, gamma=hparams["gamma_lr"]
221
    )
222
223
    monitor_tracker = MonitorBestModelEarlyStopping(
224
        patience=hparams["earlystop_patience"],
225
        min_epochs=hparams["earlystop_min_epochs"],
226
        saving_checkpoint=True,
227
    )
228
229
    for epoch in range(hparams["max_epochs"]):
230
231
        train_one_epoch(epoch, model, model_mol, device, train_loader, optimizer, n_bins, writer, loss_fn)
232
233
        # Evaluation on validation set.
234
        print("Evaluating model on validation set...")
235
        eval_loss, eval_cindex, eval_other_metrics = evaluate_model(epoch, model, model_mol, device, val_loader, n_bins, writer, loss_fn, hparams["bins_values"], *BS_data) 
236
        monitor_tracker(epoch, eval_loss, eval_cindex, eval_other_metrics, model, writer.log_dir)
237
238
        # Update LR decay.
239
        scheduler.step()
240
241
        if monitor_tracker.early_stop:
242
            print(f"Early stop criterion reached. Broke off training loop after epoch {epoch}.")
243
            break
244
    
245
    # Log the hyperparameters of the experiments.
246
    runs_history = {
247
        "run_id" : run_id,
248
        "best_epoch_CI" : monitor_tracker.best_epoch_CI,
249
        "best_CI_score" : monitor_tracker.best_CI_score,
250
        "best_epoch_loss": monitor_tracker.best_epoch_loss,
251
        "best_evalLoss" : monitor_tracker.eval_loss_min,
252
        "BS" : monitor_tracker.best_metrics_score[0],
253
        "IBS" : monitor_tracker.best_metrics_score[1],
254
        "cumMeanAUC" : monitor_tracker.best_metrics_score[2],
255
        "CI_ipwc" : monitor_tracker.best_metrics_score[3],
256
        **hparams,
257
    }
258
    with open('runs_history.txt', 'a') as filehandle:
259
        for _, value in runs_history.items():
260
            filehandle.write('%s;' % value)
261
        filehandle.write('\n')
262
263
    writer.close()
264
265
def prepare_datasets(args):
266
267
    df = pd.read_csv(args.manifest)
268
269
    n_bins = len(df['disc_label'].unique())
270
    assert n_bins == args.n_bins, 'mismatch between the number of bins passed in args and classes in dataset'
271
    bins_values = get_bins_time_value(df, n_bins, time_col_name='recurrence_years', label_time_col_name='disc_label')
272
    assert len(bins_values)==n_bins
273
    print(f'Read {args.manifest} dataset containing {len(df)} samples with {n_bins} bins of following values {bins_values}')
274
275
    # NOTE: you may need to use the two lines below depending on how the category is listed in the csv file. 
276
    #df.stage = df.stage.apply(lambda x : 'III' if 'III' in x else ('II' if 'II' in x else 'I')).astype("category")
277
    #df.stage = pd.Categorical(df['stage'], categories=['I', 'II', 'III'], ordered=True).codes
278
    print(f'stage taxonomy used: {df.stage.unique()}')
279
280
    try:
281
        training_set = df[df["split"] == "training"]
282
        validation_set = df[df["split"] == "validation"]
283
    except:
284
        raise Exception(
285
            f"Could not find training and validation splits in {args.manifest}"
286
        )
287
288
    train_split = FeatureBagsDataset(df=training_set,
289
                                    data_dir=args.data_dir,
290
                                    input_feature_size=args.input_feature_size, 
291
                                    stage_class=len(training_set.stage.unique()))
292
293
    val_split = FeatureBagsDataset(df=validation_set,
294
                                    data_dir=args.data_dir, 
295
                                    input_feature_size=args.input_feature_size, 
296
                                    stage_class=len(validation_set.stage.unique()))
297
298
    # To compute the Brier score (BS), you need a specific format of censorship and times.
299
    _, train_BS = get_survival_data_for_BS(training_set, time_col_name='recurrence_years')
300
    _, test_BS = get_survival_data_for_BS(validation_set, time_col_name='recurrence_years')
301
302
    return train_split, val_split, train_BS, test_BS, bins_values, len(df.stage.unique())
303
304
305
def main(args):
306
307
    # Set random seed for some degree of reproducibility. See PyTorch docs on this topic for caveats.
308
    # https://pytorch.org/docs/stable/notes/randomness.html#reproducibility
309
    set_seed()
310
311
    if not torch.cuda.is_available():
312
        raise Exception(
313
            "No CUDA device available. Training without one is not feasible."
314
        )
315
316
    git_sha = subprocess.check_output(["git", "describe", "--always"]).strip().decode("utf-8")
317
    train_run_id = f"{git_sha}_hp{args.hp}_{time.strftime('%Y%m%d-%H%M')}"
318
319
    train_split, val_split, train_BS, test_BS, bins_values, stage_taxonomy = prepare_datasets(args)
320
321
    print(f"=> Run ID {train_run_id}")
322
    print(f"=> Training on {len(train_split)} samples")
323
    print(f"=> Validating on {len(val_split)} samples")
324
325
    base_hparams = dict( 
326
        # Preprocessing settings. This should be changed with the dataset called accordingly.
327
        # Storing values here for readibility.
328
        n_bins=args.n_bins, # Partion on the continuous time scale.
329
        bins_values=bins_values,
330
        input_feature_size=args.input_feature_size,
331
        features_extraction=os.path.dirname(args.data_dir),
332
333
        # Settings that be changed in the loop:
334
        # Training.
335
        sampling_method="random",
336
        max_epochs=100,
337
        earlystop_warmup=0,
338
        earlystop_patience=30,
339
        earlystop_min_epochs=30,
340
341
        # Loss.
342
        alpha_surv = 0.0,
343
344
        # Optimizer.
345
        initial_lr=0.00003,
346
        milestones="2, 5, 15, 25",
347
        gamma_lr=0.1,
348
        weight_decay=0.00001,
349
350
        # Model architecture parameters. See model class for details.
351
        precompression_layer=True,
352
        feature_size_comp=512,
353
        feature_size_attn=256,
354
        postcompression_layer=True,
355
        feature_size_comp_post=128,
356
        p_dropout_fc=0.25,
357
        p_dropout_atn=0.25,
358
359
        # Model of molecular classification. In our case only inference is used. 
360
        n_classes_molecular=args.n_classes_molecular,
361
        feature_size_comp_molecular=args.feature_size_comp_molecular,
362
        feature_size_attn_molecular=args.feature_size_attn_molecular,
363
364
        # Fusion parameters.
365
        input_stage_size=stage_taxonomy,
366
        embedding_dim_stage=16,
367
        depth_dim_stage=1,
368
        act_fct_stage='elu',
369
        dropout_stage=True,
370
        p_dropout_stage=0.25,
371
        embedding_dim_mol=16,
372
        depth_dim_mol=1,
373
        act_fct_mol='elu',
374
        dropout_mol=True,
375
        p_dropout_mol=0.25,
376
        fusion_type='bilinear',
377
        use_bilinear=[True,True,True],
378
        gate_hist=True,
379
        gate_stage=True,
380
        gate_mol=True,
381
        scale=[2,1,1],
382
    )
383
384
    hparam_sets = [
385
        {
386
            **base_hparams,
387
        },
388
    ]
389
390
    hps = hparam_sets[args.hp]
391
392
393
    train_loader, val_loader = define_data_sampling(
394
            train_split,
395
            val_split,
396
            method=hps["sampling_method"],
397
            workers=args.workers,
398
    )
399
400
    run_train_eval_loop(
401
            train_loader=train_loader,
402
            val_loader=val_loader,
403
            loss_fn = NLLSurvLoss(alpha=hps["alpha_surv"]), # Used the Negative log likelihood loss.
404
            hparams=hps,
405
            run_id=train_run_id,
406
            BS_data = (train_BS, test_BS),
407
            checkpoint_model_molecular=args.checkpoint_model_molecular, 
408
    )
409
    print("Finished training.")
410
411
def get_args_parser():
412
    
413
    parser = argparse.ArgumentParser('Training script', add_help=False)
414
415
    parser.add_argument(
416
        "--manifest",
417
        type=str,
418
        help="CSV file listing all slides, their labels, and which split (train/test/val) they belong to.",
419
    )
420
    parser.add_argument(
421
        "--n_bins",
422
        type=int,
423
        help="Number of time intervals used to create the time labels. It should be the same as the manifest.",
424
    )
425
    parser.add_argument(
426
        "--data_dir",
427
        type=str,
428
        help="Directory where all *_features.h5 files are stored",
429
    )
430
    parser.add_argument(
431
        "--input_feature_size",
432
        help="The size of the input features from the feature bags. Recommend going by blocks from these output size [96, 96, 192, 192, 384, 384, 384, 384, 768, 768]",
433
        type=int,
434
        required=True,
435
    )
436
    parser.add_argument(
437
        "--checkpoint_model_molecular",
438
        type=str,
439
        default='',
440
        help="Path to checkpoint of im4MEC",
441
    )
442
    parser.add_argument(
443
        "--n_classes_molecular",
444
        type=int,
445
        required=True,
446
        help="",
447
    )
448
    parser.add_argument(
449
        "--feature_size_comp_molecular",
450
        type=int,
451
        required=True,
452
        help="Size of the model of the trained im4MEC. See in im4MEC.py",
453
    )
454
    parser.add_argument(
455
        "--feature_size_attn_molecular",
456
        type=int,
457
        required=True,
458
        help="Size of the model of the trained im4MEC. See in im4MEC.py",
459
    )
460
    parser.add_argument(
461
        "--workers",
462
        help="The number of workers to use for the data loaders.",
463
        type=int,
464
        default=4,
465
    )
466
    parser.add_argument(
467
        "--hp",
468
        type=int,
469
        required=True,
470
    )
471
472
    return parser
473
474
if __name__ == "__main__":
475
476
    parser = argparse.ArgumentParser('Training script', parents=[get_args_parser()])
477
    args = parser.parse_args()
478
479
    main(args)