Switch to unified view

a b/reproducibility/embedders/mudipath.py
1
import os
2
import re
3
import sys
4
from abc import abstractmethod
5
from reproducibility.utils.cacher import cache_hit_or_miss, cache_numpy_object
6
from reproducibility.embedders.internal_datasets import CLIPImageDataset
7
from torch.utils.data import DataLoader
8
from torch.utils import model_zoo
9
from torchvision.models.resnet import ResNet, model_urls as resnet_urls, BasicBlock, Bottleneck
10
from torchvision.models.densenet import DenseNet, model_urls as densenet_urls
11
import torch.nn.functional as F
12
import numpy as np
13
from torch import nn
14
15
class FeaturesInterface(object):
16
    @abstractmethod
17
    def n_features(self):
18
        pass
19
20
import torch
21
from torch.hub import download_url_to_file
22
23
try:
24
    from requests.utils import urlparse
25
    from requests import get as urlopen
26
    requests_available = True
27
except ImportError:
28
    requests_available = False
29
    from urllib.request import urlopen
30
    from urllib.parse import urlparse
31
try:
32
    from tqdm import tqdm
33
except ImportError:
34
    tqdm = None  # defined below
35
36
37
def _remove_prefix(s, prefix):
38
    if s.startswith(prefix):
39
        s = s[len(prefix):]
40
    return s
41
42
43
def clean_state_dict(state_dict, prefix, filter=None):
44
    if filter is None:
45
        filter = lambda *args: True
46
    return {_remove_prefix(k, prefix): v for k, v in state_dict.items() if filter(k)}
47
48
49
def load_dox_url(url, filename, model_dir=None, map_location=None, progress=True):
50
    r"""Adapt to fit format file of mtdp pre-trained models
51
    """
52
    if model_dir is None:
53
        torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch'))
54
        model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models'))
55
    if not os.path.exists(model_dir):
56
        os.makedirs(model_dir)
57
    cached_file = os.path.join(model_dir, filename)
58
    if not os.path.exists(cached_file):
59
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
60
        sys.stderr.flush()
61
        download_url_to_file(url, cached_file, None, progress=progress)
62
    return torch.load(cached_file, map_location=map_location)
63
64
65
66
MTDRN_URLS = {
67
    "resnet50": ("https://dox.uliege.be/index.php/s/kvABLtVuMxW8iJy/download", "resnet50-mh-best-191205-141200.pth")
68
}
69
70
71
class NoHeadResNet(ResNet, FeaturesInterface):
72
    def forward(self, x):
73
        x = self.conv1(x)
74
        x = self.bn1(x)
75
        x = self.relu(x)
76
        x = self.maxpool(x)
77
78
        x = self.layer1(x)
79
        x = self.layer2(x)
80
        x = self.layer3(x)
81
        x = self.layer4(x)
82
        x = self.avgpool(x)
83
        return x
84
85
    def n_features(self):
86
        return [b for b in list(self.layer4[-1].children()) if hasattr(b, 'num_features')][-1].num_features
87
88
89
def build_resnet(download_dir, pretrained=None, arch="resnet50", model_class=NoHeadResNet, **kwargs):
90
    """Constructs a ResNet-18 model.
91
    Args:
92
        arch (str): Type of densenet (among: resnet18, resnet34, resnet50, resnet101 and resnet152)
93
        pretrained (str|None): If "imagenet", returns a model pre-trained on ImageNet. If "mtdp" returns a model
94
                              pre-trained in multi-task on digital pathology data. Otherwise (None), random weights.
95
        model_class (nn.Module): Actual resnet module class
96
    """
97
    params = {
98
        "resnet18": [BasicBlock, [2, 2, 2, 2]],
99
        "resnet34": [BasicBlock, [3, 4, 6, 3]],
100
        "resnet50": [Bottleneck, [3, 4, 6, 3]],
101
        "resnet101": [Bottleneck, [3, 4, 23, 3]],
102
        "resnet152":  [Bottleneck, [3, 8, 36, 3]]
103
    }
104
    model = model_class(*params[arch], **kwargs)
105
    if isinstance(pretrained, str):
106
        if pretrained == "imagenet":
107
            url = resnet_urls[arch]  # default imagenet
108
            state_dict = model_zoo.load_url(url)
109
        elif pretrained == "mtdp":
110
            if arch not in MTDRN_URLS:
111
                raise ValueError("No pretrained weights for multi task pretraining with architecture '{}'".format(arch))
112
            url, filename = MTDRN_URLS[arch]
113
            state_dict = load_dox_url(url, filename, model_dir=download_dir, map_location="cpu")
114
            state_dict = clean_state_dict(state_dict, prefix="features.", filter=lambda k: not k.startswith("heads."))
