|
a |
|
b/im4MEC.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
import torch.nn.functional as F |
|
|
4 |
import numpy as np |
|
|
5 |
import math |
|
|
6 |
|
|
|
7 |
# Adapted from https://github.com/AIRMEC/im4MEC |
|
|
8 |
|
|
|
9 |
class Attn_Net_Gated(nn.Module): |
|
|
10 |
|
|
|
11 |
def __init__(self, L=1024, D=256, dropout=False, p_dropout_atn=0.25, n_classes=1): |
|
|
12 |
super(Attn_Net_Gated, self).__init__() |
|
|
13 |
|
|
|
14 |
self.attention_a = [nn.Linear(L, D), nn.Tanh()] |
|
|
15 |
|
|
|
16 |
self.attention_b = [nn.Linear(L, D), nn.Sigmoid()] |
|
|
17 |
|
|
|
18 |
if dropout: |
|
|
19 |
self.attention_a.append(nn.Dropout(p_dropout_atn)) |
|
|
20 |
self.attention_b.append(nn.Dropout(p_dropout_atn)) |
|
|
21 |
|
|
|
22 |
self.attention_a = nn.Sequential(*self.attention_a) |
|
|
23 |
self.attention_b = nn.Sequential(*self.attention_b) |
|
|
24 |
self.attention_c = nn.Linear(D, n_classes) |
|
|
25 |
|
|
|
26 |
def forward(self, x): |
|
|
27 |
a = self.attention_a(x) |
|
|
28 |
b = self.attention_b(x) |
|
|
29 |
A = a.mul(b) |
|
|
30 |
A = self.attention_c(A) # N x n_classes |
|
|
31 |
return A |
|
|
32 |
|
|
|
33 |
|
|
|
34 |
class Im4MEC(nn.Module): |
|
|
35 |
def __init__( |
|
|
36 |
self, |
|
|
37 |
input_feature_size=1024, |
|
|
38 |
precompression_layer=True, |
|
|
39 |
feature_size_comp = 512, |
|
|
40 |
feature_size_attn = 256, |
|
|
41 |
dropout=True, |
|
|
42 |
p_dropout_fc=0.25, |
|
|
43 |
p_dropout_atn=0.25, |
|
|
44 |
n_classes=4, |
|
|
45 |
): |
|
|
46 |
super(Im4MEC, self).__init__() |
|
|
47 |
|
|
|
48 |
self.n_classes = n_classes |
|
|
49 |
|
|
|
50 |
if precompression_layer: |
|
|
51 |
self.compression_layer = nn.Sequential(*[ |
|
|
52 |
nn.Linear(input_feature_size, feature_size_comp*4), |
|
|
53 |
nn.ReLU(), |
|
|
54 |
nn.Dropout(p_dropout_fc), |
|
|
55 |
nn.Linear(feature_size_comp*4, feature_size_comp*2), |
|
|
56 |
nn.ReLU(), |
|
|
57 |
nn.Dropout(p_dropout_fc), |
|
|
58 |
nn.Linear(feature_size_comp*2, feature_size_comp), |
|
|
59 |
nn.ReLU(), |
|
|
60 |
nn.Dropout(p_dropout_fc)]) |
|
|
61 |
|
|
|
62 |
dim_post_compression = feature_size_comp |
|
|
63 |
else: |
|
|
64 |
self.compression_layer = nn.Identity() |
|
|
65 |
dim_post_compression = input_feature_size |
|
|
66 |
|
|
|
67 |
self.attention_net = Attn_Net_Gated( |
|
|
68 |
L=dim_post_compression, |
|
|
69 |
D=feature_size_attn, |
|
|
70 |
dropout=dropout, |
|
|
71 |
p_dropout_atn=p_dropout_atn, |
|
|
72 |
n_classes=self.n_classes) |
|
|
73 |
|
|
|
74 |
# Classification head. |
|
|
75 |
self.classifiers = nn.ModuleList( |
|
|
76 |
[nn.Linear(dim_post_compression, 1) for i in range(self.n_classes)] |
|
|
77 |
) |
|
|
78 |
|
|
|
79 |
# Init weights. |
|
|
80 |
self.apply(self._init_weights) |
|
|
81 |
|
|
|
82 |
def _init_weights(self, module): |
|
|
83 |
if isinstance(module, nn.Linear): |
|
|
84 |
nn.init.xavier_normal_(module.weight) |
|
|
85 |
if module.bias is not None: |
|
|
86 |
module.bias.data.zero_() |
|
|
87 |
|
|
|
88 |
def forward_attention(self, h): |
|
|
89 |
A_ = self.attention_net(h) # h shape is N_tilesxdim |
|
|
90 |
A_raw = torch.transpose(A_, 1, 0) # K_attention_classesxN_tiles |
|
|
91 |
A = F.softmax(A_raw, dim=-1) # normalize attentions scores over tiles |
|
|
92 |
return A_raw, A |
|
|
93 |
|
|
|
94 |
def forward(self, h): |
|
|
95 |
|
|
|
96 |
h = self.compression_layer(h) |
|
|
97 |
|
|
|
98 |
# Attention MIL. |
|
|
99 |
A_raw, A = self.forward_attention(h) # 1xN tiles |
|
|
100 |
M = A @ h #torch.Size([1, dim_embedding]) # 1x512 [Sum over N(aihi,1), ..., Sum over N(aihi,512)] |
|
|
101 |
|
|
|
102 |
logits = torch.empty(1, self.n_classes).float().to(h.device) |
|
|
103 |
for c in range(self.n_classes): |
|
|
104 |
logits[0, c] = self.classifiers[c](M[c]) |
|
|
105 |
Y_hat = torch.topk(logits, 1, dim=1)[1] |
|
|
106 |
Y_prob = F.softmax(logits, dim=1) |
|
|
107 |
|
|
|
108 |
return logits, Y_prob, Y_hat, A_raw, M |