Diff of /model.py [000000] .. [352cae]

Switch to unified view

a b/model.py
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F
4
import numpy as np
5
import math
6
7
class AttnNet(nn.Module):
8
    # Adapted from https://github.com/mahmoodlab/CLAM/blob/master/models/model_clam.py
9
    # Lu, M.Y., Williamson, D.F.K., Chen, T.Y. et al. Data-efficient and weakly supervised computational pathology on whole-slide images. Nat Biomed Eng 5, 555–570 (2021). https://doi.org/10.1038/s41551-020-00682-w
10
11
    def __init__(self, L=1024, D=256, dropout=False, p_dropout_atn=0.25, n_classes=1):
12
        super(AttnNet, self).__init__()
13
14
        self.attention_a = [nn.Linear(L, D), nn.Tanh()]
15
16
        self.attention_b = [nn.Linear(L, D), nn.Sigmoid()]
17
18
        if dropout:
19
            self.attention_a.append(nn.Dropout(p_dropout_atn))
20
            self.attention_b.append(nn.Dropout(p_dropout_atn))
21
22
        self.attention_a = nn.Sequential(*self.attention_a)
23
        self.attention_b = nn.Sequential(*self.attention_b)
24
        self.attention_c = nn.Linear(D, n_classes)
25
26
    def forward(self, x):
27
        a = self.attention_a(x)
28
        b = self.attention_b(x)
29
        A = a.mul(b)
30
        A = self.attention_c(A)  # N x n_classes
31
        return A
32
33
class Attn_Modality_Gated(nn.Module):
34
    # Adapted from https://github.com/mahmoodlab/PORPOISE
35
    def __init__(self, gate_h1, gate_h2, gate_h3, dim1_og, dim2_og, dim3_og, use_bilinear=[True,True,True], scale=[1,1,1], p_dropout_fc=0.25):
36
        super(Attn_Modality_Gated, self).__init__()
37
38
        self.gate_h1 = gate_h1 #[boolean]
39
        self.gate_h2 = gate_h2 #[boolean]
40
        self.gate_h3 = gate_h3 #[boolean]
41
        self.use_bilinear = use_bilinear #[boolean]
42
43
        # can perform attention on latent vectors of lower dimension
44
        dim1, dim2, dim3 = dim1_og//scale[0], dim2_og//scale[1], dim3_og//scale[2]
45
46
        # attention gate of each modality
47
        if self.gate_h1:
48
            self.linear_h1 = nn.Sequential(nn.Linear(dim1_og, dim1), nn.ReLU())
49
            self.linear_z1 = nn.Bilinear(dim1_og, dim2_og+dim3_og, dim1) if self.use_bilinear[0] else nn.Sequential(nn.Linear(dim1_og+dim2_og+dim3_og, dim1))
50
            self.linear_o1 = nn.Sequential(nn.Linear(dim1, dim1), nn.ReLU(), nn.Dropout(p=p_dropout_fc))
51
        else:
52
            self.linear_h1, self.linear_o1 = nn.Identity(), nn.Identity()  
53
54
        if self.gate_h2:
55
            self.linear_h2 = nn.Sequential(nn.Linear(dim2_og, dim2), nn.ReLU())
56
            self.linear_z2 = nn.Bilinear(dim2_og, dim1_og+dim3_og, dim2) if self.use_bilinear[1] else nn.Sequential(nn.Linear(dim1_og+dim2_og+dim3_og, dim2))
57
            self.linear_o2 = nn.Sequential(nn.Linear(dim2, dim2), nn.ReLU(), nn.Dropout(p=p_dropout_fc))
58
        else:
59
            self.linear_h2, self.linear_o2 = nn.Identity(), nn.Identity()  
60
61
        if self.gate_h3:
62
            self.linear_h3 = nn.Sequential(nn.Linear(dim3_og, dim3), nn.ReLU())
63
            self.linear_z3 = nn.Bilinear(dim3_og, dim1_og+dim2_og, dim3) if self.use_bilinear[2] else nn.Sequential(nn.Linear(dim1_og+dim2_og+dim3_og, dim3))
