Diff of /biovil_t/encoder.py [000000] .. [4abb48]

Switch to unified view

a b/biovil_t/encoder.py
1
#  -------------------------------------------------------------------------------------------
2
#  Copyright (c) Microsoft Corporation. All rights reserved.
3
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
#  -------------------------------------------------------------------------------------------
5
6
from __future__ import annotations
7
8
from contextlib import contextmanager
9
from typing import Any, Generator, Optional, Sequence, Tuple, Union
10
11
import torch
12
import torch.nn as nn
13
from health_multimodal.common.device import get_module_device
14
from timm.models.layers import trunc_normal_
15
16
from .resnet import resnet18, resnet50
17
from .transformer import VisionTransformerPooler
18
from .types import ImageEncoderType
19
20
DEFAULT_DILATION_VALUES_FOR_RESNET = (False, False, True)
21
ImageEncoderOutputType = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
22
23
24
class ImageEncoder(nn.Module):
25
    """Image encoder trunk module for the ``ImageModel`` class.
26
27
    :param img_encoder_type : Type of image encoder model to use, either ``"resnet18_multi_image"`` or
28
                              ``"resnet50_multi_image"``.
29
    """
30
31
    def __init__(self, img_encoder_type: str):
32
        super().__init__()
33
        self.img_encoder_type = img_encoder_type
34
        self.encoder = self._create_encoder()
35
36
    def _create_encoder(self, **kwargs: Any) -> nn.Module:
37
        if self.img_encoder_type in [ImageEncoderType.RESNET18, ImageEncoderType.RESNET18_MULTI_IMAGE]:
38
            encoder_class = resnet18
39
        elif self.img_encoder_type in [ImageEncoderType.RESNET50, ImageEncoderType.RESNET50_MULTI_IMAGE]:
40
            encoder_class = resnet50
41
        else:
42
            supported = ImageEncoderType.get_members(multi_image_encoders_only=False)
43
            raise NotImplementedError(f"Image encoder type \"{self.img_encoder_type}\" must be in {supported}")
44
45
        encoder = encoder_class(pretrained=True, **kwargs)
46
47
        return encoder
48
49
    def forward(self,
50
                current_image: torch.Tensor,
51
                return_patch_embeddings: bool = False) -> ImageEncoderOutputType:
52
        """Get image global and patch embeddings"""
53
54
        patch_emb = self.encoder(current_image)
55
        avg_pooled_emb = torch.flatten(torch.nn.functional.adaptive_avg_pool2d(patch_emb, (1, 1)), 1)
56
        if return_patch_embeddings:
57
            return patch_emb, avg_pooled_emb
58
59
        return avg_pooled_emb
60
61
    def reload_encoder_with_dilation(self, replace_stride_with_dilation: Optional[Sequence[bool]] = None) -> None:
62
        """Workaround for enabling dilated convolutions after model initialization.
63
64
        :param replace_stride_with_dilation: Replace the 2x2 standard convolution stride with a dilated convolution
65
                                             in each layer in the last three blocks of ResNet architecture.
66
        """
67
        if self.img_encoder_type == ImageEncoderType.RESNET18:
68
            # resnet18 uses BasicBlock implementation, which does not support dilated convolutions.
69
            raise NotImplementedError("resnet18 does not support dilated convolutions")
70
71
        if replace_stride_with_dilation is None:
72
            replace_stride_with_dilation = DEFAULT_DILATION_VALUES_FOR_RESNET
73
74
        device = next(self.encoder.parameters()).device
75
        new_encoder = self._create_encoder(replace_stride_with_dilation=replace_stride_with_dilation).to(device)
76
77
        if self.encoder.training:
78
            new_encoder.train()
79
        else:
80
            new_encoder.eval()
81
82
        new_encoder.load_state_dict(self.encoder.state_dict())
83
        self.encoder = new_encoder
84
85
86
class MultiImageEncoder(ImageEncoder):
87
    """Multi-image encoder trunk module for the ``ImageModel`` class.
88
    It can be used to encode multiple images into combined latent representation.
89
    Currently it only supports two input images but can be extended to support more in future.
90
91
    :param img_encoder_type: Type of image encoder model to use: either ``"resnet18"`` or ``"resnet50"``.
92
    """
