a b/utils/coattn_train_utils.py
1
import numpy as np
2
import torch
3
import pickle 
4
from utils.utils import *
5
import os
6
from collections import OrderedDict
7
8
from argparse import Namespace
9
from lifelines.utils import concordance_index
10
from sksurv.metrics import concordance_index_censored
11
12
13
def train_loop_survival_coattn(epoch, model, loader, optimizer, n_classes, writer=None, loss_fn=None, reg_fn=None, lambda_reg=0., gc=16):   
14
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 
15
    model.train()
16
    train_loss_surv, train_loss = 0., 0.
17
18
    print('\n')
19
    all_risk_scores = np.zeros((len(loader)))
20
    all_censorships = np.zeros((len(loader)))
21
    all_event_times = np.zeros((len(loader)))
22
    
23
    for batch_idx, (data_WSI, data_omic1, data_omic2, data_omic3, data_omic4, data_omic5, data_omic6, label, event_time, c) in enumerate(loader):
24
25
        data_WSI = data_WSI.to(device)
26
        data_omic1 = data_omic1.type(torch.FloatTensor).to(device)
27
        data_omic2 = data_omic2.type(torch.FloatTensor).to(device)
28
        data_omic3 = data_omic3.type(torch.FloatTensor).to(device)
29
        data_omic4 = data_omic4.type(torch.FloatTensor).to(device)
30
        data_omic5 = data_omic5.type(torch.FloatTensor).to(device)
31
        data_omic6 = data_omic6.type(torch.FloatTensor).to(device)
32
        label = label.type(torch.LongTensor).to(device)
33
        c = c.type(torch.FloatTensor).to(device)
34
35
        hazards, S, Y_hat, A  = model(x_path=data_WSI, x_omic1=data_omic1, x_omic2=data_omic2, x_omic3=data_omic3, x_omic4=data_omic4, x_omic5=data_omic5, x_omic6=data_omic6)
36
        loss = loss_fn(hazards=hazards, S=S, Y=label, c=c)
37
        loss_value = loss.item()
38
39
        if reg_fn is None:
40
            loss_reg = 0
41
        else:
42
            loss_reg = reg_fn(model) * lambda_reg
43
44
        risk = -torch.sum(S, dim=1).detach().cpu().numpy()
45
        all_risk_scores[batch_idx] = risk
46
        all_censorships[batch_idx] = c.item()
47
        all_event_times[batch_idx] = event_time
48
49
        train_loss_surv += loss_value
50
        train_loss += loss_value + loss_reg
51
52
        if (batch_idx + 1) % 100 == 0:
53
            print('batch {}, loss: {:.4f}, label: {}, event_time: {:.4f}, risk: {:.4f}, bag_size:'.format(batch_idx, loss_value + loss_reg, label.item(), float(event_time), float(risk)))
54
        loss = loss / gc + loss_reg
55
        loss.backward()
56
57
        if (batch_idx + 1) % gc == 0: 
58
            optimizer.step()
59
            optimizer.zero_grad()
60
61
    # calculate loss and error for epoch
62
    train_loss_surv /= len(loader)
63
    train_loss /= len(loader)
64
    c_index = concordance_index_censored((1-all_censorships).astype(bool), all_event_times, all_risk_scores, tied_tol=1e-08)[0]
65
    print('Epoch: {}, train_loss_surv: {:.4f}, train_loss: {:.4f}, train_c_index: {:.4f}'.format(epoch, train_loss_surv, train_loss, c_index))
66
67
    if writer:
68
        writer.add_scalar('train/loss_surv', train_loss_surv, epoch)
69
        writer.add_scalar('train/loss', train_loss, epoch)
70
        writer.add_scalar('train/c_index', c_index, epoch)
71
72
73
def validate_survival_coattn(cur, epoch, model, loader, n_classes, early_stopping=None, monitor_cindex=None, writer=None, loss_fn=None, reg_fn=None, lambda_reg=0., results_dir=None):
74
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
75
    model.eval()
76
    val_loss_surv, val_loss = 0., 0.
77
    all_risk_scores = np.zeros((len(loader)))
78
    all_censorships = np.zeros((len(loader)))
79
    all_event_times = np.zeros((len(loader)))
80
81
    for batch_idx, (data_WSI, data_omic1, data_omic2, data_omic3, data_omic4, data_omic5, data_omic6, label, event_time, c) in enumerate(loader):
82
83
        data_WSI = data_WSI.to(device)
84
        data_omic1 = data_omic1.type(torch.FloatTensor).to(device)
85
        data_omic2 = data_omic2.type(torch.FloatTensor).to(device)
