|
a |
|
b/biovil_t/transformer.py |
|
|
1 |
# ------------------------------------------------------------------------------------------- |
|
|
2 |
# Copyright (c) Microsoft Corporation. All rights reserved. |
|
|
3 |
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. |
|
|
4 |
# ------------------------------------------------------------------------------------------- |
|
|
5 |
|
|
|
6 |
import math |
|
|
7 |
from dataclasses import dataclass |
|
|
8 |
from functools import partial |
|
|
9 |
from typing import Any, Callable, Optional, Set, Tuple |
|
|
10 |
|
|
|
11 |
import torch |
|
|
12 |
import torch.nn as nn |
|
|
13 |
from timm.models.layers import DropPath, Mlp, trunc_normal_ |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
def torch_int_div(tensor1, tensor2): |
|
|
17 |
""" |
|
|
18 |
A function that performs integer division across different versions of PyTorch. |
|
|
19 |
""" |
|
|
20 |
return torch.div(tensor1, tensor2, rounding_mode="floor") |
|
|
21 |
|
|
|
22 |
@dataclass |
|
|
23 |
class MultiHeadAttentionOutput: |
|
|
24 |
mha_output: torch.Tensor |
|
|
25 |
attention: Optional[torch.Tensor] = None |
|
|
26 |
|
|
|
27 |
|
|
|
28 |
class VisionTransformerPooler(nn.Module): |
|
|
29 |
""" |
|
|
30 |
:param input_dim: Input feature dimension (i.e., channels in old CNN terminology) |
|
|
31 |
:param grid_shape: Shape of the grid of patches per image |
|
|
32 |
:param num_heads: Number of self-attention heads within the MHA block |
|
|
33 |
:param num_blocks: Number of blocks per attention layer |
|
|
34 |
:param norm_layer: Normalisation layer |
|
|
35 |
|
|
|
36 |
`self.type_embed`: Is used to characterise prior and current scans, and |
|
|
37 |
create permutation variance across modalities/series. |
|
|
38 |
""" |
|
|
39 |
|
|
|
40 |
def __init__(self, |
|
|
41 |
input_dim: int, |
|
|
42 |
grid_shape: Tuple[int, int], |
|
|
43 |
num_heads: int = 8, |
|
|
44 |
num_blocks: int = 3, |
|
|
45 |
norm_layer: Any = partial(nn.LayerNorm, eps=1e-6)): |
|
|
46 |
super().__init__() |
|
|
47 |
|
|
|
48 |
block_kwargs = dict(dim=input_dim, num_heads=num_heads, mlp_ratio=1., drop=0.10, attn_drop=0.10, |
|
|
49 |
drop_path=0.25, act_layer=nn.GELU, norm_layer=norm_layer) |
|
|
50 |
self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_blocks)]) |
|
|
51 |
self.norm_post = norm_layer(input_dim) |
|
|
52 |
self.grid_shape = grid_shape |
|
|
53 |
self.num_patches = grid_shape[0] * grid_shape[1] |
|
|
54 |
self.num_blocks = num_blocks |
|
|
55 |
|
|
|
56 |
# Temporal positional embeddings |
|
|
57 |
num_series: int = 2 |
|
|
58 |
self.type_embed = nn.Parameter(torch.zeros(num_series, 1, input_dim)) |
|
|
59 |
trunc_normal_(self.type_embed, std=.02) |
|
|
60 |
|
|
|
61 |
# Positional embeddings 1 x L x C (L: Sequence length, C: Feature dimension) |
|
|
62 |
self.pos_drop = nn.Dropout(p=0.10) |
|
|
63 |
pos_embed_class = SinePositionEmbedding(embedding_dim=input_dim // 2, normalize=True) |
|
|
64 |
pos_embed = pos_embed_class(mask=torch.ones([1, grid_shape[0], grid_shape[1]])) # 1 x L x C |
|
|
65 |
self.register_buffer("pos_embed", pos_embed, persistent=False) |
|
|
66 |
|
|
|
67 |
# Initialisation |
|
|
68 |
self.apply(self._init_weights) |
|
|
69 |
|
|
|
70 |
def no_weight_decay(self) -> Set[str]: |
|
|
71 |
return {'type_embed'} |
|
|
72 |
|
|
|
73 |
def forward(self, current_image: torch.Tensor, previous_image: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
74 |
B, C, H, W = current_image.shape |
|
|
75 |
assert H == self.grid_shape[0] and W == self.grid_shape[1], "Input and grid shapes do not match" |
|
|
76 |
|
|
|
77 |
# Flatten patch embeddings to have shape (B x L x C), L = H * W |
|
|
78 |
if previous_image is not None: |
|
|
79 |
assert previous_image.shape == current_image.shape, "current_image and previous_image shapes do not match" |
|
|
80 |
previous_image = previous_image.view(B, C, H * W).transpose(1, 2) |
|
|
81 |
current_image = current_image.view(B, C, H * W).transpose(1, 2) |
|
|
82 |
pos_embed = self.pos_embed.repeat(B, 1, 1) # type: ignore |
|
|
83 |
|
|
|
84 |
# Final token activations (B x 2L x C) |
|
|
85 |
token_features = self.forward_after_reshape(x=current_image, pos_embed=pos_embed, x_previous=previous_image) |
|
|
86 |
|
|
|
87 |
# Extract the patch features of current image |
|
|
88 |
cur_img_token_id = 0 |
|
|
89 |
current_token_features = token_features[:, cur_img_token_id:self.num_patches+cur_img_token_id] |
|
|
90 |
current_patch_features = current_token_features.transpose(1, 2).view(B, C, H, W) |
|
|
91 |
|
|
|
92 |
return current_patch_features |
|
|
93 |
|
|
|
94 |
def forward_after_reshape(self, |
|
|
95 |
x: torch.Tensor, |
|
|
96 |
pos_embed: torch.Tensor, |
|
|
97 |
x_previous: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
98 |
B, L, _ = x.shape # Batch, Sequence length, Feature dimension |
|
|
99 |
|
|
|
100 |
# Positional and type embeddings |
|
|
101 |
type_embed = self.type_embed[0].expand(B, L, -1) |
|
|
102 |
if x_previous is not None: |
|
|
103 |
x = torch.cat((x, x_previous), dim=1) |
|
|
104 |
pos_embed = torch.cat((pos_embed, pos_embed), dim=1) |
|
|
105 |
prev_type_embed = self.type_embed[1].expand(B, L, -1) |
|
|
106 |
type_embed = torch.cat((type_embed, prev_type_embed), dim=1) |
|
|
107 |
|
|
|
108 |
# Add positional and type embeddings (used in query and key matching) |
|
|
109 |
pos_and_type_embed = pos_embed + type_embed |
|
|
110 |
|
|
|
111 |
# Positional dropout |
|
|
112 |
x = self.pos_drop(x) |
|
|
113 |
|
|
|
114 |
# Multihead attention followed by MLP |
|
|
115 |
for block in self.blocks: |
|
|
116 |
x = block(x=x, pos_and_type_embed=pos_and_type_embed) |
|
|
117 |
x = self.norm_post(x) |
|
|
118 |
|
|
|
119 |
return x |
|
|
120 |
|
|
|
121 |
def _init_weights(self, m: nn.Module) -> None: |
|
|
122 |
if isinstance(m, nn.Linear): |
|
|
123 |
trunc_normal_(m.weight, std=.02) |
|
|
124 |
if isinstance(m, nn.Linear) and m.bias is not None: |
|
|
125 |
nn.init.constant_(m.bias, 0) |
|
|
126 |
elif isinstance(m, nn.LayerNorm): |
|
|
127 |
nn.init.constant_(m.bias, 0) |
|
|
128 |
nn.init.constant_(m.weight, 1.0) |
|
|
129 |
|
|
|
130 |
|
|
|
131 |
class MultiHeadAttentionLayer(nn.Module): |
|
|
132 |
""" |
|
|
133 |
Multi-head self attention module |
|
|
134 |
|
|
|
135 |
The content builds on top of the TIMM library (vision_transformer.py) and differs by the following: |
|
|
136 |
- Defines a custom `MultiHeadAttentionLayer` which does not only apply `self-attention` but it can be |
|
|
137 |
generalised to arbitrary (query, key, value) input tuples. This feature can be valuable to process |
|
|
138 |
more than 2 scans at a time. |
|
|
139 |
- `Self-attention` specific use-case can still be invoked by calling the `forward_as_mhsa` method. |
|
|
140 |
""" |
|
|
141 |
|
|
|
142 |
def __init__(self, |
|
|
143 |
dim: int, |
|
|
144 |
num_heads: int = 8, |
|
|
145 |
qkv_bias: bool = False, |
|
|
146 |
attn_drop: float = 0., |
|
|
147 |
proj_drop: float = 0.) -> None: |
|
|
148 |
super().__init__() |
|
|
149 |
self.num_heads = num_heads |
|
|
150 |
assert dim % num_heads == 0, f"The embedding dim ({dim}) must be divisible by the number of heads ({num_heads})" |
|
|
151 |
head_dim = dim // num_heads |
|
|
152 |
self.scale = head_dim ** -0.5 |
|
|
153 |
self.return_attention = False |
|
|
154 |
|
|
|
155 |
self.proj_q = nn.Linear(dim, dim, bias=qkv_bias) |
|
|
156 |
self.proj_k = nn.Linear(dim, dim, bias=qkv_bias) |
|
|
157 |
self.proj_v = nn.Linear(dim, dim, bias=qkv_bias) |
|
|
158 |
|
|
|
159 |
self.attn_drop = nn.Dropout(attn_drop) |
|
|
160 |
self.proj = nn.Linear(dim, dim) |
|
|
161 |
self.proj_drop = nn.Dropout(proj_drop) |
|
|
162 |
|
|
|
163 |
def forward(self, k: torch.Tensor, q: torch.Tensor, v: torch.Tensor) -> MultiHeadAttentionOutput: |
|
|
164 |
B, N, C = v.shape |
|
|
165 |
assert C % self.num_heads == 0, \ |
|
|
166 |
f"The embedding dim ({C}) must be divisible by the number of heads ({self.num_heads})" |
|
|
167 |
|
|
|
168 |
w_q = self.proj_q(q).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) |
|
|
169 |
w_k = self.proj_k(k).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) |
|
|
170 |
w_v = self.proj_v(v).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) |
|
|
171 |
|
|
|
172 |
attn = (w_q @ w_k.transpose(-2, -1)) * self.scale |
|
|
173 |
attn = attn.softmax(dim=-1) |
|
|
174 |
attn = self.attn_drop(attn) |
|
|
175 |
|
|
|
176 |
o = (attn @ w_v).transpose(1, 2).reshape(B, N, C) |
|
|
177 |
o = self.proj(o) |
|
|
178 |
o = self.proj_drop(o) |
|
|
179 |
|
|
|
180 |
attention_output = attn if self.return_attention else None |
|
|
181 |
|
|
|
182 |
return MultiHeadAttentionOutput(mha_output=o, attention=attention_output) |
|
|
183 |
|
|
|
184 |
def forward_as_mhsa(self, input: torch.Tensor) -> MultiHeadAttentionOutput: |
|
|
185 |
return self(k=input, q=input, v=input) |
|
|
186 |
|
|
|
187 |
|
|
|
188 |
class Block(nn.Module): |
|
|
189 |
""" |
|
|
190 |
Encapsulates multi-layer perceptron and multi-head self attention modules into a block. |
|
|
191 |
|
|
|
192 |
The content builds on top of the TIMM library (vision_transformer.py) and differs by the following: |
|
|
193 |
- This implementation uses spatio-temporal positional embeddings instead of 2D positional embeddings only, |
|
|
194 |
and they are taken into account within the forward pass of each ViT block. |
|
|
195 |
- Utilises the custom defined `MultiHeadAttentionLayer` which does not apply `self-attention` only but can be |
|
|
196 |
generalised to arbitrary (query, key, value) tuples. This can be valuable to process more than 2 scans. |
|
|
197 |
|
|
|
198 |
Positional and type embeddings are handled in a similar fashion as DETR object localisation paper |
|
|
199 |
https://alcinos.github.io/detr_page/, where a fixed set of sine/cos positional embeddings are used |
|
|
200 |
in an additive manner to Q and K tensors. |
|
|
201 |
""" |
|
|
202 |
|
|
|
203 |
def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 1., qkv_bias: bool = False, drop: float = 0., |
|
|
204 |
attn_drop: float = 0., drop_path: float = 0., act_layer: Callable = nn.GELU, |
|
|
205 |
norm_layer: Callable = nn.LayerNorm) -> None: |
|
|
206 |
super().__init__() |
|
|
207 |
self.norm1 = norm_layer(dim) |
|
|
208 |
self.attn = MultiHeadAttentionLayer(dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, |
|
|
209 |
attn_drop=attn_drop, proj_drop=drop) |
|
|
210 |
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
211 |
self.norm2 = norm_layer(dim) |
|
|
212 |
mlp_hidden_dim = int(dim * mlp_ratio) |
|
|
213 |
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
|
|
214 |
|
|
|
215 |
def with_pos_and_type_embed(self, tensor: torch.Tensor, emb: Optional[torch.Tensor]) -> torch.Tensor: |
|
|
216 |
# Add positional embeddings to key and query tensors |
|
|
217 |
return tensor if emb is None else tensor + emb |
|
|
218 |
|
|
|
219 |
def forward(self, x: torch.Tensor, pos_and_type_embed: Optional[torch.Tensor]) -> torch.Tensor: |
|
|
220 |
x_with_emb = self.with_pos_and_type_embed(self.norm1(x), emb=pos_and_type_embed) |
|
|
221 |
x = x + self.drop_path(self.attn.forward_as_mhsa(x_with_emb).mha_output) |
|
|
222 |
x = x + self.drop_path(self.mlp(self.norm2(x))) |
|
|
223 |
|
|
|
224 |
return x |
|
|
225 |
|
|
|
226 |
|
|
|
227 |
class SinePositionEmbedding(): |
|
|
228 |
""" |
|
|
229 |
This is a more standard version of the position embedding, very similar to the one used by the Attention is all you |
|
|
230 |
need paper, generalized to work on images. |
|
|
231 |
""" |
|
|
232 |
|
|
|
233 |
def __init__(self, |
|
|
234 |
embedding_dim: int = 64, |
|
|
235 |
temperature: int = 10000, |
|
|
236 |
normalize: bool = False, |
|
|
237 |
scale: float = None) -> None: |
|
|
238 |
super().__init__() |
|
|
239 |
self.embedding_dim = embedding_dim |
|
|
240 |
self.temperature = temperature |
|
|
241 |
self.normalize = normalize |
|
|
242 |
if scale is not None and normalize is False: |
|
|
243 |
raise ValueError("normalize should be True if scale is passed") |
|
|
244 |
if scale is None: |
|
|
245 |
scale = 2 * math.pi |
|
|
246 |
self.scale = scale |
|
|
247 |
|
|
|
248 |
def __call__(self, mask: torch.Tensor) -> torch.Tensor: |
|
|
249 |
assert mask is not None, "No pixel mask provided" |
|
|
250 |
B, H, W = mask.shape |
|
|
251 |
y_embed = mask.cumsum(1, dtype=torch.float32) |
|
|
252 |
x_embed = mask.cumsum(2, dtype=torch.float32) |
|
|
253 |
if self.normalize: |
|
|
254 |
y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale |
|
|
255 |
x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale |
|
|
256 |
|
|
|
257 |
dim_t = torch.arange(self.embedding_dim, dtype=torch.float32) |
|
|
258 |
dim_t = self.temperature ** (2 * torch_int_div(dim_t, 2) / self.embedding_dim) |
|
|
259 |
|
|
|
260 |
pos_x = x_embed[:, :, :, None] / dim_t |
|
|
261 |
pos_y = y_embed[:, :, :, None] / dim_t |
|
|
262 |
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) |
|
|
263 |
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) |
|
|
264 |
pos = torch.cat((pos_y, pos_x), dim=3).view(B, H * W, self.embedding_dim * 2) |
|
|
265 |
|
|
|
266 |
return pos |