--- a +++ b/biovil_t/encoder.py @@ -0,0 +1,180 @@ +# ------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------- + +from __future__ import annotations + +from contextlib import contextmanager +from typing import Any, Generator, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +from health_multimodal.common.device import get_module_device +from timm.models.layers import trunc_normal_ + +from .resnet import resnet18, resnet50 +from .transformer import VisionTransformerPooler +from .types import ImageEncoderType + +DEFAULT_DILATION_VALUES_FOR_RESNET = (False, False, True) +ImageEncoderOutputType = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + + +class ImageEncoder(nn.Module): + """Image encoder trunk module for the ``ImageModel`` class. + + :param img_encoder_type : Type of image encoder model to use, either ``"resnet18_multi_image"`` or + ``"resnet50_multi_image"``. + """ + + def __init__(self, img_encoder_type: str): + super().__init__() + self.img_encoder_type = img_encoder_type + self.encoder = self._create_encoder() + + def _create_encoder(self, **kwargs: Any) -> nn.Module: + if self.img_encoder_type in [ImageEncoderType.RESNET18, ImageEncoderType.RESNET18_MULTI_IMAGE]: + encoder_class = resnet18 + elif self.img_encoder_type in [ImageEncoderType.RESNET50, ImageEncoderType.RESNET50_MULTI_IMAGE]: + encoder_class = resnet50 + else: + supported = ImageEncoderType.get_members(multi_image_encoders_only=False) + raise NotImplementedError(f"Image encoder type \"{self.img_encoder_type}\" must be in {supported}") + + encoder = encoder_class(pretrained=True, **kwargs) + + return encoder + + def forward(self, + current_image: torch.Tensor, + return_patch_embeddings: bool = False) -> ImageEncoderOutputType: + """Get image global and patch embeddings""" + + patch_emb = self.encoder(current_image) + avg_pooled_emb = torch.flatten(torch.nn.functional.adaptive_avg_pool2d(patch_emb, (1, 1)), 1) + if return_patch_embeddings: + return patch_emb, avg_pooled_emb + + return avg_pooled_emb + + def reload_encoder_with_dilation(self, replace_stride_with_dilation: Optional[Sequence[bool]] = None) -> None: + """Workaround for enabling dilated convolutions after model initialization. + + :param replace_stride_with_dilation: Replace the 2x2 standard convolution stride with a dilated convolution + in each layer in the last three blocks of ResNet architecture. + """ + if self.img_encoder_type == ImageEncoderType.RESNET18: + # resnet18 uses BasicBlock implementation, which does not support dilated convolutions. + raise NotImplementedError("resnet18 does not support dilated convolutions") + + if replace_stride_with_dilation is None: + replace_stride_with_dilation = DEFAULT_DILATION_VALUES_FOR_RESNET + + device = next(self.encoder.parameters()).device + new_encoder = self._create_encoder(replace_stride_with_dilation=replace_stride_with_dilation).to(device) + + if self.encoder.training: + new_encoder.train() + else: + new_encoder.eval() + + new_encoder.load_state_dict(self.encoder.state_dict()) + self.encoder = new_encoder + + +class MultiImageEncoder(ImageEncoder): + """Multi-image encoder trunk module for the ``ImageModel`` class. + It can be used to encode multiple images into combined latent representation. + Currently it only supports two input images but can be extended to support more in future. + + :param img_encoder_type: Type of image encoder model to use: either ``"resnet18"`` or ``"resnet50"``. + """ + + def __init__(self, img_encoder_type: str): + super().__init__(img_encoder_type) + + output_dim = 256 # The aggregate feature dim of the encoder is `2 * output_dim` i.e. [f_static, f_diff] + grid_shape = (14, 14) # Spatial dimensions of patch grid. + + backbone_output_feature_dim = get_encoder_output_dim(self.encoder, device=get_module_device(self)) + + self.backbone_to_vit = nn.Conv2d(in_channels=backbone_output_feature_dim, out_channels=output_dim, + kernel_size=1, stride=1, padding=0, bias=False) + self.vit_pooler = VisionTransformerPooler(input_dim=output_dim, grid_shape=grid_shape) + + # Missing image embedding + self.missing_previous_emb = nn.Parameter(torch.zeros(1, output_dim, 1, 1)) + trunc_normal_(self.missing_previous_emb, std=.02) + + def forward(self, # type: ignore[override] + current_image: torch.Tensor, + previous_image: Optional[torch.Tensor] = None, + return_patch_embeddings: bool = False) -> ImageEncoderOutputType: + + batch_size = current_image.shape[0] + + if previous_image is not None: + assert current_image.shape == previous_image.shape + x = torch.cat([current_image, previous_image], dim=0) + x = super().forward(x, return_patch_embeddings=True)[0] + x = self.backbone_to_vit(x) + patch_x, patch_x_previous = x[:batch_size], x[batch_size:] + diff_x = self.vit_pooler(current_image=patch_x, previous_image=patch_x_previous) + else: + x = super().forward(current_image, return_patch_embeddings=True)[0] + patch_x = self.backbone_to_vit(x) + B, _, W, H = patch_x.shape + diff_x = self.missing_previous_emb.repeat(B, 1, W, H) + + patch_fused = torch.cat([patch_x, diff_x], dim=1) + avg_pooled_emb = torch.flatten(torch.nn.functional.adaptive_avg_pool2d(patch_fused, (1, 1)), 1) + + if return_patch_embeddings: + return patch_fused, avg_pooled_emb + + return avg_pooled_emb + + def reload_encoder_with_dilation(self, replace_stride_with_dilation: Optional[Sequence[bool]] = None) -> None: + raise NotImplementedError + + +@torch.no_grad() +def get_encoder_output_dim(module: torch.nn.Module, device: torch.device) -> int: + """Calculate the output dimension of an encoder by making a single forward pass. + + :param module: Encoder module. + :param device: Compute device to use. + """ + # Target device + assert isinstance(device, torch.device) + + x = torch.rand((1, 3, 448, 448)).to(device) + + # Extract the number of output feature dimensions + with restore_training_mode(module): + module.eval() + representations = module(x) + return representations.shape[1] + + +@contextmanager +def restore_training_mode(module: nn.Module) -> Generator[None, None, None]: + """Restore the training mode of a module after some operation. + + :param module: PyTorch module. + """ + training_mode = module.training + yield + module.train(mode=training_mode) + + +def get_encoder_from_type(img_encoder_type: str) -> ImageEncoder: + """Returns the encoder class for the given encoder type. + + :param img_encoder_type: Encoder type. {RESNET18, RESNET50, RESNET18_MULTI_IMAGE, RESNET50_MULTI_IMAGE} + """ + if img_encoder_type in ImageEncoderType.get_members(multi_image_encoders_only=True): + return MultiImageEncoder(img_encoder_type=img_encoder_type) + else: + return ImageEncoder(img_encoder_type=img_encoder_type)