|
a |
|
b/utils/utils.py |
|
|
1 |
import pickle |
|
|
2 |
import torch |
|
|
3 |
import numpy as np |
|
|
4 |
import torch.nn as nn |
|
|
5 |
import pdb |
|
|
6 |
|
|
|
7 |
import torch |
|
|
8 |
import numpy as np |
|
|
9 |
import torch.nn as nn |
|
|
10 |
from torchvision import transforms |
|
|
11 |
from torch.utils.data import DataLoader, Sampler, WeightedRandomSampler, RandomSampler, SequentialSampler, sampler |
|
|
12 |
import torch.optim as optim |
|
|
13 |
import pdb |
|
|
14 |
import torch.nn.functional as F |
|
|
15 |
import math |
|
|
16 |
from itertools import islice |
|
|
17 |
import collections |
|
|
18 |
|
|
|
19 |
from torch.utils.data.dataloader import default_collate |
|
|
20 |
import torch_geometric |
|
|
21 |
from torch_geometric.data import Batch |
|
|
22 |
|
|
|
23 |
device=torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
24 |
|
|
|
25 |
class SubsetSequentialSampler(Sampler): |
|
|
26 |
"""Samples elements sequentially from a given list of indices, without replacement. |
|
|
27 |
|
|
|
28 |
Arguments: |
|
|
29 |
indices (sequence): a sequence of indices |
|
|
30 |
""" |
|
|
31 |
def __init__(self, indices): |
|
|
32 |
self.indices = indices |
|
|
33 |
|
|
|
34 |
def __iter__(self): |
|
|
35 |
return iter(self.indices) |
|
|
36 |
|
|
|
37 |
def __len__(self): |
|
|
38 |
return len(self.indices) |
|
|
39 |
|
|
|
40 |
def collate_MIL(batch): |
|
|
41 |
img = torch.cat([item[0] for item in batch], dim = 0) |
|
|
42 |
label = torch.LongTensor([item[1] for item in batch]) |
|
|
43 |
return [img, label] |
|
|
44 |
|
|
|
45 |
def collate_features(batch): |
|
|
46 |
img = torch.cat([item[0] for item in batch], dim = 0) |
|
|
47 |
coords = np.vstack([item[1] for item in batch]) |
|
|
48 |
return [img, coords] |
|
|
49 |
|
|
|
50 |
def collate_MIL_survival(batch): |
|
|
51 |
img = torch.cat([item[0] for item in batch], dim = 0) |
|
|
52 |
omic = torch.cat([item[1] for item in batch], dim = 0).type(torch.FloatTensor) |
|
|
53 |
label = torch.LongTensor([item[2] for item in batch]) |
|
|
54 |
event_time = torch.FloatTensor([item[3] for item in batch]) |
|
|
55 |
c = torch.FloatTensor([item[4] for item in batch]) |
|
|
56 |
return [img, omic, label, event_time, c] |
|
|
57 |
|
|
|
58 |
def collate_MIL_survival_cluster(batch): |
|
|
59 |
img = torch.cat([item[0] for item in batch], dim = 0) |
|
|
60 |
cluster_ids = torch.cat([item[1] for item in batch], dim = 0).type(torch.LongTensor) |
|
|
61 |
omic = torch.cat([item[2] for item in batch], dim = 0).type(torch.FloatTensor) |
|
|
62 |
label = torch.LongTensor([item[3] for item in batch]) |
|
|
63 |
event_time = np.array([item[4] for item in batch]) |
|
|
64 |
c = torch.FloatTensor([item[5] for item in batch]) |
|
|
65 |
return [img, cluster_ids, omic, label, event_time, c] |
|
|
66 |
|
|
|
67 |
def collate_MIL_survival_sig(batch): |
|
|
68 |
img = torch.cat([item[0] for item in batch], dim = 0) |
|
|
69 |
omic1 = torch.cat([item[1] for item in batch], dim = 0).type(torch.FloatTensor) |
|
|
70 |
omic2 = torch.cat([item[2] for item in batch], dim = 0).type(torch.FloatTensor) |
|
|
71 |
omic3 = torch.cat([item[3] for item in batch], dim = 0).type(torch.FloatTensor) |
|
|
72 |
omic4 = torch.cat([item[4] for item in batch], dim = 0).type(torch.FloatTensor) |
|
|
73 |
omic5 = torch.cat([item[5] for item in batch], dim = 0).type(torch.FloatTensor) |
|
|
74 |
omic6 = torch.cat([item[6] for item in batch], dim = 0).type(torch.FloatTensor) |
|
|
75 |
|
|
|
76 |
label = torch.LongTensor([item[7] for item in batch]) |
|
|
77 |
event_time = np.array([item[8] for item in batch]) |
|
|
78 |
c = torch.FloatTensor([item[9] for item in batch]) |
|
|
79 |
return [img, omic1, omic2, omic3, omic4, omic5, omic6, label, event_time, c] |
|
|
80 |
|
|
|
81 |
def get_simple_loader(dataset, batch_size=1): |
|
|
82 |
kwargs = {'num_workers': 4} if device.type == "cuda" else {} |
|
|
83 |
loader = DataLoader(dataset, batch_size=batch_size, sampler = sampler.SequentialSampler(dataset), collate_fn = collate_MIL, **kwargs) |
|
|
84 |
return loader |
|
|
85 |
|
|
|
86 |
def get_split_loader(split_dataset, training = False, testing = False, weighted = False, mode='coattn', batch_size=1): |
|
|
87 |
""" |
|
|
88 |
return either the validation loader or training loader |
|
|
89 |
""" |
|
|
90 |
if mode == 'coattn': |
|
|
91 |
collate = collate_MIL_survival_sig |
|
|
92 |
elif mode == 'cluster': |
|
|
93 |
collate = collate_MIL_survival_cluster |
|
|
94 |
else: |
|
|
95 |
collate = collate_MIL_survival |
|
|
96 |
|
|
|
97 |
kwargs = {'num_workers': 4} if device.type == "cuda" else {} |
|
|
98 |
if not testing: |
|
|
99 |
if training: |
|
|
100 |
if weighted: |
|
|
101 |
weights = make_weights_for_balanced_classes_split(split_dataset) |
|
|
102 |
loader = DataLoader(split_dataset, batch_size=batch_size, sampler = WeightedRandomSampler(weights, len(weights)), collate_fn = collate, **kwargs) |
|
|
103 |
else: |
|
|
104 |
loader = DataLoader(split_dataset, batch_size=batch_size, sampler = RandomSampler(split_dataset), collate_fn = collate, **kwargs) |
|
|
105 |
else: |
|
|
106 |
loader = DataLoader(split_dataset, batch_size=1, sampler = SequentialSampler(split_dataset), collate_fn = collate, **kwargs) |
|
|
107 |
|
|
|
108 |
else: |
|
|
109 |
ids = np.random.choice(np.arange(len(split_dataset), int(len(split_dataset)*0.1)), replace = False) |
|
|
110 |
loader = DataLoader(split_dataset, batch_size=1, sampler = SubsetSequentialSampler(ids), collate_fn = collate, **kwargs ) |
|
|
111 |
|
|
|
112 |
return loader |
|
|
113 |
|
|
|
114 |
def get_optim(model, args): |
|
|
115 |
if args.opt == "adam": |
|
|
116 |
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.reg) |
|
|
117 |
elif args.opt == 'sgd': |
|
|
118 |
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.reg) |
|
|
119 |
else: |
|
|
120 |
raise NotImplementedError |
|
|
121 |
return optimizer |
|
|
122 |
|
|
|
123 |
def print_network(net): |
|
|
124 |
num_params = 0 |
|
|
125 |
num_params_train = 0 |
|
|
126 |
print(net) |
|
|
127 |
|
|
|
128 |
for param in net.parameters(): |
|
|
129 |
n = param.numel() |
|
|
130 |
num_params += n |
|
|
131 |
if param.requires_grad: |
|
|
132 |
num_params_train += n |
|
|
133 |
|
|
|
134 |
print('Total number of parameters: %d' % num_params) |
|
|
135 |
print('Total number of trainable parameters: %d' % num_params_train) |
|
|
136 |
|
|
|
137 |
|
|
|
138 |
def generate_split(cls_ids, val_num, test_num, samples, n_splits = 5, |
|
|
139 |
seed = 7, label_frac = 1.0, custom_test_ids = None): |
|
|
140 |
indices = np.arange(samples).astype(int) |
|
|
141 |
|
|
|
142 |
pdb.set_trace() |
|
|
143 |
if custom_test_ids is not None: |
|
|
144 |
indices = np.setdiff1d(indices, custom_test_ids) |
|
|
145 |
|
|
|
146 |
np.random.seed(seed) |
|
|
147 |
for i in range(n_splits): |
|
|
148 |
all_val_ids = [] |
|
|
149 |
all_test_ids = [] |
|
|
150 |
sampled_train_ids = [] |
|
|
151 |
|
|
|
152 |
if custom_test_ids is not None: # pre-built test split, do not need to sample |
|
|
153 |
all_test_ids.extend(custom_test_ids) |
|
|
154 |
|
|
|
155 |
for c in range(len(val_num)): |
|
|
156 |
possible_indices = np.intersect1d(cls_ids[c], indices) #all indices of this class |
|
|
157 |
remaining_ids = possible_indices |
|
|
158 |
|
|
|
159 |
if val_num[c] > 0: |
|
|
160 |
val_ids = np.random.choice(possible_indices, val_num[c], replace = False) # validation ids |
|
|
161 |
remaining_ids = np.setdiff1d(possible_indices, val_ids) #indices of this class left after validation |
|
|
162 |
all_val_ids.extend(val_ids) |
|
|
163 |
|
|
|
164 |
if custom_test_ids is None and test_num[c] > 0: # sample test split |
|
|
165 |
|
|
|
166 |
test_ids = np.random.choice(remaining_ids, test_num[c], replace = False) |
|
|
167 |
remaining_ids = np.setdiff1d(remaining_ids, test_ids) |
|
|
168 |
all_test_ids.extend(test_ids) |
|
|
169 |
|
|
|
170 |
if label_frac == 1: |
|
|
171 |
sampled_train_ids.extend(remaining_ids) |
|
|
172 |
|
|
|
173 |
else: |
|
|
174 |
sample_num = math.ceil(len(remaining_ids) * label_frac) |
|
|
175 |
slice_ids = np.arange(sample_num) |
|
|
176 |
sampled_train_ids.extend(remaining_ids[slice_ids]) |
|
|
177 |
|
|
|
178 |
yield sorted(sampled_train_ids), sorted(all_val_ids), sorted(all_test_ids) |
|
|
179 |
|
|
|
180 |
|
|
|
181 |
def nth(iterator, n, default=None): |
|
|
182 |
if n is None: |
|
|
183 |
return collections.deque(iterator, maxlen=0) |
|
|
184 |
else: |
|
|
185 |
return next(islice(iterator,n, None), default) |
|
|
186 |
|
|
|
187 |
def calculate_error(Y_hat, Y): |
|
|
188 |
error = 1. - Y_hat.float().eq(Y.float()).float().mean().item() |
|
|
189 |
|
|
|
190 |
return error |
|
|
191 |
|
|
|
192 |
def make_weights_for_balanced_classes_split(dataset): |
|
|
193 |
N = float(len(dataset)) |
|
|
194 |
weight_per_class = [N/len(dataset.slide_cls_ids[c]) for c in range(len(dataset.slide_cls_ids))] |
|
|
195 |
weight = [0] * int(N) |
|
|
196 |
for idx in range(len(dataset)): |
|
|
197 |
y = dataset.getlabel(idx) |
|
|
198 |
weight[idx] = weight_per_class[y] |
|
|
199 |
|
|
|
200 |
return torch.DoubleTensor(weight) |
|
|
201 |
|
|
|
202 |
def initialize_weights(module): |
|
|
203 |
for m in module.modules(): |
|
|
204 |
if isinstance(m, nn.Linear): |
|
|
205 |
nn.init.xavier_normal_(m.weight) |
|
|
206 |
m.bias.data.zero_() |
|
|
207 |
|
|
|
208 |
elif isinstance(m, nn.BatchNorm1d): |
|
|
209 |
nn.init.constant_(m.weight, 1) |
|
|
210 |
nn.init.constant_(m.bias, 0) |
|
|
211 |
|
|
|
212 |
|
|
|
213 |
def dfs_freeze(model): |
|
|
214 |
for name, child in model.named_children(): |
|
|
215 |
for param in child.parameters(): |
|
|
216 |
param.requires_grad = False |
|
|
217 |
dfs_freeze(child) |
|
|
218 |
|
|
|
219 |
|
|
|
220 |
def dfs_unfreeze(model): |
|
|
221 |
for name, child in model.named_children(): |
|
|
222 |
for param in child.parameters(): |
|
|
223 |
param.requires_grad = True |
|
|
224 |
dfs_unfreeze(child) |
|
|
225 |
|
|
|
226 |
|
|
|
227 |
# divide continuous time scale into k discrete bins in total, T_cont \in {[0, a_1), [a_1, a_2), ...., [a_(k-1), inf)} |
|
|
228 |
# Y = T_discrete is the discrete event time: |
|
|
229 |
# Y = 0 if T_cont \in (-inf, 0), Y = 1 if T_cont \in [0, a_1), Y = 2 if T_cont in [a_1, a_2), ..., Y = k if T_cont in [a_(k-1), inf) |
|
|
230 |
# discrete hazards: discrete probability of h(t) = P(Y=t | Y>=t, X), t = 0,1,2,...,k |
|
|
231 |
# S: survival function: P(Y > t | X) |
|
|
232 |
# all patients are alive from (-inf, 0) by definition, so P(Y=0) = 0 |
|
|
233 |
# h(0) = 0 ---> do not need to model |
|
|
234 |
# S(0) = P(Y > 0 | X) = 1 ----> do not need to model |
|
|
235 |
''' |
|
|
236 |
Summary: neural network is hazard probability function, h(t) for t = 1,2,...,k |
|
|
237 |
corresponding Y = 1, ..., k. h(t) represents the probability that patient dies in [0, a_1), [a_1, a_2), ..., [a_(k-1), inf] |
|
|
238 |
''' |
|
|
239 |
# def neg_likelihood_loss(hazards, Y, c): |
|
|
240 |
# batch_size = len(Y) |
|
|
241 |
# Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,k |
|
|
242 |
# c = c.view(batch_size, 1).float() #censorship status, 0 or 1 |
|
|
243 |
# S = torch.cumprod(1 - hazards, dim=1) # surival is cumulative product of 1 - hazards |
|
|
244 |
# # without padding, S(1) = S[0], h(1) = h[0] |
|
|
245 |
# S_padded = torch.cat([torch.ones_like(c), S], 1) #S(0) = 1, all patients are alive from (-inf, 0) by definition |
|
|
246 |
# # after padding, S(0) = S[0], S(1) = S[1], etc, h(1) = h[0] |
|
|
247 |
# #h[y] = h(1) |
|
|
248 |
# #S[1] = S(1) |
|
|
249 |
# neg_l = - c * torch.log(torch.gather(S_padded, 1, Y)) - (1 - c) * (torch.log(torch.gather(S_padded, 1, Y-1)) + torch.log(hazards[:, Y-1])) |
|
|
250 |
# neg_l = neg_l.mean() |
|
|
251 |
# return neg_l |
|
|
252 |
|
|
|
253 |
|
|
|
254 |
# divide continuous time scale into k discrete bins in total, T_cont \in {[0, a_1), [a_1, a_2), ...., [a_(k-1), inf)} |
|
|
255 |
# Y = T_discrete is the discrete event time: |
|
|
256 |
# Y = -1 if T_cont \in (-inf, 0), Y = 0 if T_cont \in [0, a_1), Y = 1 if T_cont in [a_1, a_2), ..., Y = k-1 if T_cont in [a_(k-1), inf) |
|
|
257 |
# discrete hazards: discrete probability of h(t) = P(Y=t | Y>=t, X), t = -1,0,1,2,...,k |
|
|
258 |
# S: survival function: P(Y > t | X) |
|
|
259 |
# all patients are alive from (-inf, 0) by definition, so P(Y=-1) = 0 |
|
|
260 |
# h(-1) = 0 ---> do not need to model |
|
|
261 |
# S(-1) = P(Y > -1 | X) = 1 ----> do not need to model |
|
|
262 |
''' |
|
|
263 |
Summary: neural network is hazard probability function, h(t) for t = 0,1,2,...,k-1 |
|
|
264 |
corresponding Y = 0,1, ..., k-1. h(t) represents the probability that patient dies in [0, a_1), [a_1, a_2), ..., [a_(k-1), inf] |
|
|
265 |
''' |
|
|
266 |
def nll_loss(hazards, S, Y, c, alpha=0.4, eps=1e-7): |
|
|
267 |
batch_size = len(Y) |
|
|
268 |
Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,k |
|
|
269 |
c = c.view(batch_size, 1).float() #censorship status, 0 or 1 |
|
|
270 |
if S is None: |
|
|
271 |
S = torch.cumprod(1 - hazards, dim=1) # surival is cumulative product of 1 - hazards |
|
|
272 |
# without padding, S(0) = S[0], h(0) = h[0] |
|
|
273 |
S_padded = torch.cat([torch.ones_like(c), S], 1) #S(-1) = 0, all patients are alive from (-inf, 0) by definition |
|
|
274 |
# after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0] |
|
|
275 |
#h[y] = h(1) |
|
|
276 |
#S[1] = S(1) |
|
|
277 |
uncensored_loss = -(1 - c) * (torch.log(torch.gather(S_padded, 1, Y).clamp(min=eps)) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps))) |
|
|
278 |
censored_loss = - c * torch.log(torch.gather(S_padded, 1, Y+1).clamp(min=eps)) |
|
|
279 |
neg_l = censored_loss + uncensored_loss |
|
|
280 |
loss = (1-alpha) * neg_l + alpha * uncensored_loss |
|
|
281 |
loss = loss.mean() |
|
|
282 |
return loss |
|
|
283 |
|
|
|
284 |
def ce_loss(hazards, S, Y, c, alpha=0.4, eps=1e-7): |
|
|
285 |
batch_size = len(Y) |
|
|
286 |
Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,k |
|
|
287 |
c = c.view(batch_size, 1).float() #censorship status, 0 or 1 |
|
|
288 |
if S is None: |
|
|
289 |
S = torch.cumprod(1 - hazards, dim=1) # surival is cumulative product of 1 - hazards |
|
|
290 |
# without padding, S(0) = S[0], h(0) = h[0] |
|
|
291 |
# after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0] |
|
|
292 |
#h[y] = h(1) |
|
|
293 |
#S[1] = S(1) |
|
|
294 |
S_padded = torch.cat([torch.ones_like(c), S], 1) |
|
|
295 |
reg = -(1 - c) * (torch.log(torch.gather(S_padded, 1, Y)+eps) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps))) |
|
|
296 |
ce_l = - c * torch.log(torch.gather(S, 1, Y).clamp(min=eps)) - (1 - c) * torch.log(1 - torch.gather(S, 1, Y).clamp(min=eps)) |
|
|
297 |
loss = (1-alpha) * ce_l + alpha * reg |
|
|
298 |
loss = loss.mean() |
|
|
299 |
return loss |
|
|
300 |
|
|
|
301 |
# def nll_loss(hazards, Y, c, S=None, alpha=0.4, eps=1e-8): |
|
|
302 |
# batch_size = len(Y) |
|
|
303 |
# Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,k |
|
|
304 |
# c = c.view(batch_size, 1).float() #censorship status, 0 or 1 |
|
|
305 |
# if S is None: |
|
|
306 |
# S = 1 - torch.cumsum(hazards, dim=1) # surival is cumulative product of 1 - hazards |
|
|
307 |
# uncensored_loss = -(1 - c) * (torch.log(torch.gather(hazards, 1, Y).clamp(min=eps))) |
|
|
308 |
# censored_loss = - c * torch.log(torch.gather(S, 1, Y).clamp(min=eps)) |
|
|
309 |
# loss = censored_loss + uncensored_loss |
|
|
310 |
# loss = loss.mean() |
|
|
311 |
# return loss |
|
|
312 |
|
|
|
313 |
class CrossEntropySurvLoss(object): |
|
|
314 |
def __init__(self, alpha=0.15): |
|
|
315 |
self.alpha = alpha |
|
|
316 |
|
|
|
317 |
def __call__(self, hazards, S, Y, c, alpha=None): |
|
|
318 |
if alpha is None: |
|
|
319 |
return ce_loss(hazards, S, Y, c, alpha=self.alpha) |
|
|
320 |
else: |
|
|
321 |
return ce_loss(hazards, S, Y, c, alpha=alpha) |
|
|
322 |
|
|
|
323 |
# loss_fn(hazards=hazards, S=S, Y=Y_hat, c=c, alpha=0) |
|
|
324 |
class NLLSurvLoss_dep(object): |
|
|
325 |
def __init__(self, alpha=0.15): |
|
|
326 |
self.alpha = alpha |
|
|
327 |
|
|
|
328 |
def __call__(self, hazards, S, Y, c, alpha=None): |
|
|
329 |
if alpha is None: |
|
|
330 |
return nll_loss(hazards, S, Y, c, alpha=self.alpha) |
|
|
331 |
else: |
|
|
332 |
return nll_loss(hazards, S, Y, c, alpha=alpha) |
|
|
333 |
# h_padded = torch.cat([torch.zeros_like(c), hazards], 1) |
|
|
334 |
#reg = - (1 - c) * (torch.log(torch.gather(hazards, 1, Y)) + torch.gather(torch.cumsum(torch.log(1-h_padded), dim=1), 1, Y)) |
|
|
335 |
|
|
|
336 |
|
|
|
337 |
class CoxSurvLoss(object): |
|
|
338 |
def __call__(hazards, S, c, **kwargs): |
|
|
339 |
# This calculation credit to Travers Ching https://github.com/traversc/cox-nnet |
|
|
340 |
# Cox-nnet: An artificial neural network method for prognosis prediction of high-throughput omics data |
|
|
341 |
current_batch_len = len(S) |
|
|
342 |
R_mat = np.zeros([current_batch_len, current_batch_len], dtype=int) |
|
|
343 |
for i in range(current_batch_len): |
|
|
344 |
for j in range(current_batch_len): |
|
|
345 |
R_mat[i,j] = S[j] >= S[i] |
|
|
346 |
|
|
|
347 |
R_mat = torch.FloatTensor(R_mat).to(device) |
|
|
348 |
theta = hazards.reshape(-1) |
|
|
349 |
exp_theta = torch.exp(theta) |
|
|
350 |
loss_cox = -torch.mean((theta - torch.log(torch.sum(exp_theta*R_mat, dim=1))) * (1-c)) |
|
|
351 |
return loss_cox |
|
|
352 |
|
|
|
353 |
def l1_reg_all(model, reg_type=None): |
|
|
354 |
l1_reg = None |
|
|
355 |
|
|
|
356 |
for W in model.parameters(): |
|
|
357 |
if l1_reg is None: |
|
|
358 |
l1_reg = torch.abs(W).sum() |
|
|
359 |
else: |
|
|
360 |
l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1) |
|
|
361 |
return l1_reg |
|
|
362 |
|
|
|
363 |
def l1_reg_modules(model, reg_type=None): |
|
|
364 |
l1_reg = 0 |
|
|
365 |
|
|
|
366 |
l1_reg += l1_reg_all(model.fc_omic) |
|
|
367 |
l1_reg += l1_reg_all(model.mm) |
|
|
368 |
|
|
|
369 |
return l1_reg |
|
|
370 |
|
|
|
371 |
def l1_reg_omic(model, reg_type=None): |
|
|
372 |
l1_reg = 0 |
|
|
373 |
|
|
|
374 |
if hasattr(model, 'fc_omic'): |
|
|
375 |
l1_reg += l1_reg_all(model.fc_omic) |
|
|
376 |
else: |
|
|
377 |
l1_reg += l1_reg_all(model) |
|
|
378 |
|
|
|
379 |
return l1_reg |
|
|
380 |
|
|
|
381 |
def get_custom_exp_code(args): |
|
|
382 |
r""" |
|
|
383 |
Updates the argparse.NameSpace with a custom experiment code. |
|
|
384 |
|
|
|
385 |
Args: |
|
|
386 |
- args (NameSpace) |
|
|
387 |
|
|
|
388 |
Returns: |
|
|
389 |
- args (NameSpace) |
|
|
390 |
""" |
|
|
391 |
exp_code = '_'.join(args.split_dir.split('_')[:2]) |
|
|
392 |
dataset_path = 'datasets_csv' |
|
|
393 |
param_code = '' |
|
|
394 |
|
|
|
395 |
### Model Type |
|
|
396 |
if args.model_type == 'porpoise_mmf': |
|
|
397 |
param_code += 'PorpoiseMMF' |
|
|
398 |
elif args.model_type == 'porpoise_amil': |
|
|
399 |
param_code += 'PorpoiseAMIL' |
|
|
400 |
elif args.model_type == 'max_net' or args.model_type == 'snn': |
|
|
401 |
param_code += 'SNN' |
|
|
402 |
elif args.model_type == 'amil': |
|
|
403 |
param_code += 'AMIL' |
|
|
404 |
elif args.model_type == 'deepset': |
|
|
405 |
param_code += 'DS' |
|
|
406 |
elif args.model_type == 'mi_fcn': |
|
|
407 |
param_code += 'MIFCN' |
|
|
408 |
elif args.model_type == 'mcat': |
|
|
409 |
param_code += 'MCAT' |
|
|
410 |
else: |
|
|
411 |
raise NotImplementedError |
|
|
412 |
|
|
|
413 |
### Loss Function |
|
|
414 |
param_code += '_%s' % args.bag_loss |
|
|
415 |
if args.bag_loss in ['nll_surv']: |
|
|
416 |
param_code += '_a%s' % str(args.alpha_surv) |
|
|
417 |
|
|
|
418 |
### Learning Rate |
|
|
419 |
if args.lr != 2e-4: |
|
|
420 |
param_code += '_lr%s' % format(args.lr, '.0e') |
|
|
421 |
|
|
|
422 |
### L1-Regularization |
|
|
423 |
if args.reg_type != 'None': |
|
|
424 |
param_code += '_%sreg%s' % (args.reg_type, format(args.lambda_reg, '.0e')) |
|
|
425 |
|
|
|
426 |
if args.dropinput: |
|
|
427 |
param_code += '_drop%s' % str(int(args.dropinput*100)) |
|
|
428 |
|
|
|
429 |
param_code += '_%s' % args.which_splits.split("_")[0] |
|
|
430 |
|
|
|
431 |
### Batch Size |
|
|
432 |
if args.batch_size != 1: |
|
|
433 |
param_code += '_b%s' % str(args.batch_size) |
|
|
434 |
|
|
|
435 |
### Gradient Accumulation |
|
|
436 |
if args.gc != 1: |
|
|
437 |
param_code += '_gc%s' % str(args.gc) |
|
|
438 |
|
|
|
439 |
### Applying Which Features |
|
|
440 |
if args.apply_sigfeats: |
|
|
441 |
param_code += '_sig' |
|
|
442 |
dataset_path += '_sig' |
|
|
443 |
elif args.apply_mutsig: |
|
|
444 |
param_code += '_mutsig' |
|
|
445 |
dataset_path += '_mutsig' |
|
|
446 |
|
|
|
447 |
### Fusion Operation |
|
|
448 |
if args.fusion != "None": |
|
|
449 |
param_code += '_' + args.fusion |
|
|
450 |
|
|
|
451 |
### Updating |
|
|
452 |
args.exp_code = exp_code + "_" + param_code |
|
|
453 |
args.param_code = param_code |
|
|
454 |
args.dataset_path = dataset_path |
|
|
455 |
|
|
|
456 |
return args |