|
a |
|
b/model.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
import torch.nn.functional as F |
|
|
4 |
import numpy as np |
|
|
5 |
import math |
|
|
6 |
|
|
|
7 |
class AttnNet(nn.Module): |
|
|
8 |
# Adapted from https://github.com/mahmoodlab/CLAM/blob/master/models/model_clam.py |
|
|
9 |
# Lu, M.Y., Williamson, D.F.K., Chen, T.Y. et al. Data-efficient and weakly supervised computational pathology on whole-slide images. Nat Biomed Eng 5, 555–570 (2021). https://doi.org/10.1038/s41551-020-00682-w |
|
|
10 |
|
|
|
11 |
def __init__(self, L=1024, D=256, dropout=False, p_dropout_atn=0.25, n_classes=1): |
|
|
12 |
super(AttnNet, self).__init__() |
|
|
13 |
|
|
|
14 |
self.attention_a = [nn.Linear(L, D), nn.Tanh()] |
|
|
15 |
|
|
|
16 |
self.attention_b = [nn.Linear(L, D), nn.Sigmoid()] |
|
|
17 |
|
|
|
18 |
if dropout: |
|
|
19 |
self.attention_a.append(nn.Dropout(p_dropout_atn)) |
|
|
20 |
self.attention_b.append(nn.Dropout(p_dropout_atn)) |
|
|
21 |
|
|
|
22 |
self.attention_a = nn.Sequential(*self.attention_a) |
|
|
23 |
self.attention_b = nn.Sequential(*self.attention_b) |
|
|
24 |
self.attention_c = nn.Linear(D, n_classes) |
|
|
25 |
|
|
|
26 |
def forward(self, x): |
|
|
27 |
a = self.attention_a(x) |
|
|
28 |
b = self.attention_b(x) |
|
|
29 |
A = a.mul(b) |
|
|
30 |
A = self.attention_c(A) # N x n_classes |
|
|
31 |
return A |
|
|
32 |
|
|
|
33 |
class Attn_Modality_Gated(nn.Module): |
|
|
34 |
# Adapted from https://github.com/mahmoodlab/PORPOISE |
|
|
35 |
def __init__(self, gate_h1, gate_h2, gate_h3, dim1_og, dim2_og, dim3_og, use_bilinear=[True,True,True], scale=[1,1,1], p_dropout_fc=0.25): |
|
|
36 |
super(Attn_Modality_Gated, self).__init__() |
|
|
37 |
|
|
|
38 |
self.gate_h1 = gate_h1 #[boolean] |
|
|
39 |
self.gate_h2 = gate_h2 #[boolean] |
|
|
40 |
self.gate_h3 = gate_h3 #[boolean] |
|
|
41 |
self.use_bilinear = use_bilinear #[boolean] |
|
|
42 |
|
|
|
43 |
# can perform attention on latent vectors of lower dimension |
|
|
44 |
dim1, dim2, dim3 = dim1_og//scale[0], dim2_og//scale[1], dim3_og//scale[2] |
|
|
45 |
|
|
|
46 |
# attention gate of each modality |
|
|
47 |
if self.gate_h1: |
|
|
48 |
self.linear_h1 = nn.Sequential(nn.Linear(dim1_og, dim1), nn.ReLU()) |
|
|
49 |
self.linear_z1 = nn.Bilinear(dim1_og, dim2_og+dim3_og, dim1) if self.use_bilinear[0] else nn.Sequential(nn.Linear(dim1_og+dim2_og+dim3_og, dim1)) |
|
|
50 |
self.linear_o1 = nn.Sequential(nn.Linear(dim1, dim1), nn.ReLU(), nn.Dropout(p=p_dropout_fc)) |
|
|
51 |
else: |
|
|
52 |
self.linear_h1, self.linear_o1 = nn.Identity(), nn.Identity() |
|
|
53 |
|
|
|
54 |
if self.gate_h2: |
|
|
55 |
self.linear_h2 = nn.Sequential(nn.Linear(dim2_og, dim2), nn.ReLU()) |
|
|
56 |
self.linear_z2 = nn.Bilinear(dim2_og, dim1_og+dim3_og, dim2) if self.use_bilinear[1] else nn.Sequential(nn.Linear(dim1_og+dim2_og+dim3_og, dim2)) |
|
|
57 |
self.linear_o2 = nn.Sequential(nn.Linear(dim2, dim2), nn.ReLU(), nn.Dropout(p=p_dropout_fc)) |
|
|
58 |
else: |
|
|
59 |
self.linear_h2, self.linear_o2 = nn.Identity(), nn.Identity() |
|
|
60 |
|
|
|
61 |
if self.gate_h3: |
|
|
62 |
self.linear_h3 = nn.Sequential(nn.Linear(dim3_og, dim3), nn.ReLU()) |
|
|
63 |
self.linear_z3 = nn.Bilinear(dim3_og, dim1_og+dim2_og, dim3) if self.use_bilinear[2] else nn.Sequential(nn.Linear(dim1_og+dim2_og+dim3_og, dim3)) |
|
|
64 |
self.linear_o3 = nn.Sequential(nn.Linear(dim3, dim3), nn.ReLU(), nn.Dropout(p=p_dropout_fc)) |
|
|
65 |
else: |
|
|
66 |
self.linear_h3, self.linear_o3 = nn.Identity(), nn.Identity() |
|
|
67 |
|
|
|
68 |
def forward(self, x1, x2, x3): |
|
|
69 |
|
|
|
70 |
if self.gate_h1: |
|
|
71 |
h1 = self.linear_h1(x1) #breaks colli of h1 |
|
|
72 |
z1 = self.linear_z1(x1, torch.cat([x2,x3], dim=-1)) if self.use_bilinear[0] else self.linear_z1(torch.cat((x1, x2, x3), dim=-1)) #creates a vector combining both modalities |
|
|
73 |
o1 = self.linear_o1(nn.Sigmoid()(z1)*h1) #update modality input |
|
|
74 |
else: |
|
|
75 |
h1 = self.linear_h1(x1) |
|
|
76 |
o1 = self.linear_o1(h1) |
|
|
77 |
|
|
|
78 |
if self.gate_h2: |
|
|
79 |
h2 = self.linear_h2(x2) |
|
|
80 |
z2 = self.linear_z2(x2, torch.cat([x1,x3], dim=-1)) if self.use_bilinear[1] else self.linear_z2(torch.cat((x1, x2, x3), dim=-1)) |
|
|
81 |
o2 = self.linear_o2(nn.Sigmoid()(z2)*h2) |
|
|
82 |
else: |
|
|
83 |
h2 = self.linear_h2(x2) |
|
|
84 |
o2 = self.linear_o2(h2) |
|
|
85 |
|
|
|
86 |
if self.gate_h3: |
|
|
87 |
h3 = self.linear_h3(x3) |
|
|
88 |
z3 = self.linear_z3(x3, torch.cat([x1,x2], dim=-1)) if self.use_bilinear[2] else self.linear_z3(torch.cat((x1, x2, x3), dim=-1)) |
|
|
89 |
o3 = self.linear_o3(nn.Sigmoid()(z3)*h3) |
|
|
90 |
else: |
|
|
91 |
h3 = self.linear_h3(x3) |
|
|
92 |
o3 = self.linear_o3(h3) |
|
|
93 |
|
|
|
94 |
return o1, o2, o3 |
|
|
95 |
|
|
|
96 |
class FC_block(nn.Module): |
|
|
97 |
def __init__(self, dim_in, dim_out, act_layer=nn.ReLU, dropout=True, p_dropout_fc=0.25): |
|
|
98 |
super(FC_block, self).__init__() |
|
|
99 |
|
|
|
100 |
self.fc = nn.Linear(dim_in, dim_out) |
|
|
101 |
self.act = act_layer() |
|
|
102 |
self.drop = nn.Dropout(p_dropout_fc) if dropout else nn.Identity() |
|
|
103 |
|
|
|
104 |
def forward(self, x): |
|
|
105 |
x = self.fc(x) |
|
|
106 |
x = self.act(x) |
|
|
107 |
x = self.drop(x) |
|
|
108 |
return x |
|
|
109 |
|
|
|
110 |
class Categorical_encoding(nn.Module): |
|
|
111 |
def __init__(self, taxonomy_in=3, embedding_dim=128, depth=1, act_fct='relu', dropout=True, p_dropout=0.25): |
|
|
112 |
super(Categorical_encoding, self).__init__() |
|
|
113 |
|
|
|
114 |
act_fcts = {'relu': nn.ReLU(), |
|
|
115 |
'elu' : nn.ELU(), |
|
|
116 |
'tanh': nn.Tanh(), |
|
|
117 |
'selu': nn.SELU(), |
|
|
118 |
} |
|
|
119 |
dropout_module = nn.AlphaDropout(p_dropout) if act_fct=='selu' else nn.Dropout(p_dropout) |
|
|
120 |
|
|
|
121 |
self.embedding = nn.Embedding(taxonomy_in, embedding_dim) |
|
|
122 |
|
|
|
123 |
fc_layers = [] |
|
|
124 |
for d in range(depth): |
|
|
125 |
fc_layers.append(nn.Linear(embedding_dim//(2**d), embedding_dim//(2**(d+1)))) |
|
|
126 |
fc_layers.append(dropout_module if dropout else nn.Identity()) |
|
|
127 |
fc_layers.append(act_fcts[act_fct]) |
|
|
128 |
|
|
|
129 |
self.fc_layers = nn.Sequential(*fc_layers) |
|
|
130 |
|
|
|
131 |
def forward(self, x): |
|
|
132 |
x = self.embedding(x) |
|
|
133 |
x = self.fc_layers(x) |
|
|
134 |
return x |
|
|
135 |
|
|
|
136 |
class HECTOR(nn.Module): |
|
|
137 |
def __init__( |
|
|
138 |
self, |
|
|
139 |
input_feature_size=1024, |
|
|
140 |
precompression_layer=True, |
|
|
141 |
feature_size_comp = 512, |
|
|
142 |
feature_size_attn = 256, |
|
|
143 |
postcompression_layer=True, |
|
|
144 |
feature_size_comp_post = 128, |
|
|
145 |
dropout=True, |
|
|
146 |
p_dropout_fc=0.25, |
|
|
147 |
p_dropout_atn=0.25, |
|
|
148 |
n_classes=2, |
|
|
149 |
input_stage_size=6, |
|
|
150 |
embedding_dim_stage=16, |
|
|
151 |
depth_dim_stage=1, |
|
|
152 |
act_fct_stage='elu', |
|
|
153 |
dropout_stage=True, |
|
|
154 |
p_dropout_stage=0.25, |
|
|
155 |
input_mol_size=4, |
|
|
156 |
embedding_dim_mol=16, |
|
|
157 |
depth_dim_mol=1, |
|
|
158 |
act_fct_mol='elu', |
|
|
159 |
dropout_mol=True, |
|
|
160 |
p_dropout_mol=0.25, |
|
|
161 |
fusion_type='kron', |
|
|
162 |
use_bilinear=[True,True,True], |
|
|
163 |
gate_hist=False, |
|
|
164 |
gate_stage=False, |
|
|
165 |
gate_mol=False, |
|
|
166 |
scale=[1,1,1], |
|
|
167 |
): |
|
|
168 |
super(HECTOR, self).__init__() |
|
|
169 |
|
|
|
170 |
self.fusion_type =fusion_type |
|
|
171 |
self.input_stage_size=input_stage_size |
|
|
172 |
self.use_bilinear = use_bilinear |
|
|
173 |
self.gate_hist = gate_hist |
|
|
174 |
self.gate_stage = gate_stage |
|
|
175 |
self.gate_mol = gate_mol |
|
|
176 |
|
|
|
177 |
# Reduce dimension of H&E patch features. |
|
|
178 |
if precompression_layer: |
|
|
179 |
self.compression_layer = nn.Sequential(*[FC_block(input_feature_size, feature_size_comp*4, p_dropout_fc=p_dropout_fc), |
|
|
180 |
FC_block(feature_size_comp*4, feature_size_comp*2, p_dropout_fc=p_dropout_fc), |
|
|
181 |
FC_block(feature_size_comp*2, feature_size_comp, p_dropout_fc=p_dropout_fc),]) |
|
|
182 |
|
|
|
183 |
dim_post_compression = feature_size_comp |
|
|
184 |
else: |
|
|
185 |
self.compression_layer = nn.Identity() |
|
|
186 |
dim_post_compression = input_feature_size |
|
|
187 |
|
|
|
188 |
# Get embeddings of categorical features. |
|
|
189 |
self.encoding_stage_net = Categorical_encoding(taxonomy_in=self.input_stage_size, |
|
|
190 |
embedding_dim=embedding_dim_stage, |
|
|
191 |
depth=depth_dim_stage, |
|
|
192 |
act_fct=act_fct_stage, |
|
|
193 |
dropout=dropout_stage, |
|
|
194 |
p_dropout=p_dropout_stage) |
|
|
195 |
self.out_stage_size = embedding_dim_stage//(2**depth_dim_stage) |
|
|
196 |
|
|
|
197 |
self.encoding_mol_net = Categorical_encoding(taxonomy_in=input_mol_size, |
|
|
198 |
embedding_dim=embedding_dim_mol, |
|
|
199 |
depth=depth_dim_mol, |
|
|
200 |
act_fct=act_fct_mol, |
|
|
201 |
dropout=dropout_mol, |
|
|
202 |
p_dropout=p_dropout_mol) |
|
|
203 |
h_mol_size_out = embedding_dim_mol//(2**depth_dim_mol) |
|
|
204 |
|
|
|
205 |
# For survival tasks the attention scores are binary (set to class=1). |
|
|
206 |
self.attention_survival_net = AttnNet( |
|
|
207 |
L=dim_post_compression, |
|
|
208 |
D=feature_size_attn, |
|
|
209 |
dropout=dropout, |
|
|
210 |
p_dropout_atn=p_dropout_atn, |
|
|
211 |
n_classes=1,) |
|
|
212 |
|
|
|
213 |
# Attention gate on each modality. |
|
|
214 |
self.attn_modalities = Attn_Modality_Gated( |
|
|
215 |
gate_h1=self.gate_hist, |
|
|
216 |
gate_h2=self.gate_stage, |
|
|
217 |
gate_h3=self.gate_mol, |
|
|
218 |
dim1_og=dim_post_compression, |
|
|
219 |
dim2_og=self.out_stage_size, |
|
|
220 |
dim3_og=h_mol_size_out, |
|
|
221 |
scale=scale, |
|
|
222 |
use_bilinear=self.use_bilinear) |
|
|
223 |
|
|
|
224 |
# Post-compression layer for H&E slide-level embedding before fusion. |
|
|
225 |
dim_post_compression = dim_post_compression//scale[0] if self.gate_hist else dim_post_compression |
|
|
226 |
self.post_compression_layer_he = FC_block(dim_post_compression, dim_post_compression//2, p_dropout_fc=p_dropout_fc) |
|
|
227 |
dim_post_compression = dim_post_compression//2 |
|
|
228 |
|
|
|
229 |
# Post-compression layer. |
|
|
230 |
dim1, dim2, dim3 = dim_post_compression, self.out_stage_size//scale[1] if self.gate_stage else self.out_stage_size, h_mol_size_out//scale[2] if self.gate_mol else h_mol_size_out |
|
|
231 |
if self.fusion_type=='bilinear': |
|
|
232 |
head_size_in = (dim1+1)*(dim2+1)*(dim3+1) |
|
|
233 |
elif self.fusion_type=='kron': |
|
|
234 |
head_size_in = (dim1)*(dim2)*(dim3) |
|
|
235 |
elif self.fusion_type=='concat': |
|
|
236 |
head_size_in = dim1+dim2+dim3 |
|
|
237 |
|
|
|
238 |
self.post_compression_layer = nn.Sequential(*[FC_block(head_size_in, feature_size_comp_post*2, p_dropout_fc=p_dropout_fc), |
|
|
239 |
FC_block(feature_size_comp_post*2, feature_size_comp_post, p_dropout_fc=p_dropout_fc),]) |
|
|
240 |
|
|
|
241 |
# Survival head. |
|
|
242 |
self.n_classes = n_classes |
|
|
243 |
self.classifier = nn.Linear(feature_size_comp_post, self.n_classes) |
|
|
244 |
|
|
|
245 |
# Init weights. |
|
|
246 |
self.apply(self._init_weights) |
|
|
247 |
|
|
|
248 |
def _init_weights(self, module): |
|
|
249 |
if isinstance(module, nn.Linear): |
|
|
250 |
nn.init.xavier_normal_(module.weight) |
|
|
251 |
if module.bias is not None: |
|
|
252 |
module.bias.data.zero_() |
|
|
253 |
|
|
|
254 |
def forward_attention(self, h): |
|
|
255 |
A_ = self.attention_survival_net(h) # h shape is N_tilesxdim |
|
|
256 |
A_raw = torch.transpose(A_, 1, 0) # K_attention_classesxN_tiles |
|
|
257 |
A = F.softmax(A_raw, dim=-1) # #normalize attentions scores over tiles |
|
|
258 |
return A_raw, A |
|
|
259 |
|
|
|
260 |
def forward_fusion(self, h1, h2, h3): |
|
|
261 |
|
|
|
262 |
if self.fusion_type=='bilinear': |
|
|
263 |
# Append 1 to retain unimodal embeddings in the fusion |
|
|
264 |
h1 = torch.cat((h1, torch.ones(1, 1, dtype=torch.float, device=h1.device)), -1) |
|
|
265 |
h2 = torch.cat((h2, torch.ones(1, 1, dtype=torch.float, device=h2.device)), -1) |
|
|
266 |
h3 = torch.cat((h3, torch.ones(1, 1, dtype=torch.float, device=h3.device)), -1) |
|
|
267 |
|
|
|
268 |
return torch.kron(torch.kron(h1, h2), h3) |
|
|
269 |
|
|
|
270 |
elif self.fusion_type=='kron': |
|
|
271 |
return torch.kron(torch.kron(h1, h2), h3) |
|
|
272 |
|
|
|
273 |
elif self.fusion_type=='concat': |
|
|
274 |
return torch.cat([h1, h2, h3], dim=-1) |
|
|
275 |
else: |
|
|
276 |
print('Not implemeted') |
|
|
277 |
#raise Exception ... |
|
|
278 |
|
|
|
279 |
def forward_survival(self, logits): |
|
|
280 |
Y_hat = torch.topk(logits, 1, dim=1)[1] |
|
|
281 |
# Model outputs the hazards with sigmoid activation function. |
|
|
282 |
hazards = torch.sigmoid(logits) #size [1, n_classes] h(t|X) := P(T=t|T>=t,X) |
|
|
283 |
#S(t|X) := P(T>=t|X) = TT (1-h(s|X)) for s=1,t. This is computed for each discrete time point t. So for s=1 there is no cum prod. |
|
|
284 |
survival = torch.cumprod(1 - hazards, dim=1) #size [1, n_classes] |
|
|
285 |
|
|
|
286 |
return hazards, survival, Y_hat |
|
|
287 |
|
|
|
288 |
def forward(self, h, stage, h_mol): |
|
|
289 |
|
|
|
290 |
# H&E embedding. |
|
|
291 |
h = self.compression_layer(h) |
|
|
292 |
|
|
|
293 |
# Attention MIL and first-order pooling. |
|
|
294 |
A_raw, A = self.forward_attention(h) # 1xN tiles |
|
|
295 |
h_hist = A @ h #torch.Size([1, dim_embedding]) [Sum over N(aihi,1), ..., Sum over N(aihi,dim_embedding)] |
|
|
296 |
|
|
|
297 |
# Stage learnable embedding. |
|
|
298 |
stage = self.encoding_stage_net(stage) |
|
|
299 |
|
|
|
300 |
# Compression h_mol. |
|
|
301 |
h_mol = self.encoding_mol_net(h_mol) |
|
|
302 |
|
|
|
303 |
# Attention gates on each modality. |
|
|
304 |
h_hist, stage, h_mol = self.attn_modalities(h_hist, stage, h_mol) |
|
|
305 |
|
|
|
306 |
# Post-compressiong H&E slide embedding. |
|
|
307 |
h_hist = self.post_compression_layer_he(h_hist) |
|
|
308 |
|
|
|
309 |
# Fusion. |
|
|
310 |
m = self.forward_fusion(h_hist, stage, h_mol) |
|
|
311 |
|
|
|
312 |
# Post-compression of multimodal embedding. |
|
|
313 |
m = self.post_compression_layer(m) |
|
|
314 |
|
|
|
315 |
# Survival head. |
|
|
316 |
logits = self.classifier(m) |
|
|
317 |
|
|
|
318 |
hazards, survival, Y_hat = self.forward_survival(logits) |
|
|
319 |
|
|
|
320 |
return hazards, survival, Y_hat, A_raw, m |