93
94
    def __init__(self, img_encoder_type: str):
95
        super().__init__(img_encoder_type)
96
97
        output_dim = 256  # The aggregate feature dim of the encoder is `2 * output_dim` i.e. [f_static, f_diff]
98
        grid_shape = (14, 14)  # Spatial dimensions of patch grid.
99
100
        backbone_output_feature_dim = get_encoder_output_dim(self.encoder, device=get_module_device(self))
101
102
        self.backbone_to_vit = nn.Conv2d(in_channels=backbone_output_feature_dim, out_channels=output_dim,
103
                                         kernel_size=1, stride=1, padding=0, bias=False)
104
        self.vit_pooler = VisionTransformerPooler(input_dim=output_dim, grid_shape=grid_shape)
105
106
        # Missing image embedding
107
        self.missing_previous_emb = nn.Parameter(torch.zeros(1, output_dim, 1, 1))
108
        trunc_normal_(self.missing_previous_emb, std=.02)
109
110
    def forward(self,  # type: ignore[override]
111
                current_image: torch.Tensor,
112
                previous_image: Optional[torch.Tensor] = None,
113
                return_patch_embeddings: bool = False) -> ImageEncoderOutputType:
114
115
        batch_size = current_image.shape[0]
116
117
        if previous_image is not None:
118
            assert current_image.shape == previous_image.shape
119
            x = torch.cat([current_image, previous_image], dim=0)
120
            x = super().forward(x, return_patch_embeddings=True)[0]
121
            x = self.backbone_to_vit(x)
122
            patch_x, patch_x_previous = x[:batch_size], x[batch_size:]
123
            diff_x = self.vit_pooler(current_image=patch_x, previous_image=patch_x_previous)
124
        else:
125
            x = super().forward(current_image, return_patch_embeddings=True)[0]
126
            patch_x = self.backbone_to_vit(x)
127
            B, _, W, H = patch_x.shape
128
            diff_x = self.missing_previous_emb.repeat(B, 1, W, H)
129
130
        patch_fused = torch.cat([patch_x, diff_x], dim=1)
131
        avg_pooled_emb = torch.flatten(torch.nn.functional.adaptive_avg_pool2d(patch_fused, (1, 1)), 1)
132
133
        if return_patch_embeddings:
134
            return patch_fused, avg_pooled_emb
135
136
        return avg_pooled_emb
137
138
    def reload_encoder_with_dilation(self, replace_stride_with_dilation: Optional[Sequence[bool]] = None) -> None:
139
        raise NotImplementedError
140
141
142
@torch.no_grad()
143
def get_encoder_output_dim(module: torch.nn.Module, device: torch.device) -> int:
144
    """Calculate the output dimension of an encoder by making a single forward pass.
145
146
    :param module: Encoder module.
147
    :param device: Compute device to use.
148
    """
149
    # Target device
150
    assert isinstance(device, torch.device)
151
152
    x = torch.rand((1, 3, 448, 448)).to(device)
153
154
    # Extract the number of output feature dimensions
155
    with restore_training_mode(module):
156
        module.eval()
157
        representations = module(x)
158
    return representations.shape[1]
159
160
161
@contextmanager
162
def restore_training_mode(module: nn.Module) -> Generator[None, None, None]:
163
    """Restore the training mode of a module after some operation.
164
165
    :param module: PyTorch module.
166
    """
167
    training_mode = module.training
168
    yield
169
    module.train(mode=training_mode)
170
171
172
def get_encoder_from_type(img_encoder_type: str) -> ImageEncoder:
173
    """Returns the encoder class for the given encoder type.
174
175
    :param img_encoder_type: Encoder type. {RESNET18, RESNET50, RESNET18_MULTI_IMAGE, RESNET50_MULTI_IMAGE}
176
    """
177
    if img_encoder_type in ImageEncoderType.get_members(multi_image_encoders_only=True):
178
        return MultiImageEncoder(img_encoder_type=img_encoder_type)
179
    else:
180
        return ImageEncoder(img_encoder_type=img_encoder_type)