64
            self.linear_o3 = nn.Sequential(nn.Linear(dim3, dim3), nn.ReLU(), nn.Dropout(p=p_dropout_fc))
65
        else:
66
            self.linear_h3, self.linear_o3 = nn.Identity(), nn.Identity()  
67
68
    def forward(self, x1, x2, x3):
69
70
        if self.gate_h1:
71
            h1 = self.linear_h1(x1) #breaks colli of h1
72
            z1 = self.linear_z1(x1, torch.cat([x2,x3], dim=-1)) if self.use_bilinear[0] else self.linear_z1(torch.cat((x1, x2, x3), dim=-1)) #creates a vector combining both modalities
73
            o1 = self.linear_o1(nn.Sigmoid()(z1)*h1) #update modality input
74
        else:
75
            h1 = self.linear_h1(x1)
76
            o1 = self.linear_o1(h1)
77
78
        if self.gate_h2:
79
            h2 = self.linear_h2(x2)
80
            z2 = self.linear_z2(x2, torch.cat([x1,x3], dim=-1)) if self.use_bilinear[1] else self.linear_z2(torch.cat((x1, x2, x3), dim=-1))
81
            o2 = self.linear_o2(nn.Sigmoid()(z2)*h2)
82
        else:
83
            h2 = self.linear_h2(x2)
84
            o2 = self.linear_o2(h2)
85
        
86
        if self.gate_h3:
87
            h3 = self.linear_h3(x3)
88
            z3 = self.linear_z3(x3, torch.cat([x1,x2], dim=-1)) if self.use_bilinear[2] else self.linear_z3(torch.cat((x1, x2, x3), dim=-1))
89
            o3 = self.linear_o3(nn.Sigmoid()(z3)*h3)
90
        else:
91
            h3 = self.linear_h3(x3)
92
            o3 = self.linear_o3(h3)
93
94
        return o1, o2, o3
95
96
class FC_block(nn.Module):
97
    def __init__(self, dim_in, dim_out, act_layer=nn.ReLU, dropout=True, p_dropout_fc=0.25):
98
        super(FC_block, self).__init__()
99
100
        self.fc = nn.Linear(dim_in, dim_out)
101
        self.act = act_layer()
102
        self.drop = nn.Dropout(p_dropout_fc) if dropout else nn.Identity()
103
    
104
    def forward(self, x):
105
        x = self.fc(x)
106
        x = self.act(x)
107
        x = self.drop(x)
108
        return x
109
110
class Categorical_encoding(nn.Module):
111
    def __init__(self, taxonomy_in=3, embedding_dim=128, depth=1, act_fct='relu', dropout=True, p_dropout=0.25):
112
        super(Categorical_encoding, self).__init__()
113
114
        act_fcts = {'relu': nn.ReLU(),
115
        'elu' : nn.ELU(),
116
        'tanh': nn.Tanh(),
117
        'selu': nn.SELU(),
118
        }
119
        dropout_module = nn.AlphaDropout(p_dropout) if act_fct=='selu' else nn.Dropout(p_dropout)
120
121
        self.embedding = nn.Embedding(taxonomy_in, embedding_dim)
122
        
123
        fc_layers = []
124
        for d in range(depth):
