|
a |
|
b/models/model_set_mil.py |
|
|
1 |
from collections import OrderedDict |
|
|
2 |
from os.path import join |
|
|
3 |
import pdb |
|
|
4 |
|
|
|
5 |
import numpy as np |
|
|
6 |
|
|
|
7 |
import torch |
|
|
8 |
import torch.nn as nn |
|
|
9 |
import torch.nn.functional as F |
|
|
10 |
|
|
|
11 |
from models.model_utils import * |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
|
|
|
15 |
################################ |
|
|
16 |
### Deep Sets Implementation ### |
|
|
17 |
################################ |
|
|
18 |
class MIL_Sum_FC_surv(nn.Module): |
|
|
19 |
def __init__(self, omic_input_dim=None, fusion=None, size_arg = "small", dropout=0.25, n_classes=4): |
|
|
20 |
r""" |
|
|
21 |
Deep Sets Implementation. |
|
|
22 |
|
|
|
23 |
Args: |
|
|
24 |
omic_input_dim (int): Dimension size of genomic features. |
|
|
25 |
fusion (str): Fusion method (Choices: concat, bilinear, or None) |
|
|
26 |
size_arg (str): Size of NN architecture (Choices: small or large) |
|
|
27 |
dropout (float): Dropout rate |
|
|
28 |
n_classes (int): Output shape of NN |
|
|
29 |
""" |
|
|
30 |
super(MIL_Sum_FC_surv, self).__init__() |
|
|
31 |
self.fusion = fusion |
|
|
32 |
self.size_dict_path = {"small": [1024, 512, 256], "big": [1024, 512, 384]} |
|
|
33 |
self.size_dict_omic = {'small': [256, 256]} |
|
|
34 |
|
|
|
35 |
### Deep Sets Architecture Construction |
|
|
36 |
size = self.size_dict_path[size_arg] |
|
|
37 |
self.phi = nn.Sequential(*[nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout)]) |
|
|
38 |
self.rho = nn.Sequential(*[nn.Linear(size[1], size[2]), nn.ReLU(), nn.Dropout(dropout)]) |
|
|
39 |
|
|
|
40 |
### Constructing Genomic SNN |
|
|
41 |
if self.fusion != None: |
|
|
42 |
hidden = [256, 256] |
|
|
43 |
fc_omic = [SNN_Block(dim1=omic_input_dim, dim2=hidden[0])] |
|
|
44 |
for i, _ in enumerate(hidden[1:]): |
|
|
45 |
fc_omic.append(SNN_Block(dim1=hidden[i], dim2=hidden[i+1], dropout=0.25)) |
|
|
46 |
self.fc_omic = nn.Sequential(*fc_omic) |
|
|
47 |
|
|
|
48 |
if self.fusion == 'concat': |
|
|
49 |
self.mm = nn.Sequential(*[nn.Linear(256*2, size[2]), nn.ReLU(), nn.Linear(size[2], size[2]), nn.ReLU()]) |
|
|
50 |
elif self.fusion == 'bilinear': |
|
|
51 |
self.mm = BilinearFusion(dim1=256, dim2=256, scale_dim1=8, scale_dim2=8, mmhid=256) |
|
|
52 |
else: |
|
|
53 |
self.mm = None |
|
|
54 |
|
|
|
55 |
self.classifier = nn.Linear(size[2], n_classes) |
|
|
56 |
|
|
|
57 |
|
|
|
58 |
def relocate(self): |
|
|
59 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
60 |
if torch.cuda.device_count() >= 1: |
|
|
61 |
device_ids = list(range(torch.cuda.device_count())) |
|
|
62 |
self.phi = nn.DataParallel(self.phi, device_ids=device_ids).to('cuda:0') |
|
|
63 |
|
|
|
64 |
if self.fusion is not None: |
|
|
65 |
self.fc_omic = self.fc_omic.to(device) |
|
|
66 |
self.mm = self.mm.to(device) |
|
|
67 |
|
|
|
68 |
self.rho = self.rho.to(device) |
|
|
69 |
self.classifier = self.classifier.to(device) |
|
|
70 |
|
|
|
71 |
|
|
|
72 |
def forward(self, **kwargs): |
|
|
73 |
x_path = kwargs['x_path'] |
|
|
74 |
|
|
|
75 |
h_path = self.phi(x_path).sum(axis=0) |
|
|
76 |
h_path = self.rho(h_path) |
|
|
77 |
|
|
|
78 |
if self.fusion is not None: |
|
|
79 |
x_omic = kwargs['x_omic'] |
|
|
80 |
h_omic = self.fc_omic(x_omic).squeeze(dim=0) |
|
|
81 |
if self.fusion == 'bilinear': |
|
|
82 |
h = self.mm(h_path.unsqueeze(dim=0), h_omic.unsqueeze(dim=0)).squeeze() |
|
|
83 |
elif self.fusion == 'concat': |
|
|
84 |
h = self.mm(torch.cat([h_path, h_omic], axis=0)) |
|
|
85 |
else: |
|
|
86 |
h = h_path # [256] vector |
|
|
87 |
|
|
|
88 |
logits = self.classifier(h).unsqueeze(0) # logits needs to be a [1 x 4] vector |
|
|
89 |
Y_hat = torch.topk(logits, 1, dim = 1)[1] |
|
|
90 |
hazards = torch.sigmoid(logits) |
|
|
91 |
S = torch.cumprod(1 - hazards, dim=1) |
|
|
92 |
|
|
|
93 |
return hazards, S, Y_hat, None, None |
|
|
94 |
|
|
|
95 |
|
|
|
96 |
|
|
|
97 |
################################ |
|
|
98 |
# Attention MIL Implementation # |
|
|
99 |
################################ |
|
|
100 |
class MIL_Attention_FC_surv(nn.Module): |
|
|
101 |
def __init__(self, omic_input_dim=None, fusion=None, size_arg = "small", dropout=0.25, n_classes=4): |
|
|
102 |
r""" |
|
|
103 |
Attention MIL Implementation |
|
|
104 |
|
|
|
105 |
Args: |
|
|
106 |
omic_input_dim (int): Dimension size of genomic features. |
|
|
107 |
fusion (str): Fusion method (Choices: concat, bilinear, or None) |
|
|
108 |
size_arg (str): Size of NN architecture (Choices: small or large) |
|
|
109 |
dropout (float): Dropout rate |
|
|
110 |
n_classes (int): Output shape of NN |
|
|
111 |
""" |
|
|
112 |
super(MIL_Attention_FC_surv, self).__init__() |
|
|
113 |
self.fusion = fusion |
|
|
114 |
self.size_dict_path = {"small": [1024, 512, 256], "big": [1024, 512, 384]} |
|
|
115 |
self.size_dict_omic = {'small': [256, 256]} |
|
|
116 |
|
|
|
117 |
### Deep Sets Architecture Construction |
|
|
118 |
size = self.size_dict_path[size_arg] |
|
|
119 |
fc = [nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout)] |
|
|
120 |
attention_net = Attn_Net_Gated(L=size[1], D=size[2], dropout=dropout, n_classes=1) |
|
|
121 |
fc.append(attention_net) |
|
|
122 |
self.attention_net = nn.Sequential(*fc) |
|
|
123 |
self.rho = nn.Sequential(*[nn.Linear(size[1], size[2]), nn.ReLU(), nn.Dropout(dropout)]) |
|
|
124 |
|
|
|
125 |
### Constructing Genomic SNN |
|
|
126 |
if self.fusion is not None: |
|
|
127 |
hidden = [256, 256] |
|
|
128 |
fc_omic = [SNN_Block(dim1=omic_input_dim, dim2=hidden[0])] |
|
|
129 |
for i, _ in enumerate(hidden[1:]): |
|
|
130 |
fc_omic.append(SNN_Block(dim1=hidden[i], dim2=hidden[i+1], dropout=0.25)) |
|
|
131 |
self.fc_omic = nn.Sequential(*fc_omic) |
|
|
132 |
|
|
|
133 |
if self.fusion == 'concat': |
|
|
134 |
self.mm = nn.Sequential(*[nn.Linear(256*2, size[2]), nn.ReLU(), nn.Linear(size[2], size[2]), nn.ReLU()]) |
|
|
135 |
elif self.fusion == 'bilinear': |
|
|
136 |
self.mm = BilinearFusion(dim1=256, dim2=256, scale_dim1=8, scale_dim2=8, mmhid=256) |
|
|
137 |
else: |
|
|
138 |
self.mm = None |
|
|
139 |
|
|
|
140 |
self.classifier = nn.Linear(size[2], n_classes) |
|
|
141 |
|
|
|
142 |
|
|
|
143 |
def relocate(self): |
|
|
144 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
145 |
if torch.cuda.device_count() >= 1: |
|
|
146 |
device_ids = list(range(torch.cuda.device_count())) |
|
|
147 |
self.attention_net = nn.DataParallel(self.attention_net, device_ids=device_ids).to('cuda:0') |
|
|
148 |
|
|
|
149 |
if self.fusion is not None: |
|
|
150 |
self.fc_omic = self.fc_omic.to(device) |
|
|
151 |
self.mm = self.mm.to(device) |
|
|
152 |
|
|
|
153 |
self.rho = self.rho.to(device) |
|
|
154 |
self.classifier = self.classifier.to(device) |
|
|
155 |
|
|
|
156 |
|
|
|
157 |
def forward(self, **kwargs): |
|
|
158 |
x_path = kwargs['x_path'] |
|
|
159 |
|
|
|
160 |
A, h_path = self.attention_net(x_path) |
|
|
161 |
A = torch.transpose(A, 1, 0) |
|
|
162 |
A_raw = A |
|
|
163 |
A = F.softmax(A, dim=1) |
|
|
164 |
h_path = torch.mm(A, h_path) |
|
|
165 |
h_path = self.rho(h_path).squeeze() |
|
|
166 |
|
|
|
167 |
if self.fusion is not None: |
|
|
168 |
x_omic = kwargs['x_omic'] |
|
|
169 |
h_omic = self.fc_omic(x_omic) |
|
|
170 |
if self.fusion == 'bilinear': |
|
|
171 |
h = self.mm(h_path.unsqueeze(dim=0), h_omic.unsqueeze(dim=0)).squeeze() |
|
|
172 |
elif self.fusion == 'concat': |
|
|
173 |
h = self.mm(torch.cat([h_path, h_omic], axis=0)) |
|
|
174 |
else: |
|
|
175 |
h = h_path # [256] vector |
|
|
176 |
|
|
|
177 |
logits = self.classifier(h).unsqueeze(0) # logits needs to be a [1 x 4] vector |
|
|
178 |
Y_hat = torch.topk(logits, 1, dim = 1)[1] |
|
|
179 |
hazards = torch.sigmoid(logits) |
|
|
180 |
S = torch.cumprod(1 - hazards, dim=1) |
|
|
181 |
|
|
|
182 |
return hazards, S, Y_hat, None, None |
|
|
183 |
|
|
|
184 |
|
|
|
185 |
|
|
|
186 |
###################################### |
|
|
187 |
# Deep Attention MISL Implementation # |
|
|
188 |
###################################### |
|
|
189 |
class MIL_Cluster_FC_surv(nn.Module): |
|
|
190 |
def __init__(self, omic_input_dim=None, fusion=None, num_clusters=10, size_arg = "small", dropout=0.25, n_classes=4): |
|
|
191 |
r""" |
|
|
192 |
Attention MIL Implementation |
|
|
193 |
|
|
|
194 |
Args: |
|
|
195 |
omic_input_dim (int): Dimension size of genomic features. |
|
|
196 |
fusion (str): Fusion method (Choices: concat, bilinear, or None) |
|
|
197 |
size_arg (str): Size of NN architecture (Choices: small or large) |
|
|
198 |
dropout (float): Dropout rate |
|
|
199 |
n_classes (int): Output shape of NN |
|
|
200 |
""" |
|
|
201 |
super(MIL_Cluster_FC_surv, self).__init__() |
|
|
202 |
self.size_dict_path = {"small": [1024, 512, 256], "big": [1024, 512, 384]} |
|
|
203 |
self.size_dict_omic = {'small': [256, 256]} |
|
|
204 |
self.num_clusters = num_clusters |
|
|
205 |
self.fusion = fusion |
|
|
206 |
|
|
|
207 |
### FC Cluster layers + Pooling |
|
|
208 |
size = self.size_dict_path[size_arg] |
|
|
209 |
phis = [] |
|
|
210 |
for phenotype_i in range(num_clusters): |
|
|
211 |
phi = [nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout), |
|
|
212 |
nn.Linear(size[1], size[1]), nn.ReLU(), nn.Dropout(dropout)] |
|
|
213 |
phis.append(nn.Sequential(*phi)) |
|
|
214 |
self.phis = nn.ModuleList(phis) |
|
|
215 |
self.pool1d = nn.AdaptiveAvgPool1d(1) |
|
|
216 |
|
|
|
217 |
### WSI Attention MIL Construction |
|
|
218 |
fc = [nn.Linear(size[1], size[1]), nn.ReLU(), nn.Dropout(dropout)] |
|
|
219 |
attention_net = Attn_Net_Gated(L=size[1], D=size[2], dropout=dropout, n_classes=1) |
|
|
220 |
fc.append(attention_net) |
|
|
221 |
self.attention_net = nn.Sequential(*fc) |
|
|
222 |
self.rho = nn.Sequential(*[nn.Linear(size[1], size[2]), nn.ReLU(), nn.Dropout(dropout)]) |
|
|
223 |
|
|
|
224 |
### Genomic SNN Construction + Multimodal Fusion |
|
|
225 |
if fusion is not None: |
|
|
226 |
hidden = self.size_dict_omic['small'] |
|
|
227 |
fc_omic = [SNN_Block(dim1=omic_input_dim, dim2=hidden[0])] |
|
|
228 |
for i, _ in enumerate(hidden[1:]): |
|
|
229 |
fc_omic.append(SNN_Block(dim1=hidden[i], dim2=hidden[i+1], dropout=0.25)) |
|
|
230 |
self.fc_omic = nn.Sequential(*fc_omic) |
|
|
231 |
|
|
|
232 |
if fusion == 'concat': |
|
|
233 |
self.mm = nn.Sequential(*[nn.Linear(size[2]*2, size[2]), nn.ReLU(), nn.Linear(size[2], size[2]), nn.ReLU()]) |
|
|
234 |
elif self.fusion == 'bilinear': |
|
|
235 |
self.mm = BilinearFusion(dim1=256, dim2=256, scale_dim1=8, scale_dim2=8, mmhid=256) |
|
|
236 |
else: |
|
|
237 |
self.mm = None |
|
|
238 |
|
|
|
239 |
self.classifier = nn.Linear(size[2], n_classes) |
|
|
240 |
|
|
|
241 |
|
|
|
242 |
def relocate(self): |
|
|
243 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
244 |
if torch.cuda.device_count() >= 1: |
|
|
245 |
device_ids = list(range(torch.cuda.device_count())) |
|
|
246 |
self.attention_net = nn.DataParallel(self.attention_net, device_ids=device_ids).to('cuda:0') |
|
|
247 |
else: |
|
|
248 |
self.attention_net = self.attention_net.to(device) |
|
|
249 |
|
|
|
250 |
if self.fusion is not None: |
|
|
251 |
self.fc_omic = self.fc_omic.to(device) |
|
|
252 |
self.mm = self.mm.to(device) |
|
|
253 |
|
|
|
254 |
self.phis = self.phis.to(device) |
|
|
255 |
self.pool1d = self.pool1d.to(device) |
|
|
256 |
self.rho = self.rho.to(device) |
|
|
257 |
self.classifier = self.classifier.to(device) |
|
|
258 |
|
|
|
259 |
|
|
|
260 |
def forward(self, **kwargs): |
|
|
261 |
x_path = kwargs['x_path'] |
|
|
262 |
cluster_id = kwargs['cluster_id'].detach().cpu().numpy() |
|
|
263 |
|
|
|
264 |
### FC Cluster layers + Pooling |
|
|
265 |
h_cluster = [] |
|
|
266 |
for i in range(self.num_clusters): |
|
|
267 |
h_cluster_i = self.phis[i](x_path[cluster_id==i]) |
|
|
268 |
if h_cluster_i.shape[0] == 0: |
|
|
269 |
h_cluster_i = torch.zeros((1,512)).to(torch.device('cuda')) |
|
|
270 |
h_cluster.append(self.pool1d(h_cluster_i.T.unsqueeze(0)).squeeze(2)) |
|
|
271 |
h_cluster = torch.stack(h_cluster, dim=1).squeeze(0) |
|
|
272 |
|
|
|
273 |
### Attention MIL |
|
|
274 |
A, h_path = self.attention_net(h_cluster) |
|
|
275 |
A = torch.transpose(A, 1, 0) |
|
|
276 |
A_raw = A |
|
|
277 |
A = F.softmax(A, dim=1) |
|
|
278 |
h_path = torch.mm(A, h_path) |
|
|
279 |
h_path = self.rho(h_path).squeeze() |
|
|
280 |
|
|
|
281 |
### Attention MIL + Genomic Fusion |
|
|
282 |
if self.fusion is not None: |
|
|
283 |
x_omic = kwargs['x_omic'] |
|
|
284 |
h_omic = self.fc_omic(x_omic) |
|
|
285 |
if self.fusion == 'bilinear': |
|
|
286 |
h = self.mm(h_path.unsqueeze(dim=0), h_omic.unsqueeze(dim=0)).squeeze() |
|
|
287 |
elif self.fusion == 'concat': |
|
|
288 |
h = self.mm(torch.cat([h_path, h_omic], axis=0)) |
|
|
289 |
else: |
|
|
290 |
h = h_path |
|
|
291 |
|
|
|
292 |
logits = self.classifier(h).unsqueeze(0) |
|
|
293 |
Y_hat = torch.topk(logits, 1, dim = 1)[1] |
|
|
294 |
hazards = torch.sigmoid(logits) |
|
|
295 |
S = torch.cumprod(1 - hazards, dim=1) |
|
|
296 |
|
|
|
297 |
return hazards, S, Y_hat, None, None |