a b/im4MEC.py
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F
4
import numpy as np
5
import math
6
7
# Adapted from https://github.com/AIRMEC/im4MEC
8
9
class Attn_Net_Gated(nn.Module):
10
11
    def __init__(self, L=1024, D=256, dropout=False, p_dropout_atn=0.25, n_classes=1):
12
        super(Attn_Net_Gated, self).__init__()
13
14
        self.attention_a = [nn.Linear(L, D), nn.Tanh()]
15
16
        self.attention_b = [nn.Linear(L, D), nn.Sigmoid()]
17
18
        if dropout:
19
            self.attention_a.append(nn.Dropout(p_dropout_atn))
20
            self.attention_b.append(nn.Dropout(p_dropout_atn))
21
22
        self.attention_a = nn.Sequential(*self.attention_a)
23
        self.attention_b = nn.Sequential(*self.attention_b)
24
        self.attention_c = nn.Linear(D, n_classes)
25
26
    def forward(self, x):
27
        a = self.attention_a(x)
28
        b = self.attention_b(x)
29
        A = a.mul(b)
30
        A = self.attention_c(A)  # N x n_classes
31
        return A
32
33
34
class Im4MEC(nn.Module):
35
    def __init__(
36
        self,
37
        input_feature_size=1024,
38
        precompression_layer=True,
39
        feature_size_comp = 512,
40
        feature_size_attn = 256,
41
        dropout=True,
42
        p_dropout_fc=0.25,
43
        p_dropout_atn=0.25,
44
        n_classes=4,
45
    ):
46
        super(Im4MEC, self).__init__()
47
48
        self.n_classes = n_classes
49
50
        if precompression_layer:
51
            self.compression_layer = nn.Sequential(*[
52
                                                    nn.Linear(input_feature_size, feature_size_comp*4), 
53
                                                    nn.ReLU(), 
54
                                                    nn.Dropout(p_dropout_fc),
55
                                                    nn.Linear(feature_size_comp*4, feature_size_comp*2), 
56
                                                    nn.ReLU(), 
57
                                                    nn.Dropout(p_dropout_fc),
58
                                                    nn.Linear(feature_size_comp*2, feature_size_comp), 
59
                                                    nn.ReLU(), 
60
                                                    nn.Dropout(p_dropout_fc)])
61
62
            dim_post_compression = feature_size_comp
63
        else:
64
            self.compression_layer = nn.Identity()
65
            dim_post_compression = input_feature_size
66
67
        self.attention_net = Attn_Net_Gated(
68
            L=dim_post_compression,
69
            D=feature_size_attn,
70
            dropout=dropout,
71
            p_dropout_atn=p_dropout_atn,
72
            n_classes=self.n_classes)
73
74
        # Classification head.
75
        self.classifiers = nn.ModuleList(
76
            [nn.Linear(dim_post_compression, 1) for i in range(self.n_classes)]
77
        )
78
79
        # Init weights.
80
        self.apply(self._init_weights)
81
    
82
    def _init_weights(self, module):
83
        if isinstance(module, nn.Linear):
84
            nn.init.xavier_normal_(module.weight)
85
            if module.bias is not None:
86
                module.bias.data.zero_()
87
88
    def forward_attention(self, h):
89
        A_ = self.attention_net(h)  # h shape is N_tilesxdim
90
        A_raw = torch.transpose(A_, 1, 0)  # K_attention_classesxN_tiles
91
        A = F.softmax(A_raw, dim=-1)  # normalize attentions scores over tiles
92
        return A_raw, A
93
94
    def forward(self, h):
95
    
96
        h = self.compression_layer(h)
97
98
        # Attention MIL.
99
        A_raw, A = self.forward_attention(h) # 1xN tiles
100
        M = A @ h #torch.Size([1, dim_embedding])  # 1x512 [Sum over N(aihi,1), ..., Sum over N(aihi,512)]
101
102
        logits = torch.empty(1, self.n_classes).float().to(h.device)
103
        for c in range(self.n_classes):
104
            logits[0, c] = self.classifiers[c](M[c])
105
        Y_hat = torch.topk(logits, 1, dim=1)[1]
106
        Y_prob = F.softmax(logits, dim=1)
107
108
        return logits, Y_prob, Y_hat, A_raw, M