a b/biovil_t/transformer.py
1
#  -------------------------------------------------------------------------------------------
2
#  Copyright (c) Microsoft Corporation. All rights reserved.
3
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
#  -------------------------------------------------------------------------------------------
5
6
import math
7
from dataclasses import dataclass
8
from functools import partial
9
from typing import Any, Callable, Optional, Set, Tuple
10
11
import torch
12
import torch.nn as nn
13
from timm.models.layers import DropPath, Mlp, trunc_normal_
14
15
16
def torch_int_div(tensor1, tensor2):
17
    """
18
    A function that performs integer division across different versions of PyTorch.
19
    """
20
    return torch.div(tensor1, tensor2, rounding_mode="floor")
21
22
@dataclass
23
class MultiHeadAttentionOutput:
24
    mha_output: torch.Tensor
25
    attention: Optional[torch.Tensor] = None
26
27
28
class VisionTransformerPooler(nn.Module):
29
    """
30
    :param input_dim: Input feature dimension (i.e., channels in old CNN terminology)
31
    :param grid_shape: Shape of the grid of patches per image
32
    :param num_heads: Number of self-attention heads within the MHA block
33
    :param num_blocks: Number of blocks per attention layer
34
    :param norm_layer: Normalisation layer
35
36
    `self.type_embed`: Is used to characterise prior and current scans, and
37
                       create permutation variance across modalities/series.
38
    """
39
40
    def __init__(self,
41
                 input_dim: int,
42
                 grid_shape: Tuple[int, int],
43
                 num_heads: int = 8,
44
                 num_blocks: int = 3,
45
                 norm_layer: Any = partial(nn.LayerNorm, eps=1e-6)):
46
        super().__init__()
47
48
        block_kwargs = dict(dim=input_dim, num_heads=num_heads, mlp_ratio=1., drop=0.10, attn_drop=0.10,
49
                            drop_path=0.25, act_layer=nn.GELU, norm_layer=norm_layer)
50
        self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_blocks)])
51
        self.norm_post = norm_layer(input_dim)
52
        self.grid_shape = grid_shape
53
        self.num_patches = grid_shape[0] * grid_shape[1]
54
        self.num_blocks = num_blocks
55
56
        # Temporal positional embeddings
57
        num_series: int = 2
58
        self.type_embed = nn.Parameter(torch.zeros(num_series, 1, input_dim))
59
        trunc_normal_(self.type_embed, std=.02)
60
61
        # Positional embeddings 1 x L x C (L: Sequence length, C: Feature dimension)
62
        self.pos_drop = nn.Dropout(p=0.10)
