Diff of /im4MEC.py [000000] .. [352cae]

Switch to side-by-side view

--- a
+++ b/im4MEC.py
@@ -0,0 +1,108 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import math
+
+# Adapted from https://github.com/AIRMEC/im4MEC
+
+class Attn_Net_Gated(nn.Module):
+
+    def __init__(self, L=1024, D=256, dropout=False, p_dropout_atn=0.25, n_classes=1):
+        super(Attn_Net_Gated, self).__init__()
+
+        self.attention_a = [nn.Linear(L, D), nn.Tanh()]
+
+        self.attention_b = [nn.Linear(L, D), nn.Sigmoid()]
+
+        if dropout:
+            self.attention_a.append(nn.Dropout(p_dropout_atn))
+            self.attention_b.append(nn.Dropout(p_dropout_atn))
+
+        self.attention_a = nn.Sequential(*self.attention_a)
+        self.attention_b = nn.Sequential(*self.attention_b)
+        self.attention_c = nn.Linear(D, n_classes)
+
+    def forward(self, x):
+        a = self.attention_a(x)
+        b = self.attention_b(x)
+        A = a.mul(b)
+        A = self.attention_c(A)  # N x n_classes
+        return A
+
+
+class Im4MEC(nn.Module):
+    def __init__(
+        self,
+        input_feature_size=1024,
+        precompression_layer=True,
+        feature_size_comp = 512,
+        feature_size_attn = 256,
+        dropout=True,
+        p_dropout_fc=0.25,
+        p_dropout_atn=0.25,
+        n_classes=4,
+    ):
+        super(Im4MEC, self).__init__()
+
+        self.n_classes = n_classes
+
+        if precompression_layer:
+            self.compression_layer = nn.Sequential(*[
+                                                    nn.Linear(input_feature_size, feature_size_comp*4), 
+                                                    nn.ReLU(), 
+                                                    nn.Dropout(p_dropout_fc),
+                                                    nn.Linear(feature_size_comp*4, feature_size_comp*2), 
+                                                    nn.ReLU(), 
+                                                    nn.Dropout(p_dropout_fc),
+                                                    nn.Linear(feature_size_comp*2, feature_size_comp), 
+                                                    nn.ReLU(), 
+                                                    nn.Dropout(p_dropout_fc)])
+
+            dim_post_compression = feature_size_comp
+        else:
+            self.compression_layer = nn.Identity()
+            dim_post_compression = input_feature_size
+
+        self.attention_net = Attn_Net_Gated(
+            L=dim_post_compression,
+            D=feature_size_attn,
+            dropout=dropout,
+            p_dropout_atn=p_dropout_atn,
+            n_classes=self.n_classes)
+
+        # Classification head.
+        self.classifiers = nn.ModuleList(
+            [nn.Linear(dim_post_compression, 1) for i in range(self.n_classes)]
+        )
+
+        # Init weights.
+        self.apply(self._init_weights)
+    
+    def _init_weights(self, module):
+        if isinstance(module, nn.Linear):
+            nn.init.xavier_normal_(module.weight)
+            if module.bias is not None:
+                module.bias.data.zero_()
+
+    def forward_attention(self, h):
+        A_ = self.attention_net(h)  # h shape is N_tilesxdim
+        A_raw = torch.transpose(A_, 1, 0)  # K_attention_classesxN_tiles
+        A = F.softmax(A_raw, dim=-1)  # normalize attentions scores over tiles
+        return A_raw, A
+
+    def forward(self, h):
+    
+        h = self.compression_layer(h)
+
+        # Attention MIL.
+        A_raw, A = self.forward_attention(h) # 1xN tiles
+        M = A @ h #torch.Size([1, dim_embedding])  # 1x512 [Sum over N(aihi,1), ..., Sum over N(aihi,512)]
+
+        logits = torch.empty(1, self.n_classes).float().to(h.device)
+        for c in range(self.n_classes):
+            logits[0, c] = self.classifiers[c](M[c])
+        Y_hat = torch.topk(logits, 1, dim=1)[1]
+        Y_prob = F.softmax(logits, dim=1)
+
+        return logits, Y_prob, Y_hat, A_raw, M
\ No newline at end of file