a b/models/model_set_mil.py
1
from collections import OrderedDict
2
from os.path import join
3
import pdb
4
5
import numpy as np
6
7
import torch
8
import torch.nn as nn
9
import torch.nn.functional as F
10
11
from models.model_utils import *
12
13
14
15
################################
16
### Deep Sets Implementation ###
17
################################
18
class MIL_Sum_FC_surv(nn.Module):
19
    def __init__(self, omic_input_dim=None, fusion=None, size_arg = "small", dropout=0.25, n_classes=4):
20
        r"""
21
        Deep Sets Implementation.
22
23
        Args:
24
            omic_input_dim (int): Dimension size of genomic features.
25
            fusion (str): Fusion method (Choices: concat, bilinear, or None)
26
            size_arg (str): Size of NN architecture (Choices: small or large)
27
            dropout (float): Dropout rate
28
            n_classes (int): Output shape of NN
29
        """
30
        super(MIL_Sum_FC_surv, self).__init__()
31
        self.fusion = fusion
32
        self.size_dict_path = {"small": [1024, 512, 256], "big": [1024, 512, 384]}
33
        self.size_dict_omic = {'small': [256, 256]}
34
35
        ### Deep Sets Architecture Construction
36
        size = self.size_dict_path[size_arg]
37
        self.phi = nn.Sequential(*[nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout)])
38
        self.rho = nn.Sequential(*[nn.Linear(size[1], size[2]), nn.ReLU(), nn.Dropout(dropout)])
39
40
        ### Constructing Genomic SNN
41
        if self.fusion != None:
42
            hidden = [256, 256]
43
            fc_omic = [SNN_Block(dim1=omic_input_dim, dim2=hidden[0])]
44
            for i, _ in enumerate(hidden[1:]):
45
                fc_omic.append(SNN_Block(dim1=hidden[i], dim2=hidden[i+1], dropout=0.25))
46
            self.fc_omic = nn.Sequential(*fc_omic)
47
        
48
            if self.fusion == 'concat':
49
                self.mm = nn.Sequential(*[nn.Linear(256*2, size[2]), nn.ReLU(), nn.Linear(size[2], size[2]), nn.ReLU()])
50
            elif self.fusion == 'bilinear':
51
                self.mm = BilinearFusion(dim1=256, dim2=256, scale_dim1=8, scale_dim2=8, mmhid=256)
52
            else:
53
                self.mm = None
54
55
        self.classifier = nn.Linear(size[2], n_classes)
56
57
58
    def relocate(self):
59
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
        if torch.cuda.device_count() >= 1:
61
            device_ids = list(range(torch.cuda.device_count()))
62
            self.phi = nn.DataParallel(self.phi, device_ids=device_ids).to('cuda:0')
63
64
        if self.fusion is not None:
65
            self.fc_omic = self.fc_omic.to(device)
66
            self.mm = self.mm.to(device)
67
68
        self.rho = self.rho.to(device)
69
        self.classifier = self.classifier.to(device)
70
71
72
    def forward(self, **kwargs):
73
        x_path = kwargs['x_path']
74
75
        h_path = self.phi(x_path).sum(axis=0)
76
        h_path = self.rho(h_path)
77
78
        if self.fusion is not None:
79
            x_omic = kwargs['x_omic']
80
            h_omic = self.fc_omic(x_omic).squeeze(dim=0)
81
            if self.fusion == 'bilinear':
82
                h = self.mm(h_path.unsqueeze(dim=0), h_omic.unsqueeze(dim=0)).squeeze()
83
            elif self.fusion == 'concat':
84
                h = self.mm(torch.cat([h_path, h_omic], axis=0))
85
        else:
86
            h = h_path # [256] vector
87
88
        logits  = self.classifier(h).unsqueeze(0) # logits needs to be a [1 x 4] vector 
89
        Y_hat = torch.topk(logits, 1, dim = 1)[1]
90
        hazards = torch.sigmoid(logits)
91
        S = torch.cumprod(1 - hazards, dim=1)
92
        
93
        return hazards, S, Y_hat, None, None
94
95
96
97
################################
98
# Attention MIL Implementation #
99
################################
100
class MIL_Attention_FC_surv(nn.Module):
101
    def __init__(self, omic_input_dim=None, fusion=None, size_arg = "small", dropout=0.25, n_classes=4):
102
        r"""
103
        Attention MIL Implementation
104
105
        Args:
106
            omic_input_dim (int): Dimension size of genomic features.
107
            fusion (str): Fusion method (Choices: concat, bilinear, or None)
108
            size_arg (str): Size of NN architecture (Choices: small or large)
109
            dropout (float): Dropout rate
110
            n_classes (int): Output shape of NN
111
        """
