|
a |
|
b/utils/coattn_train_utils.py |
|
|
1 |
import numpy as np |
|
|
2 |
import torch |
|
|
3 |
import pickle |
|
|
4 |
from utils.utils import * |
|
|
5 |
import os |
|
|
6 |
from collections import OrderedDict |
|
|
7 |
|
|
|
8 |
from argparse import Namespace |
|
|
9 |
from lifelines.utils import concordance_index |
|
|
10 |
from sksurv.metrics import concordance_index_censored |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
def train_loop_survival_coattn(epoch, model, loader, optimizer, n_classes, writer=None, loss_fn=None, reg_fn=None, lambda_reg=0., gc=16): |
|
|
14 |
device=torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
15 |
model.train() |
|
|
16 |
train_loss_surv, train_loss = 0., 0. |
|
|
17 |
|
|
|
18 |
print('\n') |
|
|
19 |
all_risk_scores = np.zeros((len(loader))) |
|
|
20 |
all_censorships = np.zeros((len(loader))) |
|
|
21 |
all_event_times = np.zeros((len(loader))) |
|
|
22 |
|
|
|
23 |
for batch_idx, (data_WSI, data_omic1, data_omic2, data_omic3, data_omic4, data_omic5, data_omic6, label, event_time, c) in enumerate(loader): |
|
|
24 |
|
|
|
25 |
data_WSI = data_WSI.to(device) |
|
|
26 |
data_omic1 = data_omic1.type(torch.FloatTensor).to(device) |
|
|
27 |
data_omic2 = data_omic2.type(torch.FloatTensor).to(device) |
|
|
28 |
data_omic3 = data_omic3.type(torch.FloatTensor).to(device) |
|
|
29 |
data_omic4 = data_omic4.type(torch.FloatTensor).to(device) |
|
|
30 |
data_omic5 = data_omic5.type(torch.FloatTensor).to(device) |
|
|
31 |
data_omic6 = data_omic6.type(torch.FloatTensor).to(device) |
|
|
32 |
label = label.type(torch.LongTensor).to(device) |
|
|
33 |
c = c.type(torch.FloatTensor).to(device) |
|
|
34 |
|
|
|
35 |
hazards, S, Y_hat, A = model(x_path=data_WSI, x_omic1=data_omic1, x_omic2=data_omic2, x_omic3=data_omic3, x_omic4=data_omic4, x_omic5=data_omic5, x_omic6=data_omic6) |
|
|
36 |
loss = loss_fn(hazards=hazards, S=S, Y=label, c=c) |
|
|
37 |
loss_value = loss.item() |
|
|
38 |
|
|
|
39 |
if reg_fn is None: |
|
|
40 |
loss_reg = 0 |
|
|
41 |
else: |
|
|
42 |
loss_reg = reg_fn(model) * lambda_reg |
|
|
43 |
|
|
|
44 |
risk = -torch.sum(S, dim=1).detach().cpu().numpy() |
|
|
45 |
all_risk_scores[batch_idx] = risk |
|
|
46 |
all_censorships[batch_idx] = c.item() |
|
|
47 |
all_event_times[batch_idx] = event_time |
|
|
48 |
|
|
|
49 |
train_loss_surv += loss_value |
|
|
50 |
train_loss += loss_value + loss_reg |
|
|
51 |
|
|
|
52 |
if (batch_idx + 1) % 100 == 0: |
|
|
53 |
print('batch {}, loss: {:.4f}, label: {}, event_time: {:.4f}, risk: {:.4f}, bag_size:'.format(batch_idx, loss_value + loss_reg, label.item(), float(event_time), float(risk))) |
|
|
54 |
loss = loss / gc + loss_reg |
|
|
55 |
loss.backward() |
|
|
56 |
|
|
|
57 |
if (batch_idx + 1) % gc == 0: |
|
|
58 |
optimizer.step() |
|
|
59 |
optimizer.zero_grad() |
|
|
60 |
|
|
|
61 |
# calculate loss and error for epoch |
|
|
62 |
train_loss_surv /= len(loader) |
|
|
63 |
train_loss /= len(loader) |
|
|
64 |
c_index = concordance_index_censored((1-all_censorships).astype(bool), all_event_times, all_risk_scores, tied_tol=1e-08)[0] |
|
|
65 |
print('Epoch: {}, train_loss_surv: {:.4f}, train_loss: {:.4f}, train_c_index: {:.4f}'.format(epoch, train_loss_surv, train_loss, c_index)) |
|
|
66 |
|
|
|
67 |
if writer: |
|
|
68 |
writer.add_scalar('train/loss_surv', train_loss_surv, epoch) |
|
|
69 |
writer.add_scalar('train/loss', train_loss, epoch) |
|
|
70 |
writer.add_scalar('train/c_index', c_index, epoch) |
|
|
71 |
|
|
|
72 |
|
|
|
73 |
def validate_survival_coattn(cur, epoch, model, loader, n_classes, early_stopping=None, monitor_cindex=None, writer=None, loss_fn=None, reg_fn=None, lambda_reg=0., results_dir=None): |
|
|
74 |
device=torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
75 |
model.eval() |
|
|
76 |
val_loss_surv, val_loss = 0., 0. |
|
|
77 |
all_risk_scores = np.zeros((len(loader))) |
|
|
78 |
all_censorships = np.zeros((len(loader))) |
|
|
79 |
all_event_times = np.zeros((len(loader))) |
|
|
80 |
|
|
|
81 |
for batch_idx, (data_WSI, data_omic1, data_omic2, data_omic3, data_omic4, data_omic5, data_omic6, label, event_time, c) in enumerate(loader): |
|
|
82 |
|
|
|
83 |
data_WSI = data_WSI.to(device) |
|
|
84 |
data_omic1 = data_omic1.type(torch.FloatTensor).to(device) |
|
|
85 |
data_omic2 = data_omic2.type(torch.FloatTensor).to(device) |
|
|
86 |
data_omic3 = data_omic3.type(torch.FloatTensor).to(device) |
|
|
87 |
data_omic4 = data_omic4.type(torch.FloatTensor).to(device) |
|
|
88 |
data_omic5 = data_omic5.type(torch.FloatTensor).to(device) |
|
|
89 |
data_omic6 = data_omic6.type(torch.FloatTensor).to(device) |
|
|
90 |
label = label.type(torch.LongTensor).to(device) |
|
|
91 |
c = c.type(torch.FloatTensor).to(device) |
|
|
92 |
|
|
|
93 |
with torch.no_grad(): |
|
|
94 |
hazards, S, Y_hat, A = model(x_path=data_WSI, x_omic1=data_omic1, x_omic2=data_omic2, x_omic3=data_omic3, x_omic4=data_omic4, x_omic5=data_omic5, x_omic6=data_omic6) # return hazards, S, Y_hat, A_raw, results_dict |
|
|
95 |
|
|
|
96 |
loss = loss_fn(hazards=hazards, S=S, Y=label, c=c, alpha=0) |
|
|
97 |
loss_value = loss.item() |
|
|
98 |
|
|
|
99 |
if reg_fn is None: |
|
|
100 |
loss_reg = 0 |
|
|
101 |
else: |
|
|
102 |
loss_reg = reg_fn(model) * lambda_reg |
|
|
103 |
|
|
|
104 |
risk = -torch.sum(S, dim=1).cpu().numpy() |
|
|
105 |
all_risk_scores[batch_idx] = risk |
|
|
106 |
all_censorships[batch_idx] = c.cpu().numpy() |
|
|
107 |
all_event_times[batch_idx] = event_time |
|
|
108 |
|
|
|
109 |
val_loss_surv += loss_value |
|
|
110 |
val_loss += loss_value + loss_reg |
|
|
111 |
|
|
|
112 |
|
|
|
113 |
val_loss_surv /= len(loader) |
|
|
114 |
val_loss /= len(loader) |
|
|
115 |
c_index = concordance_index_censored((1-all_censorships).astype(bool), all_event_times, all_risk_scores, tied_tol=1e-08)[0] |
|
|
116 |
|
|
|
117 |
if writer: |
|
|
118 |
writer.add_scalar('val/loss_surv', val_loss_surv, epoch) |
|
|
119 |
writer.add_scalar('val/loss', val_loss, epoch) |
|
|
120 |
writer.add_scalar('val/c-index', c_index, epoch) |
|
|
121 |
|
|
|
122 |
if early_stopping: |
|
|
123 |
assert results_dir |
|
|
124 |
early_stopping(epoch, val_loss_surv, model, ckpt_name=os.path.join(results_dir, "s_{}_minloss_checkpoint.pt".format(cur))) |
|
|
125 |
|
|
|
126 |
if early_stopping.early_stop: |
|
|
127 |
print("Early stopping") |
|
|
128 |
return True |
|
|
129 |
|
|
|
130 |
return False |
|
|
131 |
|
|
|
132 |
|
|
|
133 |
def summary_survival_coattn(model, loader, n_classes): |
|
|
134 |
device=torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
135 |
model.eval() |
|
|
136 |
test_loss = 0. |
|
|
137 |
|
|
|
138 |
all_risk_scores = np.zeros((len(loader))) |
|
|
139 |
all_censorships = np.zeros((len(loader))) |
|
|
140 |
all_event_times = np.zeros((len(loader))) |
|
|
141 |
|
|
|
142 |
slide_ids = loader.dataset.slide_data['slide_id'] |
|
|
143 |
patient_results = {} |
|
|
144 |
|
|
|
145 |
for batch_idx, (data_WSI, data_omic1, data_omic2, data_omic3, data_omic4, data_omic5, data_omic6, label, event_time, c) in enumerate(loader): |
|
|
146 |
|
|
|
147 |
data_WSI = data_WSI.to(device) |
|
|
148 |
data_omic1 = data_omic1.type(torch.FloatTensor).to(device) |
|
|
149 |
data_omic2 = data_omic2.type(torch.FloatTensor).to(device) |
|
|
150 |
data_omic3 = data_omic3.type(torch.FloatTensor).to(device) |
|
|
151 |
data_omic4 = data_omic4.type(torch.FloatTensor).to(device) |
|
|
152 |
data_omic5 = data_omic5.type(torch.FloatTensor).to(device) |
|
|
153 |
data_omic6 = data_omic6.type(torch.FloatTensor).to(device) |
|
|
154 |
label = label.type(torch.LongTensor).to(device) |
|
|
155 |
c = c.type(torch.FloatTensor).to(device) |
|
|
156 |
slide_id = slide_ids.iloc[batch_idx] |
|
|
157 |
|
|
|
158 |
with torch.no_grad(): |
|
|
159 |
hazards, survival, Y_hat, A = model(x_path=data_WSI, x_omic1=data_omic1, x_omic2=data_omic2, x_omic3=data_omic3, x_omic4=data_omic4, x_omic5=data_omic5, x_omic6=data_omic6) # return hazards, S, Y_hat, A_raw, results_dict |
|
|
160 |
|
|
|
161 |
risk = np.asscalar(-torch.sum(survival, dim=1).cpu().numpy()) |
|
|
162 |
event_time = np.asscalar(event_time) |
|
|
163 |
c = np.asscalar(c) |
|
|
164 |
all_risk_scores[batch_idx] = risk |
|
|
165 |
all_censorships[batch_idx] = c |
|
|
166 |
all_event_times[batch_idx] = event_time |
|
|
167 |
patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'risk': risk, 'disc_label': label.item(), 'survival': event_time, 'censorship': c}}) |
|
|
168 |
|
|
|
169 |
c_index = concordance_index_censored((1-all_censorships).astype(bool), all_event_times, all_risk_scores, tied_tol=1e-08)[0] |
|
|
170 |
return patient_results, c_index |