115
        else:
116
            raise ValueError("Unknown pre-training source")
117
        model.load_state_dict(state_dict)
118
    return model
119
120
MTDP_URLS = {
121
    "densenet121": ("https://dox.uliege.be/index.php/s/G72InP4xmJvOrVp/download", "densenet121-mh-best-191205-141200.pth")
122
}
123
124
125
class NoHeadDenseNet(DenseNet, FeaturesInterface):
126
    def forward(self, x):
127
        return F.adaptive_avg_pool2d(self.features(x), (1, 1))
128
129
    def n_features(self):
130
        return self.features[-1].num_features
131
132
133
def build_densenet(download_dir, pretrained=False, arch="densenet121", model_class=NoHeadDenseNet, **kwargs):
134
    r"""Densenet-XXX model from
135
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
136
    Args:
137
        arch (str): Type of densenet (among: densenet121, densenet169, densenet201 and densenet161)
138
        pretrained (str|None): If "imagenet", returns a model pre-trained on ImageNet. If "mtdp" returns a model pre-trained
139
                           in multi-task on digital pathology data. Otherwise (None), random weights.
140
        model_class (nn.Module): Actual densenet module class
141
    """
142
    params = {
143
        "densenet121": {"num_init_features": 64, "growth_rate": 32, "block_config": (6, 12, 24, 16)},
144
        "densenet169": {"num_init_features": 64, "growth_rate": 32, "block_config": (6, 12, 32, 32)},
145
        "densenet201": {"num_init_features": 64, "growth_rate": 32, "block_config": (6, 12, 48, 32)},
146
        "densenet161": {"num_init_features": 96, "growth_rate": 48, "block_config": (6, 12, 36, 24)}
147
    }
148
    model = model_class(**(params[arch]), **kwargs)
149
    if isinstance(pretrained, str):
150
        # '.'s are no longer allowed in module names, but pervious _DenseLayer
151
        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
152
        # They are also in the checkpoints in model_urls. This pattern is used
153
        # to find such keys.
154
        if pretrained == "imagenet":
155
            pattern = re.compile(
156
                r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
157
            state_dict = model_zoo.load_url(densenet_urls[arch])
158
            for key in list(state_dict.keys()):
159
                res = pattern.match(key)
160
                if res:
161
                    new_key = res.group(1) + res.group(2)
162
                    state_dict[new_key] = state_dict[key]
163
                    del state_dict[key]
164
        elif pretrained == "mtdp":
165
            if arch not in MTDP_URLS:
166
                raise ValueError("No pretrained weights for multi task pretraining with architecture '{}'".format(arch))
167
            url, filename = MTDP_URLS[arch]
168
            state_dict = load_dox_url(url, filename, model_dir=download_dir, map_location="cpu")
169
            state_dict = clean_state_dict(state_dict, prefix="features.", filter=lambda k: not k.startswith("heads."))
170
        else:
171
            raise ValueError("Unknown pre-training source")
172
        model.load_state_dict(state_dict)
173
    return model
174
175
176
class ResNetBottom(nn.Module):
177
    def __init__(self, original_model):
178
        super(ResNetBottom, self).__init__()
179
        self.features = nn.Sequential(*list(original_model.children())[:-1])
180
181
    def forward(self, x):
182
        x = self.features(x)
183
        x = torch.flatten(x, 1)
184
        return x
185
186
187
class DenseNetEmbedder:
188
    def __init__(self, model, preprocess, name, backbone):
189
        self.model = model
190
        self.preprocess = preprocess
191
        self.name = name
192
        self.backbone = backbone
193
194
    def image_embedder(self, list_of_images, device="cuda", num_workers=1, batch_size=32, additional_cache_name=""):
195
        # additional_cache_name: name of the validation dataset (e.g., Kather_7K)
196
        hit_or_miss = cache_hit_or_miss(self.name + "img" + additional_cache_name, self.backbone)
197
198
        if hit_or_miss is not None:
199
            return hit_or_miss
200
        else:
201
            hit = self.embed_images(list_of_images, device=device, num_workers=num_workers, batch_size=batch_size)
202
            cache_numpy_object(hit, self.name + "img" + additional_cache_name, self.backbone)
203
            return hit
204
205
    def embed_images(self, list_of_images, device="cuda", num_workers=1, batch_size=32):
206
        dataset = CLIPImageDataset(list_of_images, self.preprocess)
207
        dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
208
209
        all_embs = []
210
        for batch_X in tqdm(dataloader):
211
            batch_X = batch_X.to(device)
212
            embeddings = self.model(batch_X).detach().float().squeeze()
213
            embeddings = embeddings.detach().cpu().numpy()
214
            all_embs.append(embeddings)
215
        return np.concatenate(all_embs)
216
217