63
        pos_embed_class = SinePositionEmbedding(embedding_dim=input_dim // 2, normalize=True)
64
        pos_embed = pos_embed_class(mask=torch.ones([1, grid_shape[0], grid_shape[1]]))  # 1 x L x C
65
        self.register_buffer("pos_embed", pos_embed, persistent=False)
66
67
        # Initialisation
68
        self.apply(self._init_weights)
69
70
    def no_weight_decay(self) -> Set[str]:
71
        return {'type_embed'}
72
73
    def forward(self, current_image: torch.Tensor, previous_image: Optional[torch.Tensor] = None) -> torch.Tensor:
74
        B, C, H, W = current_image.shape
75
        assert H == self.grid_shape[0] and W == self.grid_shape[1], "Input and grid shapes do not match"
76
77
        # Flatten patch embeddings to have shape (B x L x C), L = H * W
78
        if previous_image is not None:
79
            assert previous_image.shape == current_image.shape, "current_image and previous_image shapes do not match"
80
            previous_image = previous_image.view(B, C, H * W).transpose(1, 2)
81
        current_image = current_image.view(B, C, H * W).transpose(1, 2)
82
        pos_embed = self.pos_embed.repeat(B, 1, 1)  # type: ignore
83
84
        # Final token activations (B x 2L x C)
85
        token_features = self.forward_after_reshape(x=current_image, pos_embed=pos_embed, x_previous=previous_image)
86
87
        # Extract the patch features of current image
88
        cur_img_token_id = 0
89
        current_token_features = token_features[:, cur_img_token_id:self.num_patches+cur_img_token_id]
90
        current_patch_features = current_token_features.transpose(1, 2).view(B, C, H, W)
91
92
        return current_patch_features
93
94
    def forward_after_reshape(self,
95
                              x: torch.Tensor,
96
                              pos_embed: torch.Tensor,
97
                              x_previous: Optional[torch.Tensor] = None) -> torch.Tensor:
98
        B, L, _ = x.shape  # Batch, Sequence length, Feature dimension
99
100
        # Positional and type embeddings
101
        type_embed = self.type_embed[0].expand(B, L, -1)
102
        if x_previous is not None:
103
            x = torch.cat((x, x_previous), dim=1)
104
            pos_embed = torch.cat((pos_embed, pos_embed), dim=1)
105
            prev_type_embed = self.type_embed[1].expand(B, L, -1)
106
            type_embed = torch.cat((type_embed, prev_type_embed), dim=1)
107
108
        # Add positional and type embeddings (used in query and key matching)
109
        pos_and_type_embed = pos_embed + type_embed
110
111
        # Positional dropout
112
        x = self.pos_drop(x)
113
114
        # Multihead attention followed by MLP
115
        for block in self.blocks:
116
            x = block(x=x, pos_and_type_embed=pos_and_type_embed)
117
        x = self.norm_post(x)
118
119
        return x
120
121
    def _init_weights(self, m: nn.Module) -> None:
122
        if isinstance(m, nn.Linear):
123
            trunc_normal_(m.weight, std=.02)
124
            if isinstance(m, nn.Linear) and m.bias is not None:
125
                nn.init.constant_(m.bias, 0)
126
        elif isinstance(m, nn.LayerNorm):
127
            nn.init.constant_(m.bias, 0)
128
            nn.init.constant_(m.weight, 1.0)
129
130
131
class MultiHeadAttentionLayer(nn.Module):
132
    """
133
    Multi-head self attention module
134
135
    The content builds on top of the TIMM library (vision_transformer.py) and differs by the following:
136
        - Defines a custom `MultiHeadAttentionLayer` which does not only apply `self-attention` but it can be
137
            generalised to arbitrary (query, key, value) input tuples. This feature can be valuable to process
138
            more than 2 scans at a time.
139
        - `Self-attention` specific use-case can still be invoked by calling the `forward_as_mhsa` method.
140
    """
141
142
    def __init__(self,
143
                 dim: int,
144
                 num_heads: int = 8,
145
                 qkv_bias: bool = False,
146
                 attn_drop: float = 0.,
147
                 proj_drop: float = 0.) -> None:
148
        super().__init__()
149
        self.num_heads = num_heads
150
        assert dim % num_heads == 0, f"The embedding dim ({dim}) must be divisible by the number of heads ({num_heads})"
151
        head_dim = dim // num_heads
152
        self.scale = head_dim ** -0.5
153
        self.return_attention = False
154
155
        self.proj_q = nn.Linear(dim, dim, bias=qkv_bias)
156
        self.proj_k = nn.Linear(dim, dim, bias=qkv_bias)
157
        self.proj_v = nn.Linear(dim, dim, bias=qkv_bias)
158
159
        self.attn_drop = nn.Dropout(attn_drop)
160
        self.proj = nn.Linear(dim, dim)
161
        self.proj_drop = nn.Dropout(proj_drop)
162
163
    def forward(self, k: torch.Tensor, q: torch.Tensor, v: torch.Tensor) -> MultiHeadAttentionOutput:
164
        B, N, C = v.shape
165
        assert C % self.num_heads == 0, \
166
            f"The embedding dim ({C}) must be divisible by the number of heads ({self.num_heads})"
167
168
        w_q = self.proj_q(q).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
169
        w_k = self.proj_k(k).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
170
        w_v = self.proj_v(v).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
171
172
        attn = (w_q @ w_k.transpose(-2, -1)) * self.scale
173
        attn = attn.softmax(dim=-1)
174
        attn = self.attn_drop(attn)
175
176
        o = (attn @ w_v).transpose(1, 2).reshape(B, N, C)
177
        o = self.proj(o)
178
        o = self.proj_drop(o)
179
180
        attention_output = attn if self.return_attention else None
181
182
        return MultiHeadAttentionOutput(mha_output=o, attention=attention_output)
183
184
    def forward_as_mhsa(self, input: torch.Tensor) -> MultiHeadAttentionOutput:
185
        return self(k=input, q=input, v=input)
186
187
188
class Block(nn.Module):
189
    """
190
    Encapsulates multi-layer perceptron and multi-head self attention modules into a block.
191
192
    The content builds on top of the TIMM library (vision_transformer.py) and differs by the following:
193
        - This implementation uses spatio-temporal positional embeddings instead of 2D positional embeddings only,
194
            and they are taken into account within the forward pass of each ViT block.
195
        - Utilises the custom defined `MultiHeadAttentionLayer` which does not apply `self-attention` only but can be
196
            generalised to arbitrary (query, key, value) tuples. This can be valuable to process more than 2 scans.
197
198
    Positional and type embeddings are handled in a similar fashion as DETR object localisation paper
199
    https://alcinos.github.io/detr_page/, where a fixed set of sine/cos positional embeddings are used
200
    in an additive manner to Q and K tensors.
201
    """
202
203
    def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 1., qkv_bias: bool = False, drop: float = 0.,
204
                 attn_drop: float = 0., drop_path: float = 0., act_layer: Callable = nn.GELU,
205
                 norm_layer: Callable = nn.LayerNorm) -> None:
