Diff of /biovil_t/model.py [000000] .. [4abb48]

Switch to unified view

a b/biovil_t/model.py
1
#  -------------------------------------------------------------------------------------------
2
#  Copyright (c) Microsoft Corporation. All rights reserved.
3
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
#  -------------------------------------------------------------------------------------------
5
6
from __future__ import annotations
7
8
from abc import ABC, abstractmethod
9
from pathlib import Path
10
from typing import Any, Optional, Union
11
12
import torch
13
import torch.nn as nn
14
import torch.nn.functional as F
15
from health_multimodal.common.device import get_module_device
16
17
from .encoder import get_encoder_from_type, get_encoder_output_dim, MultiImageEncoder
18
from .modules import MLP, MultiTaskModel
19
from .types import ImageModelOutput
20
21
22
class BaseImageModel(nn.Module, ABC):
23
    """Abstract class for image models."""
24
    @abstractmethod
25
    def forward(self, *args: Any, **kwargs: Any) -> ImageModelOutput:
26
        raise NotImplementedError
27
28
    @abstractmethod
29
    def get_patchwise_projected_embeddings(self, input_img: torch.Tensor, normalize: bool) -> torch.Tensor:
30
        raise NotImplementedError
31
32
33
class ImageModel(BaseImageModel):
34
    """Image encoder module"""
35
36
    def __init__(self,
37
                 img_encoder_type: str,
38
                 joint_feature_size: int,
39
                 freeze_encoder: bool = False,
40
                 pretrained_model_path: Optional[Union[str, Path]] = None,
41
                 **downstream_classifier_kwargs: Any):
42
        super().__init__()
43
44
        # Initiate encoder, projector, and classifier
45
        self.encoder = get_encoder_from_type(img_encoder_type)
46
        self.feature_size = get_encoder_output_dim(self.encoder, device=get_module_device(self.encoder))
47
        self.projector = MLP(input_dim=self.feature_size, output_dim=joint_feature_size,
48
                             hidden_dim=joint_feature_size, use_1x1_convs=True)
49
        self.downstream_classifier_kwargs = downstream_classifier_kwargs
50
        self.classifier = self.create_downstream_classifier() if downstream_classifier_kwargs else None
51
52
        # Initialise the mode of modules
53
        self.freeze_encoder = freeze_encoder
54
        self.train()
55
56
        if pretrained_model_path is not None:
57
            if not isinstance(pretrained_model_path, (str, Path)):
58
                raise TypeError(f"Expected a string or Path, got {type(pretrained_model_path)}")
59
            state_dict = torch.load(pretrained_model_path, map_location="cpu")
60
            # drop projector
61
            for k in list(state_dict.keys()):
62
                if k.startswith("projector"):
63
                    state_dict.pop(k)
64
65
            self.load_state_dict(state_dict, strict=False)
66
67
68
    def train(self, mode: bool = True) -> Any:
69
        """Switch the model between training and evaluation modes."""
70
        super().train(mode=mode)
71
        if self.freeze_encoder:
72
            self.encoder.train(mode=False)
73
            self.projector.train(mode=False)
74
        return self
75
76
    def forward(self, x: torch.Tensor) -> ImageModelOutput:  # type: ignore[override]
77
        with torch.set_grad_enabled(not self.freeze_encoder):
78
            patch_x, pooled_x = self.encoder(x, return_patch_embeddings=True)
79
        return self.forward_post_encoder(patch_x, pooled_x)
80
81
    def forward_post_encoder(self, patch_x: torch.Tensor, pooled_x: torch.Tensor) -> ImageModelOutput:
82
        with torch.set_grad_enabled(not self.freeze_encoder):
83
            projected_patch_embeddings = self.projector(patch_x)
84
            projected_global_embedding = torch.mean(projected_patch_embeddings, dim=(2, 3))
85
86
        logits = self.classifier(pooled_x) if self.classifier else None
87
        return ImageModelOutput(img_embedding=pooled_x,
88
                                patch_embeddings=patch_x,
89
                                class_logits=logits,
90
                                projected_patch_embeddings=projected_patch_embeddings,
91
                                projected_global_embedding=projected_global_embedding)
92
93
    def create_downstream_classifier(self, **kwargs: Any) -> MultiTaskModel:
94
        """Create the classification module for the downstream task."""
95
        downstream_classifier_kwargs = kwargs if kwargs else self.downstream_classifier_kwargs
96
        return MultiTaskModel(self.feature_size, **downstream_classifier_kwargs)
97
98
    @torch.no_grad()
99
    def get_patchwise_projected_embeddings(self, input_img: torch.Tensor, normalize: bool) -> torch.Tensor:
100
        """Get patch-wise projected embeddings from the CNN model.
101
102
        :param input_img: input tensor image [B, C, H, W].
103
        :param normalize: If ``True``, the embeddings are L2-normalized.
104
        :returns projected_embeddings: tensor of embeddings in shape [batch, n_patches_h, n_patches_w, feature_size].
105
        """
106
        assert not self.training, "This function is only implemented for evaluation mode"
107
        outputs = self.forward(input_img)
108
        projected_embeddings = outputs.projected_patch_embeddings.detach()  # type: ignore
109
        if normalize:
110
            projected_embeddings = F.normalize(projected_embeddings, dim=1)
111
        projected_embeddings = projected_embeddings.permute([0, 2, 3, 1])  # B D H W -> B H W D (D: Features)
112
        return projected_embeddings
113
114
115
class MultiImageModel(ImageModel):
116
    def __init__(self, **kwargs: Any) -> None:
117
        super().__init__(**kwargs)
118
        assert isinstance(self.encoder, MultiImageEncoder), "MultiImageModel only supports MultiImageEncoder"
119
120
    def forward(self,  # type: ignore[override]
121
                current_image: torch.Tensor,
122
                previous_image: Optional[torch.Tensor] = None) -> ImageModelOutput:
123
124
        with torch.set_grad_enabled(not self.freeze_encoder):
125
            patch_x, pooled_x = self.encoder(current_image=current_image,
126
                                             previous_image=previous_image,
127
                                             return_patch_embeddings=True)
128
        return self.forward_post_encoder(patch_x, pooled_x)