112
        super(MIL_Attention_FC_surv, self).__init__()
113
        self.fusion = fusion
114
        self.size_dict_path = {"small": [1024, 512, 256], "big": [1024, 512, 384]}
115
        self.size_dict_omic = {'small': [256, 256]}
116
117
        ### Deep Sets Architecture Construction
118
        size = self.size_dict_path[size_arg]
119
        fc = [nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout)]
120
        attention_net = Attn_Net_Gated(L=size[1], D=size[2], dropout=dropout, n_classes=1)
121
        fc.append(attention_net)
122
        self.attention_net = nn.Sequential(*fc)
123
        self.rho = nn.Sequential(*[nn.Linear(size[1], size[2]), nn.ReLU(), nn.Dropout(dropout)])
124
125
        ### Constructing Genomic SNN
126
        if self.fusion is not None:
127
            hidden = [256, 256]
128
            fc_omic = [SNN_Block(dim1=omic_input_dim, dim2=hidden[0])]
129
            for i, _ in enumerate(hidden[1:]):
130
                fc_omic.append(SNN_Block(dim1=hidden[i], dim2=hidden[i+1], dropout=0.25))
131
            self.fc_omic = nn.Sequential(*fc_omic)
132
        
133
            if self.fusion == 'concat':
134
                self.mm = nn.Sequential(*[nn.Linear(256*2, size[2]), nn.ReLU(), nn.Linear(size[2], size[2]), nn.ReLU()])
135
            elif self.fusion == 'bilinear':
136
                self.mm = BilinearFusion(dim1=256, dim2=256, scale_dim1=8, scale_dim2=8, mmhid=256)
137
            else:
138
                self.mm = None
139
140
        self.classifier = nn.Linear(size[2], n_classes)
141
142
143
    def relocate(self):
144
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
145
        if torch.cuda.device_count() >= 1:
146
            device_ids = list(range(torch.cuda.device_count()))
147
            self.attention_net = nn.DataParallel(self.attention_net, device_ids=device_ids).to('cuda:0')
148
149
        if self.fusion is not None:
150
            self.fc_omic = self.fc_omic.to(device)
151
            self.mm = self.mm.to(device)
152
153
        self.rho = self.rho.to(device)
154
        self.classifier = self.classifier.to(device)
155
156
157
    def forward(self, **kwargs):
158
        x_path = kwargs['x_path']
159
160
        A, h_path = self.attention_net(x_path)  
161
        A = torch.transpose(A, 1, 0)
162
        A_raw = A 
163
        A = F.softmax(A, dim=1) 
164
        h_path = torch.mm(A, h_path)
165
        h_path = self.rho(h_path).squeeze()
166
167
        if self.fusion is not None:
168
            x_omic = kwargs['x_omic']
169
            h_omic = self.fc_omic(x_omic)
170
            if self.fusion == 'bilinear':
171
                h = self.mm(h_path.unsqueeze(dim=0), h_omic.unsqueeze(dim=0)).squeeze()
172
            elif self.fusion == 'concat':
173
                h = self.mm(torch.cat([h_path, h_omic], axis=0))
174
        else:
175
            h = h_path # [256] vector
176
177
        logits  = self.classifier(h).unsqueeze(0) # logits needs to be a [1 x 4] vector 
178
        Y_hat = torch.topk(logits, 1, dim = 1)[1]
179
        hazards = torch.sigmoid(logits)
180
        S = torch.cumprod(1 - hazards, dim=1)
181
        
182
        return hazards, S, Y_hat, None, None
183
184
185
186
######################################
187
# Deep Attention MISL Implementation #
188
######################################
189
class MIL_Cluster_FC_surv(nn.Module):
190
    def __init__(self, omic_input_dim=None, fusion=None, num_clusters=10, size_arg = "small", dropout=0.25, n_classes=4):
191
        r"""
192
        Attention MIL Implementation
193
194
        Args:
195
            omic_input_dim (int): Dimension size of genomic features.
196
            fusion (str): Fusion method (Choices: concat, bilinear, or None)
197
            size_arg (str): Size of NN architecture (Choices: small or large)
198
            dropout (float): Dropout rate
199
            n_classes (int): Output shape of NN
200
        """
201
        super(MIL_Cluster_FC_surv, self).__init__()
202
        self.size_dict_path = {"small": [1024, 512, 256], "big": [1024, 512, 384]}
203
        self.size_dict_omic = {'small': [256, 256]}
204
        self.num_clusters = num_clusters
205
        self.fusion = fusion
206
        