86
        data_omic3 = data_omic3.type(torch.FloatTensor).to(device)
87
        data_omic4 = data_omic4.type(torch.FloatTensor).to(device)
88
        data_omic5 = data_omic5.type(torch.FloatTensor).to(device)
89
        data_omic6 = data_omic6.type(torch.FloatTensor).to(device)
90
        label = label.type(torch.LongTensor).to(device)
91
        c = c.type(torch.FloatTensor).to(device)
92
93
        with torch.no_grad():
94
            hazards, S, Y_hat, A = model(x_path=data_WSI, x_omic1=data_omic1, x_omic2=data_omic2, x_omic3=data_omic3, x_omic4=data_omic4, x_omic5=data_omic5, x_omic6=data_omic6) # return hazards, S, Y_hat, A_raw, results_dict
95
96
        loss = loss_fn(hazards=hazards, S=S, Y=label, c=c, alpha=0)
97
        loss_value = loss.item()
98
99
        if reg_fn is None:
100
            loss_reg = 0
101
        else:
102
            loss_reg = reg_fn(model) * lambda_reg
103
104
        risk = -torch.sum(S, dim=1).cpu().numpy()
105
        all_risk_scores[batch_idx] = risk
106
        all_censorships[batch_idx] = c.cpu().numpy()
107
        all_event_times[batch_idx] = event_time
108
109
        val_loss_surv += loss_value
110
        val_loss += loss_value + loss_reg
111
112
113
    val_loss_surv /= len(loader)
114
    val_loss /= len(loader)
115
    c_index = concordance_index_censored((1-all_censorships).astype(bool), all_event_times, all_risk_scores, tied_tol=1e-08)[0]
116
117
    if writer:
118
        writer.add_scalar('val/loss_surv', val_loss_surv, epoch)
119
        writer.add_scalar('val/loss', val_loss, epoch)
120
        writer.add_scalar('val/c-index', c_index, epoch)
121
122
    if early_stopping:
123
        assert results_dir
124
        early_stopping(epoch, val_loss_surv, model, ckpt_name=os.path.join(results_dir, "s_{}_minloss_checkpoint.pt".format(cur)))
125
        
126
        if early_stopping.early_stop:
127
            print("Early stopping")
128
            return True
129
130
    return False
131
132
133
def summary_survival_coattn(model, loader, n_classes):
134
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
135
    model.eval()
136
    test_loss = 0.
137
138
    all_risk_scores = np.zeros((len(loader)))
139
    all_censorships = np.zeros((len(loader)))
140
    all_event_times = np.zeros((len(loader)))
141
142
    slide_ids = loader.dataset.slide_data['slide_id']
143
    patient_results = {}
144
145
    for batch_idx, (data_WSI, data_omic1, data_omic2, data_omic3, data_omic4, data_omic5, data_omic6, label, event_time, c) in enumerate(loader):
146
        
147
        data_WSI = data_WSI.to(device)
148
        data_omic1 = data_omic1.type(torch.FloatTensor).to(device)
149
        data_omic2 = data_omic2.type(torch.FloatTensor).to(device)
150
        data_omic3 = data_omic3.type(torch.FloatTensor).to(device)
151
        data_omic4 = data_omic4.type(torch.FloatTensor).to(device)
152
        data_omic5 = data_omic5.type(torch.FloatTensor).to(device)
153
        data_omic6 = data_omic6.type(torch.FloatTensor).to(device)
154
        label = label.type(torch.LongTensor).to(device)
155
        c = c.type(torch.FloatTensor).to(device)
156
        slide_id = slide_ids.iloc[batch_idx]
157
158
        with torch.no_grad():
159
            hazards, survival, Y_hat, A  = model(x_path=data_WSI, x_omic1=data_omic1, x_omic2=data_omic2, x_omic3=data_omic3, x_omic4=data_omic4, x_omic5=data_omic5, x_omic6=data_omic6) # return hazards, S, Y_hat, A_raw, results_dict
160
161
        risk = np.asscalar(-torch.sum(survival, dim=1).cpu().numpy())
162
        event_time = np.asscalar(event_time)
163
        c = np.asscalar(c)
164
        all_risk_scores[batch_idx] = risk
165
        all_censorships[batch_idx] = c
166
        all_event_times[batch_idx] = event_time
167
        patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'risk': risk, 'disc_label': label.item(), 'survival': event_time, 'censorship': c}})
168
169
    c_index = concordance_index_censored((1-all_censorships).astype(bool), all_event_times, all_risk_scores, tied_tol=1e-08)[0]
170
    return patient_results, c_index