206
        super().__init__()
207
        self.norm1 = norm_layer(dim)
208
        self.attn = MultiHeadAttentionLayer(dim=dim, num_heads=num_heads, qkv_bias=qkv_bias,
209
                                            attn_drop=attn_drop, proj_drop=drop)
210
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
211
        self.norm2 = norm_layer(dim)
212
        mlp_hidden_dim = int(dim * mlp_ratio)
213
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
214
215
    def with_pos_and_type_embed(self, tensor: torch.Tensor, emb: Optional[torch.Tensor]) -> torch.Tensor:
216
        # Add positional embeddings to key and query tensors
217
        return tensor if emb is None else tensor + emb
218
219
    def forward(self, x: torch.Tensor, pos_and_type_embed: Optional[torch.Tensor]) -> torch.Tensor:
220
        x_with_emb = self.with_pos_and_type_embed(self.norm1(x), emb=pos_and_type_embed)
221
        x = x + self.drop_path(self.attn.forward_as_mhsa(x_with_emb).mha_output)
222
        x = x + self.drop_path(self.mlp(self.norm2(x)))
223
224
        return x
225
226
227
class SinePositionEmbedding():
228
    """
229
    This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
230
    need paper, generalized to work on images.
231
    """
232
233
    def __init__(self,
234
                 embedding_dim: int = 64,
235
                 temperature: int = 10000,
236
                 normalize: bool = False,
237
                 scale: float = None) -> None:
238
        super().__init__()
239
        self.embedding_dim = embedding_dim
240
        self.temperature = temperature
241
        self.normalize = normalize
242
        if scale is not None and normalize is False:
243
            raise ValueError("normalize should be True if scale is passed")
244
        if scale is None:
245
            scale = 2 * math.pi
246
        self.scale = scale
247
248
    def __call__(self, mask: torch.Tensor) -> torch.Tensor:
249
        assert mask is not None, "No pixel mask provided"
250
        B, H, W = mask.shape
251
        y_embed = mask.cumsum(1, dtype=torch.float32)
252
        x_embed = mask.cumsum(2, dtype=torch.float32)
253
        if self.normalize:
254
            y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
255
            x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
256
257
        dim_t = torch.arange(self.embedding_dim, dtype=torch.float32)
258
        dim_t = self.temperature ** (2 * torch_int_div(dim_t, 2) / self.embedding_dim)
259
260
        pos_x = x_embed[:, :, :, None] / dim_t
261
        pos_y = y_embed[:, :, :, None] / dim_t
262
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
263
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
264
        pos = torch.cat((pos_y, pos_x), dim=3).view(B, H * W, self.embedding_dim * 2)
265
266
        return pos