|
a |
|
b/biovil_t/encoder.py |
|
|
1 |
# ------------------------------------------------------------------------------------------- |
|
|
2 |
# Copyright (c) Microsoft Corporation. All rights reserved. |
|
|
3 |
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. |
|
|
4 |
# ------------------------------------------------------------------------------------------- |
|
|
5 |
|
|
|
6 |
from __future__ import annotations |
|
|
7 |
|
|
|
8 |
from contextlib import contextmanager |
|
|
9 |
from typing import Any, Generator, Optional, Sequence, Tuple, Union |
|
|
10 |
|
|
|
11 |
import torch |
|
|
12 |
import torch.nn as nn |
|
|
13 |
from health_multimodal.common.device import get_module_device |
|
|
14 |
from timm.models.layers import trunc_normal_ |
|
|
15 |
|
|
|
16 |
from .resnet import resnet18, resnet50 |
|
|
17 |
from .transformer import VisionTransformerPooler |
|
|
18 |
from .types import ImageEncoderType |
|
|
19 |
|
|
|
20 |
DEFAULT_DILATION_VALUES_FOR_RESNET = (False, False, True) |
|
|
21 |
ImageEncoderOutputType = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] |
|
|
22 |
|
|
|
23 |
|
|
|
24 |
class ImageEncoder(nn.Module): |
|
|
25 |
"""Image encoder trunk module for the ``ImageModel`` class. |
|
|
26 |
|
|
|
27 |
:param img_encoder_type : Type of image encoder model to use, either ``"resnet18_multi_image"`` or |
|
|
28 |
``"resnet50_multi_image"``. |
|
|
29 |
""" |
|
|
30 |
|
|
|
31 |
def __init__(self, img_encoder_type: str): |
|
|
32 |
super().__init__() |
|
|
33 |
self.img_encoder_type = img_encoder_type |
|
|
34 |
self.encoder = self._create_encoder() |
|
|
35 |
|
|
|
36 |
def _create_encoder(self, **kwargs: Any) -> nn.Module: |
|
|
37 |
if self.img_encoder_type in [ImageEncoderType.RESNET18, ImageEncoderType.RESNET18_MULTI_IMAGE]: |
|
|
38 |
encoder_class = resnet18 |
|
|
39 |
elif self.img_encoder_type in [ImageEncoderType.RESNET50, ImageEncoderType.RESNET50_MULTI_IMAGE]: |
|
|
40 |
encoder_class = resnet50 |
|
|
41 |
else: |
|
|
42 |
supported = ImageEncoderType.get_members(multi_image_encoders_only=False) |
|
|
43 |
raise NotImplementedError(f"Image encoder type \"{self.img_encoder_type}\" must be in {supported}") |
|
|
44 |
|
|
|
45 |
encoder = encoder_class(pretrained=True, **kwargs) |
|
|
46 |
|
|
|
47 |
return encoder |
|
|
48 |
|
|
|
49 |
def forward(self, |
|
|
50 |
current_image: torch.Tensor, |
|
|
51 |
return_patch_embeddings: bool = False) -> ImageEncoderOutputType: |
|
|
52 |
"""Get image global and patch embeddings""" |
|
|
53 |
|
|
|
54 |
patch_emb = self.encoder(current_image) |
|
|
55 |
avg_pooled_emb = torch.flatten(torch.nn.functional.adaptive_avg_pool2d(patch_emb, (1, 1)), 1) |
|
|
56 |
if return_patch_embeddings: |
|
|
57 |
return patch_emb, avg_pooled_emb |
|
|
58 |
|
|
|
59 |
return avg_pooled_emb |
|
|
60 |
|
|
|
61 |
def reload_encoder_with_dilation(self, replace_stride_with_dilation: Optional[Sequence[bool]] = None) -> None: |
|
|
62 |
"""Workaround for enabling dilated convolutions after model initialization. |
|
|
63 |
|
|
|
64 |
:param replace_stride_with_dilation: Replace the 2x2 standard convolution stride with a dilated convolution |
|
|
65 |
in each layer in the last three blocks of ResNet architecture. |
|
|
66 |
""" |
|
|
67 |
if self.img_encoder_type == ImageEncoderType.RESNET18: |
|
|
68 |
# resnet18 uses BasicBlock implementation, which does not support dilated convolutions. |
|
|
69 |
raise NotImplementedError("resnet18 does not support dilated convolutions") |
|
|
70 |
|
|
|
71 |
if replace_stride_with_dilation is None: |
|
|
72 |
replace_stride_with_dilation = DEFAULT_DILATION_VALUES_FOR_RESNET |
|
|
73 |
|
|
|
74 |
device = next(self.encoder.parameters()).device |
|
|
75 |
new_encoder = self._create_encoder(replace_stride_with_dilation=replace_stride_with_dilation).to(device) |
|
|
76 |
|
|
|
77 |
if self.encoder.training: |
|
|
78 |
new_encoder.train() |
|
|
79 |
else: |
|
|
80 |
new_encoder.eval() |
|
|
81 |
|
|
|
82 |
new_encoder.load_state_dict(self.encoder.state_dict()) |
|
|
83 |
self.encoder = new_encoder |
|
|
84 |
|
|
|
85 |
|
|
|
86 |
class MultiImageEncoder(ImageEncoder): |
|
|
87 |
"""Multi-image encoder trunk module for the ``ImageModel`` class. |
|
|
88 |
It can be used to encode multiple images into combined latent representation. |
|
|
89 |
Currently it only supports two input images but can be extended to support more in future. |
|
|
90 |
|
|
|
91 |
:param img_encoder_type: Type of image encoder model to use: either ``"resnet18"`` or ``"resnet50"``. |
|
|
92 |
""" |
|
|
93 |
|
|
|
94 |
def __init__(self, img_encoder_type: str): |
|
|
95 |
super().__init__(img_encoder_type) |
|
|
96 |
|
|
|
97 |
output_dim = 256 # The aggregate feature dim of the encoder is `2 * output_dim` i.e. [f_static, f_diff] |
|
|
98 |
grid_shape = (14, 14) # Spatial dimensions of patch grid. |
|
|
99 |
|
|
|
100 |
backbone_output_feature_dim = get_encoder_output_dim(self.encoder, device=get_module_device(self)) |
|
|
101 |
|
|
|
102 |
self.backbone_to_vit = nn.Conv2d(in_channels=backbone_output_feature_dim, out_channels=output_dim, |
|
|
103 |
kernel_size=1, stride=1, padding=0, bias=False) |
|
|
104 |
self.vit_pooler = VisionTransformerPooler(input_dim=output_dim, grid_shape=grid_shape) |
|
|
105 |
|
|
|
106 |
# Missing image embedding |
|
|
107 |
self.missing_previous_emb = nn.Parameter(torch.zeros(1, output_dim, 1, 1)) |
|
|
108 |
trunc_normal_(self.missing_previous_emb, std=.02) |
|
|
109 |
|
|
|
110 |
def forward(self, # type: ignore[override] |
|
|
111 |
current_image: torch.Tensor, |
|
|
112 |
previous_image: Optional[torch.Tensor] = None, |
|
|
113 |
return_patch_embeddings: bool = False) -> ImageEncoderOutputType: |
|
|
114 |
|
|
|
115 |
batch_size = current_image.shape[0] |
|
|
116 |
|
|
|
117 |
if previous_image is not None: |
|
|
118 |
assert current_image.shape == previous_image.shape |
|
|
119 |
x = torch.cat([current_image, previous_image], dim=0) |
|
|
120 |
x = super().forward(x, return_patch_embeddings=True)[0] |
|
|
121 |
x = self.backbone_to_vit(x) |
|
|
122 |
patch_x, patch_x_previous = x[:batch_size], x[batch_size:] |
|
|
123 |
diff_x = self.vit_pooler(current_image=patch_x, previous_image=patch_x_previous) |
|
|
124 |
else: |
|
|
125 |
x = super().forward(current_image, return_patch_embeddings=True)[0] |
|
|
126 |
patch_x = self.backbone_to_vit(x) |
|
|
127 |
B, _, W, H = patch_x.shape |
|
|
128 |
diff_x = self.missing_previous_emb.repeat(B, 1, W, H) |
|
|
129 |
|
|
|
130 |
patch_fused = torch.cat([patch_x, diff_x], dim=1) |
|
|
131 |
avg_pooled_emb = torch.flatten(torch.nn.functional.adaptive_avg_pool2d(patch_fused, (1, 1)), 1) |
|
|
132 |
|
|
|
133 |
if return_patch_embeddings: |
|
|
134 |
return patch_fused, avg_pooled_emb |
|
|
135 |
|
|
|
136 |
return avg_pooled_emb |
|
|
137 |
|
|
|
138 |
def reload_encoder_with_dilation(self, replace_stride_with_dilation: Optional[Sequence[bool]] = None) -> None: |
|
|
139 |
raise NotImplementedError |
|
|
140 |
|
|
|
141 |
|
|
|
142 |
@torch.no_grad() |
|
|
143 |
def get_encoder_output_dim(module: torch.nn.Module, device: torch.device) -> int: |
|
|
144 |
"""Calculate the output dimension of an encoder by making a single forward pass. |
|
|
145 |
|
|
|
146 |
:param module: Encoder module. |
|
|
147 |
:param device: Compute device to use. |
|
|
148 |
""" |
|
|
149 |
# Target device |
|
|
150 |
assert isinstance(device, torch.device) |
|
|
151 |
|
|
|
152 |
x = torch.rand((1, 3, 448, 448)).to(device) |
|
|
153 |
|
|
|
154 |
# Extract the number of output feature dimensions |
|
|
155 |
with restore_training_mode(module): |
|
|
156 |
module.eval() |
|
|
157 |
representations = module(x) |
|
|
158 |
return representations.shape[1] |
|
|
159 |
|
|
|
160 |
|
|
|
161 |
@contextmanager |
|
|
162 |
def restore_training_mode(module: nn.Module) -> Generator[None, None, None]: |
|
|
163 |
"""Restore the training mode of a module after some operation. |
|
|
164 |
|
|
|
165 |
:param module: PyTorch module. |
|
|
166 |
""" |
|
|
167 |
training_mode = module.training |
|
|
168 |
yield |
|
|
169 |
module.train(mode=training_mode) |
|
|
170 |
|
|
|
171 |
|
|
|
172 |
def get_encoder_from_type(img_encoder_type: str) -> ImageEncoder: |
|
|
173 |
"""Returns the encoder class for the given encoder type. |
|
|
174 |
|
|
|
175 |
:param img_encoder_type: Encoder type. {RESNET18, RESNET50, RESNET18_MULTI_IMAGE, RESNET50_MULTI_IMAGE} |
|
|
176 |
""" |
|
|
177 |
if img_encoder_type in ImageEncoderType.get_members(multi_image_encoders_only=True): |
|
|
178 |
return MultiImageEncoder(img_encoder_type=img_encoder_type) |
|
|
179 |
else: |
|
|
180 |
return ImageEncoder(img_encoder_type=img_encoder_type) |