|
a |
|
b/biovil_t/model.py |
|
|
1 |
# ------------------------------------------------------------------------------------------- |
|
|
2 |
# Copyright (c) Microsoft Corporation. All rights reserved. |
|
|
3 |
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. |
|
|
4 |
# ------------------------------------------------------------------------------------------- |
|
|
5 |
|
|
|
6 |
from __future__ import annotations |
|
|
7 |
|
|
|
8 |
from abc import ABC, abstractmethod |
|
|
9 |
from pathlib import Path |
|
|
10 |
from typing import Any, Optional, Union |
|
|
11 |
|
|
|
12 |
import torch |
|
|
13 |
import torch.nn as nn |
|
|
14 |
import torch.nn.functional as F |
|
|
15 |
from health_multimodal.common.device import get_module_device |
|
|
16 |
|
|
|
17 |
from .encoder import get_encoder_from_type, get_encoder_output_dim, MultiImageEncoder |
|
|
18 |
from .modules import MLP, MultiTaskModel |
|
|
19 |
from .types import ImageModelOutput |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
class BaseImageModel(nn.Module, ABC): |
|
|
23 |
"""Abstract class for image models.""" |
|
|
24 |
@abstractmethod |
|
|
25 |
def forward(self, *args: Any, **kwargs: Any) -> ImageModelOutput: |
|
|
26 |
raise NotImplementedError |
|
|
27 |
|
|
|
28 |
@abstractmethod |
|
|
29 |
def get_patchwise_projected_embeddings(self, input_img: torch.Tensor, normalize: bool) -> torch.Tensor: |
|
|
30 |
raise NotImplementedError |
|
|
31 |
|
|
|
32 |
|
|
|
33 |
class ImageModel(BaseImageModel): |
|
|
34 |
"""Image encoder module""" |
|
|
35 |
|
|
|
36 |
def __init__(self, |
|
|
37 |
img_encoder_type: str, |
|
|
38 |
joint_feature_size: int, |
|
|
39 |
freeze_encoder: bool = False, |
|
|
40 |
pretrained_model_path: Optional[Union[str, Path]] = None, |
|
|
41 |
**downstream_classifier_kwargs: Any): |
|
|
42 |
super().__init__() |
|
|
43 |
|
|
|
44 |
# Initiate encoder, projector, and classifier |
|
|
45 |
self.encoder = get_encoder_from_type(img_encoder_type) |
|
|
46 |
self.feature_size = get_encoder_output_dim(self.encoder, device=get_module_device(self.encoder)) |
|
|
47 |
self.projector = MLP(input_dim=self.feature_size, output_dim=joint_feature_size, |
|
|
48 |
hidden_dim=joint_feature_size, use_1x1_convs=True) |
|
|
49 |
self.downstream_classifier_kwargs = downstream_classifier_kwargs |
|
|
50 |
self.classifier = self.create_downstream_classifier() if downstream_classifier_kwargs else None |
|
|
51 |
|
|
|
52 |
# Initialise the mode of modules |
|
|
53 |
self.freeze_encoder = freeze_encoder |
|
|
54 |
self.train() |
|
|
55 |
|
|
|
56 |
if pretrained_model_path is not None: |
|
|
57 |
if not isinstance(pretrained_model_path, (str, Path)): |
|
|
58 |
raise TypeError(f"Expected a string or Path, got {type(pretrained_model_path)}") |
|
|
59 |
state_dict = torch.load(pretrained_model_path, map_location="cpu") |
|
|
60 |
# drop projector |
|
|
61 |
for k in list(state_dict.keys()): |
|
|
62 |
if k.startswith("projector"): |
|
|
63 |
state_dict.pop(k) |
|
|
64 |
|
|
|
65 |
self.load_state_dict(state_dict, strict=False) |
|
|
66 |
|
|
|
67 |
|
|
|
68 |
def train(self, mode: bool = True) -> Any: |
|
|
69 |
"""Switch the model between training and evaluation modes.""" |
|
|
70 |
super().train(mode=mode) |
|
|
71 |
if self.freeze_encoder: |
|
|
72 |
self.encoder.train(mode=False) |
|
|
73 |
self.projector.train(mode=False) |
|
|
74 |
return self |
|
|
75 |
|
|
|
76 |
def forward(self, x: torch.Tensor) -> ImageModelOutput: # type: ignore[override] |
|
|
77 |
with torch.set_grad_enabled(not self.freeze_encoder): |
|
|
78 |
patch_x, pooled_x = self.encoder(x, return_patch_embeddings=True) |
|
|
79 |
return self.forward_post_encoder(patch_x, pooled_x) |
|
|
80 |
|
|
|
81 |
def forward_post_encoder(self, patch_x: torch.Tensor, pooled_x: torch.Tensor) -> ImageModelOutput: |
|
|
82 |
with torch.set_grad_enabled(not self.freeze_encoder): |
|
|
83 |
projected_patch_embeddings = self.projector(patch_x) |
|
|
84 |
projected_global_embedding = torch.mean(projected_patch_embeddings, dim=(2, 3)) |
|
|
85 |
|
|
|
86 |
logits = self.classifier(pooled_x) if self.classifier else None |
|
|
87 |
return ImageModelOutput(img_embedding=pooled_x, |
|
|
88 |
patch_embeddings=patch_x, |
|
|
89 |
class_logits=logits, |
|
|
90 |
projected_patch_embeddings=projected_patch_embeddings, |
|
|
91 |
projected_global_embedding=projected_global_embedding) |
|
|
92 |
|
|
|
93 |
def create_downstream_classifier(self, **kwargs: Any) -> MultiTaskModel: |
|
|
94 |
"""Create the classification module for the downstream task.""" |
|
|
95 |
downstream_classifier_kwargs = kwargs if kwargs else self.downstream_classifier_kwargs |
|
|
96 |
return MultiTaskModel(self.feature_size, **downstream_classifier_kwargs) |
|
|
97 |
|
|
|
98 |
@torch.no_grad() |
|
|
99 |
def get_patchwise_projected_embeddings(self, input_img: torch.Tensor, normalize: bool) -> torch.Tensor: |
|
|
100 |
"""Get patch-wise projected embeddings from the CNN model. |
|
|
101 |
|
|
|
102 |
:param input_img: input tensor image [B, C, H, W]. |
|
|
103 |
:param normalize: If ``True``, the embeddings are L2-normalized. |
|
|
104 |
:returns projected_embeddings: tensor of embeddings in shape [batch, n_patches_h, n_patches_w, feature_size]. |
|
|
105 |
""" |
|
|
106 |
assert not self.training, "This function is only implemented for evaluation mode" |
|
|
107 |
outputs = self.forward(input_img) |
|
|
108 |
projected_embeddings = outputs.projected_patch_embeddings.detach() # type: ignore |
|
|
109 |
if normalize: |
|
|
110 |
projected_embeddings = F.normalize(projected_embeddings, dim=1) |
|
|
111 |
projected_embeddings = projected_embeddings.permute([0, 2, 3, 1]) # B D H W -> B H W D (D: Features) |
|
|
112 |
return projected_embeddings |
|
|
113 |
|
|
|
114 |
|
|
|
115 |
class MultiImageModel(ImageModel): |
|
|
116 |
def __init__(self, **kwargs: Any) -> None: |
|
|
117 |
super().__init__(**kwargs) |
|
|
118 |
assert isinstance(self.encoder, MultiImageEncoder), "MultiImageModel only supports MultiImageEncoder" |
|
|
119 |
|
|
|
120 |
def forward(self, # type: ignore[override] |
|
|
121 |
current_image: torch.Tensor, |
|
|
122 |
previous_image: Optional[torch.Tensor] = None) -> ImageModelOutput: |
|
|
123 |
|
|
|
124 |
with torch.set_grad_enabled(not self.freeze_encoder): |
|
|
125 |
patch_x, pooled_x = self.encoder(current_image=current_image, |
|
|
126 |
previous_image=previous_image, |
|
|
127 |
return_patch_embeddings=True) |
|
|
128 |
return self.forward_post_encoder(patch_x, pooled_x) |