Diff of /biovil_t/resnet.py [000000] .. [4abb48]

Switch to unified view

a b/biovil_t/resnet.py
1
#  -------------------------------------------------------------------------------------------
2
#  Copyright (c) Microsoft Corporation. All rights reserved.
3
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
#  -------------------------------------------------------------------------------------------
5
6
from typing import Any, List, Tuple, Type, Union
7
8
import torch
9
from torch.hub import load_state_dict_from_url
10
from torchvision.models.resnet import model_urls, ResNet, BasicBlock, Bottleneck
11
12
TypeSkipConnections = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
13
14
15
class ResNetHIML(ResNet):
16
    """Wrapper class of the original torchvision ResNet model.
17
18
    The forward function is updated to return the penultimate layer
19
    activations, which are required to obtain image patch embeddings.
20
    """
21
22
    def __init__(self, **kwargs: Any) -> None:
23
        super().__init__(**kwargs)
24
25
    def forward(self, x: torch.Tensor,
26
                return_intermediate_layers: bool = False) -> Union[torch.Tensor, TypeSkipConnections]:
27
        """ResNetHIML forward pass. Optionally returns intermediate layers using the
28
        ``return_intermediate_layers`` argument.
29
30
        :param return_intermediate_layers: If ``True``, return layers x0-x4 as a tuple,
31
            otherwise return x4 only.
32
        """
33
34
        x0 = self.conv1(x)
35
        x0 = self.bn1(x0)
36
        x0 = self.relu(x0)
37
        x0 = self.maxpool(x0)
38
39
        x1 = self.layer1(x0)
40
        x2 = self.layer2(x1)
41
        x3 = self.layer3(x2)
42
        x4 = self.layer4(x3)
43
44
        if return_intermediate_layers:
45
            return x0, x1, x2, x3, x4
46
        else:
47
            return x4
48
49
50
def _resnet(arch: str, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int],
51
            pretrained: bool, progress: bool, **kwargs: Any) -> ResNetHIML:
52
    """Instantiate a custom :class:`ResNet` model.
53
54
    Adapted from :mod:`torchvision.models.resnet`.
55
    """
56
    model = ResNetHIML(block=block, layers=layers, **kwargs)
57
    if pretrained:
58
        state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
59
        model.load_state_dict(state_dict)
60
    return model
61
62
63
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNetHIML:
64
    r"""ResNet-18 model from
65
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
66
67
    :param pretrained: If ``True``, returns a model pre-trained on ImageNet.
68
    :param progress: If ``True``, displays a progress bar of the download to ``stderr``.
69
    """
70
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
71
72
73
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNetHIML:
74
    r"""ResNet-50 model from
75
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
76
77
    :param pretrained: If ``True``, returns a model pre-trained on ImageNet
78
    :param progress: If ``True``, displays a progress bar of the download to ``stderr``.
79
    """
80
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)