125
            fc_layers.append(nn.Linear(embedding_dim//(2**d), embedding_dim//(2**(d+1))))
126
            fc_layers.append(dropout_module if dropout else nn.Identity())
127
            fc_layers.append(act_fcts[act_fct])
128
129
        self.fc_layers = nn.Sequential(*fc_layers)
130
    
131
    def forward(self, x):
132
        x = self.embedding(x)
133
        x = self.fc_layers(x)
134
        return x
135
136
class HECTOR(nn.Module):
137
    def __init__(
138
        self,
139
        input_feature_size=1024,
140
        precompression_layer=True,
141
        feature_size_comp = 512,
142
        feature_size_attn = 256,
143
        postcompression_layer=True,
144
        feature_size_comp_post = 128,
145
        dropout=True,
146
        p_dropout_fc=0.25,
147
        p_dropout_atn=0.25,
148
        n_classes=2,
149
        input_stage_size=6,
150
        embedding_dim_stage=16,
151
        depth_dim_stage=1,
152
        act_fct_stage='elu',
153
        dropout_stage=True,
154
        p_dropout_stage=0.25,
155
        input_mol_size=4,
156
        embedding_dim_mol=16,
157
        depth_dim_mol=1,
158
        act_fct_mol='elu',
159
        dropout_mol=True,
160
        p_dropout_mol=0.25,
161
        fusion_type='kron',
162
        use_bilinear=[True,True,True],
163
        gate_hist=False,
164
        gate_stage=False,
165
        gate_mol=False,
166
        scale=[1,1,1],
167
    ):
168
        super(HECTOR, self).__init__()
169
170
        self.fusion_type =fusion_type
171
        self.input_stage_size=input_stage_size
172
        self.use_bilinear = use_bilinear
173
        self.gate_hist = gate_hist
174
        self.gate_stage = gate_stage
175
        self.gate_mol = gate_mol
176
177
        # Reduce dimension of H&E patch features.
178
        if precompression_layer:
179
            self.compression_layer = nn.Sequential(*[FC_block(input_feature_size, feature_size_comp*4, p_dropout_fc=p_dropout_fc),
180
                                                    FC_block(feature_size_comp*4, feature_size_comp*2, p_dropout_fc=p_dropout_fc),
181
                                                    FC_block(feature_size_comp*2, feature_size_comp, p_dropout_fc=p_dropout_fc),])
182
183
            dim_post_compression = feature_size_comp
184
        else:
185
            self.compression_layer = nn.Identity()
186
            dim_post_compression = input_feature_size
187
188
        # Get embeddings of categorical features.
189
        self.encoding_stage_net = Categorical_encoding(taxonomy_in=self.input_stage_size, 
190
                                                embedding_dim=embedding_dim_stage, 
191
                                                depth=depth_dim_stage, 
192
                                                act_fct=act_fct_stage, 
193
                                                dropout=dropout_stage, 
194
                                                p_dropout=p_dropout_stage)
195
        self.out_stage_size = embedding_dim_stage//(2**depth_dim_stage)
196
     
197
        self.encoding_mol_net = Categorical_encoding(taxonomy_in=input_mol_size, 
198
                                                embedding_dim=embedding_dim_mol, 
199
                                                depth=depth_dim_mol, 
200
                                                act_fct=act_fct_mol, 
201
                                                dropout=dropout_mol, 
202
                                                p_dropout=p_dropout_mol)
203
        h_mol_size_out = embedding_dim_mol//(2**depth_dim_mol)
204
205
        # For survival tasks the attention scores are binary (set to class=1).
206
        self.attention_survival_net = AttnNet(
207
            L=dim_post_compression,
208
            D=feature_size_attn,
209
            dropout=dropout,
210
            p_dropout_atn=p_dropout_atn,
211
            n_classes=1,)
212
213
        # Attention gate on each modality.
214
        self.attn_modalities = Attn_Modality_Gated(
215
            gate_h1=self.gate_hist, 
216
            gate_h2=self.gate_stage, 
217
            gate_h3=self.gate_mol,
218
            dim1_og=dim_post_compression, 
219
            dim2_og=self.out_stage_size, 
220
            dim3_og=h_mol_size_out,
221
            scale=scale, 
222
            use_bilinear=self.use_bilinear)
223
224
        # Post-compression layer for H&E slide-level embedding before fusion.
225
        dim_post_compression = dim_post_compression//scale[0] if self.gate_hist else dim_post_compression
226
        self.post_compression_layer_he = FC_block(dim_post_compression, dim_post_compression//2, p_dropout_fc=p_dropout_fc)
227
        dim_post_compression = dim_post_compression//2
228
229
        # Post-compression layer.
230
        dim1, dim2, dim3 = dim_post_compression, self.out_stage_size//scale[1] if self.gate_stage else self.out_stage_size, h_mol_size_out//scale[2] if self.gate_mol else h_mol_size_out
231
        if self.fusion_type=='bilinear':
232
            head_size_in = (dim1+1)*(dim2+1)*(dim3+1)
233
        elif self.fusion_type=='kron':
234
            head_size_in = (dim1)*(dim2)*(dim3)
235
        elif self.fusion_type=='concat':
236
            head_size_in = dim1+dim2+dim3
237
238
        self.post_compression_layer = nn.Sequential(*[FC_block(head_size_in, feature_size_comp_post*2, p_dropout_fc=p_dropout_fc),
239
                                                        FC_block(feature_size_comp_post*2, feature_size_comp_post, p_dropout_fc=p_dropout_fc),])
240
241
        # Survival head.
242
        self.n_classes = n_classes
243
        self.classifier = nn.Linear(feature_size_comp_post, self.n_classes)
244
245
        # Init weights.
246
        self.apply(self._init_weights)
247
    
248
    def _init_weights(self, module):
249
        if isinstance(module, nn.Linear):
250
            nn.init.xavier_normal_(module.weight)
251
            if module.bias is not None:
252
                module.bias.data.zero_()
253
254
    def forward_attention(self, h):
255
        A_ = self.attention_survival_net(h)  # h shape is N_tilesxdim
256
        A_raw = torch.transpose(A_, 1, 0)  # K_attention_classesxN_tiles
257
        A = F.softmax(A_raw, dim=-1)  # #normalize attentions scores over tiles
258
        return A_raw, A
259
260
    def forward_fusion(self, h1, h2, h3):
261
262
        if self.fusion_type=='bilinear':
263
            # Append 1 to retain unimodal embeddings in the fusion
264
            h1 = torch.cat((h1, torch.ones(1, 1, dtype=torch.float, device=h1.device)), -1)
265
            h2 = torch.cat((h2, torch.ones(1, 1, dtype=torch.float, device=h2.device)), -1)
266
            h3 = torch.cat((h3, torch.ones(1, 1, dtype=torch.float, device=h3.device)), -1)
267
268
            return torch.kron(torch.kron(h1, h2), h3)
269
270
        elif self.fusion_type=='kron':
271
            return torch.kron(torch.kron(h1, h2), h3)
272
273
        elif self.fusion_type=='concat':
274
            return torch.cat([h1, h2, h3], dim=-1)
275
        else:
276
            print('Not implemeted') 
277
            #raise Exception ... 
278
279
    def forward_survival(self, logits):
280
        Y_hat = torch.topk(logits, 1, dim=1)[1]
281
        # Model outputs the hazards with sigmoid activation function.
282
        hazards = torch.sigmoid(logits) #size [1, n_classes] h(t|X) := P(T=t|T>=t,X)
283
        #S(t|X) := P(T>=t|X) = TT (1-h(s|X)) for s=1,t. This is computed for each discrete time point t. So for s=1 there is no cum prod. 
284
        survival = torch.cumprod(1 - hazards, dim=1) #size [1, n_classes]
285
286
        return hazards, survival, Y_hat
287
288
    def forward(self, h, stage, h_mol):
289
290
        # H&E embedding.
291
        h = self.compression_layer(h)
292
293
        # Attention MIL and first-order pooling.
294
        A_raw, A = self.forward_attention(h) # 1xN tiles
295
        h_hist = A @ h #torch.Size([1, dim_embedding]) [Sum over N(aihi,1), ..., Sum over N(aihi,dim_embedding)]
296
        
297
        # Stage learnable embedding.
298
        stage = self.encoding_stage_net(stage)
299
300
        # Compression h_mol.
301
        h_mol = self.encoding_mol_net(h_mol)
302
303
        # Attention gates on each modality.
304
        h_hist, stage, h_mol = self.attn_modalities(h_hist, stage, h_mol)
305
306
        # Post-compressiong H&E slide embedding.
307
        h_hist = self.post_compression_layer_he(h_hist)
308
309
        # Fusion.
310
        m = self.forward_fusion(h_hist, stage, h_mol)
311
312
        # Post-compression of multimodal embedding.
313
        m = self.post_compression_layer(m)
314
315
        # Survival head.
316
        logits  = self.classifier(m)
317
318
        hazards, survival, Y_hat = self.forward_survival(logits)
319
320
        return hazards, survival, Y_hat, A_raw, m