207
        ### FC Cluster layers + Pooling
208
        size = self.size_dict_path[size_arg]
209
        phis = []
210
        for phenotype_i in range(num_clusters):
211
            phi = [nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout),
212
                   nn.Linear(size[1], size[1]), nn.ReLU(), nn.Dropout(dropout)]
213
            phis.append(nn.Sequential(*phi))
214
        self.phis = nn.ModuleList(phis)
215
        self.pool1d = nn.AdaptiveAvgPool1d(1)
216
        
217
        ### WSI Attention MIL Construction
218
        fc = [nn.Linear(size[1], size[1]), nn.ReLU(), nn.Dropout(dropout)]
219
        attention_net = Attn_Net_Gated(L=size[1], D=size[2], dropout=dropout, n_classes=1)
220
        fc.append(attention_net)
221
        self.attention_net = nn.Sequential(*fc)
222
        self.rho = nn.Sequential(*[nn.Linear(size[1], size[2]), nn.ReLU(), nn.Dropout(dropout)])
223
224
        ### Genomic SNN Construction + Multimodal Fusion
225
        if fusion is not None:
226
            hidden = self.size_dict_omic['small']
227
            fc_omic = [SNN_Block(dim1=omic_input_dim, dim2=hidden[0])]
228
            for i, _ in enumerate(hidden[1:]):
229
                fc_omic.append(SNN_Block(dim1=hidden[i], dim2=hidden[i+1], dropout=0.25))
230
            self.fc_omic = nn.Sequential(*fc_omic)
231
232
            if fusion == 'concat':
233
                self.mm = nn.Sequential(*[nn.Linear(size[2]*2, size[2]), nn.ReLU(), nn.Linear(size[2], size[2]), nn.ReLU()])
234
            elif self.fusion == 'bilinear':
235
                self.mm = BilinearFusion(dim1=256, dim2=256, scale_dim1=8, scale_dim2=8, mmhid=256)
236
            else:
237
                self.mm = None
238
239
        self.classifier = nn.Linear(size[2], n_classes)
240
241
242
    def relocate(self):
243
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
244
        if torch.cuda.device_count() >= 1:
245
            device_ids = list(range(torch.cuda.device_count()))
246
            self.attention_net = nn.DataParallel(self.attention_net, device_ids=device_ids).to('cuda:0')
247
        else:
248
            self.attention_net = self.attention_net.to(device)
249
250
        if self.fusion is not None:
251
            self.fc_omic = self.fc_omic.to(device)
252
            self.mm = self.mm.to(device)
253
254
        self.phis = self.phis.to(device)
255
        self.pool1d = self.pool1d.to(device)
256
        self.rho = self.rho.to(device)
257
        self.classifier = self.classifier.to(device)
258
259
260
    def forward(self, **kwargs):
261
        x_path = kwargs['x_path']
262
        cluster_id = kwargs['cluster_id'].detach().cpu().numpy()
263
264
        ### FC Cluster layers + Pooling
265
        h_cluster = []
266
        for i in range(self.num_clusters):
267
            h_cluster_i = self.phis[i](x_path[cluster_id==i])
268
            if h_cluster_i.shape[0] == 0:
269
                h_cluster_i = torch.zeros((1,512)).to(torch.device('cuda'))
270
            h_cluster.append(self.pool1d(h_cluster_i.T.unsqueeze(0)).squeeze(2))
271
        h_cluster = torch.stack(h_cluster, dim=1).squeeze(0)
272
273
        ### Attention MIL
274
        A, h_path = self.attention_net(h_cluster)  
275
        A = torch.transpose(A, 1, 0)
276
        A_raw = A 
277
        A = F.softmax(A, dim=1) 
278
        h_path = torch.mm(A, h_path)
279
        h_path = self.rho(h_path).squeeze()
280
281
        ### Attention MIL + Genomic Fusion
282
        if self.fusion is not None:
283
            x_omic = kwargs['x_omic']
284
            h_omic = self.fc_omic(x_omic)
285
            if self.fusion == 'bilinear':
286
                h = self.mm(h_path.unsqueeze(dim=0), h_omic.unsqueeze(dim=0)).squeeze()
287
            elif self.fusion == 'concat':
288
                h = self.mm(torch.cat([h_path, h_omic], axis=0))
289
        else:
290
            h = h_path
291
292
        logits  = self.classifier(h).unsqueeze(0)
293
        Y_hat = torch.topk(logits, 1, dim = 1)[1]
294
        hazards = torch.sigmoid(logits)
295
        S = torch.cumprod(1 - hazards, dim=1)
296
        
297
        return hazards, S, Y_hat, None, None