|
a |
|
b/reproducibility/embedders/mudipath.py |
|
|
1 |
import os |
|
|
2 |
import re |
|
|
3 |
import sys |
|
|
4 |
from abc import abstractmethod |
|
|
5 |
from reproducibility.utils.cacher import cache_hit_or_miss, cache_numpy_object |
|
|
6 |
from reproducibility.embedders.internal_datasets import CLIPImageDataset |
|
|
7 |
from torch.utils.data import DataLoader |
|
|
8 |
from torch.utils import model_zoo |
|
|
9 |
from torchvision.models.resnet import ResNet, model_urls as resnet_urls, BasicBlock, Bottleneck |
|
|
10 |
from torchvision.models.densenet import DenseNet, model_urls as densenet_urls |
|
|
11 |
import torch.nn.functional as F |
|
|
12 |
import numpy as np |
|
|
13 |
from torch import nn |
|
|
14 |
|
|
|
15 |
class FeaturesInterface(object): |
|
|
16 |
@abstractmethod |
|
|
17 |
def n_features(self): |
|
|
18 |
pass |
|
|
19 |
|
|
|
20 |
import torch |
|
|
21 |
from torch.hub import download_url_to_file |
|
|
22 |
|
|
|
23 |
try: |
|
|
24 |
from requests.utils import urlparse |
|
|
25 |
from requests import get as urlopen |
|
|
26 |
requests_available = True |
|
|
27 |
except ImportError: |
|
|
28 |
requests_available = False |
|
|
29 |
from urllib.request import urlopen |
|
|
30 |
from urllib.parse import urlparse |
|
|
31 |
try: |
|
|
32 |
from tqdm import tqdm |
|
|
33 |
except ImportError: |
|
|
34 |
tqdm = None # defined below |
|
|
35 |
|
|
|
36 |
|
|
|
37 |
def _remove_prefix(s, prefix): |
|
|
38 |
if s.startswith(prefix): |
|
|
39 |
s = s[len(prefix):] |
|
|
40 |
return s |
|
|
41 |
|
|
|
42 |
|
|
|
43 |
def clean_state_dict(state_dict, prefix, filter=None): |
|
|
44 |
if filter is None: |
|
|
45 |
filter = lambda *args: True |
|
|
46 |
return {_remove_prefix(k, prefix): v for k, v in state_dict.items() if filter(k)} |
|
|
47 |
|
|
|
48 |
|
|
|
49 |
def load_dox_url(url, filename, model_dir=None, map_location=None, progress=True): |
|
|
50 |
r"""Adapt to fit format file of mtdp pre-trained models |
|
|
51 |
""" |
|
|
52 |
if model_dir is None: |
|
|
53 |
torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch')) |
|
|
54 |
model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models')) |
|
|
55 |
if not os.path.exists(model_dir): |
|
|
56 |
os.makedirs(model_dir) |
|
|
57 |
cached_file = os.path.join(model_dir, filename) |
|
|
58 |
if not os.path.exists(cached_file): |
|
|
59 |
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) |
|
|
60 |
sys.stderr.flush() |
|
|
61 |
download_url_to_file(url, cached_file, None, progress=progress) |
|
|
62 |
return torch.load(cached_file, map_location=map_location) |
|
|
63 |
|
|
|
64 |
|
|
|
65 |
|
|
|
66 |
MTDRN_URLS = { |
|
|
67 |
"resnet50": ("https://dox.uliege.be/index.php/s/kvABLtVuMxW8iJy/download", "resnet50-mh-best-191205-141200.pth") |
|
|
68 |
} |
|
|
69 |
|
|
|
70 |
|
|
|
71 |
class NoHeadResNet(ResNet, FeaturesInterface): |
|
|
72 |
def forward(self, x): |
|
|
73 |
x = self.conv1(x) |
|
|
74 |
x = self.bn1(x) |
|
|
75 |
x = self.relu(x) |
|
|
76 |
x = self.maxpool(x) |
|
|
77 |
|
|
|
78 |
x = self.layer1(x) |
|
|
79 |
x = self.layer2(x) |
|
|
80 |
x = self.layer3(x) |
|
|
81 |
x = self.layer4(x) |
|
|
82 |
x = self.avgpool(x) |
|
|
83 |
return x |
|
|
84 |
|
|
|
85 |
def n_features(self): |
|
|
86 |
return [b for b in list(self.layer4[-1].children()) if hasattr(b, 'num_features')][-1].num_features |
|
|
87 |
|
|
|
88 |
|
|
|
89 |
def build_resnet(download_dir, pretrained=None, arch="resnet50", model_class=NoHeadResNet, **kwargs): |
|
|
90 |
"""Constructs a ResNet-18 model. |
|
|
91 |
Args: |
|
|
92 |
arch (str): Type of densenet (among: resnet18, resnet34, resnet50, resnet101 and resnet152) |
|
|
93 |
pretrained (str|None): If "imagenet", returns a model pre-trained on ImageNet. If "mtdp" returns a model |
|
|
94 |
pre-trained in multi-task on digital pathology data. Otherwise (None), random weights. |
|
|
95 |
model_class (nn.Module): Actual resnet module class |
|
|
96 |
""" |
|
|
97 |
params = { |
|
|
98 |
"resnet18": [BasicBlock, [2, 2, 2, 2]], |
|
|
99 |
"resnet34": [BasicBlock, [3, 4, 6, 3]], |
|
|
100 |
"resnet50": [Bottleneck, [3, 4, 6, 3]], |
|
|
101 |
"resnet101": [Bottleneck, [3, 4, 23, 3]], |
|
|
102 |
"resnet152": [Bottleneck, [3, 8, 36, 3]] |
|
|
103 |
} |
|
|
104 |
model = model_class(*params[arch], **kwargs) |
|
|
105 |
if isinstance(pretrained, str): |
|
|
106 |
if pretrained == "imagenet": |
|
|
107 |
url = resnet_urls[arch] # default imagenet |
|
|
108 |
state_dict = model_zoo.load_url(url) |
|
|
109 |
elif pretrained == "mtdp": |
|
|
110 |
if arch not in MTDRN_URLS: |
|
|
111 |
raise ValueError("No pretrained weights for multi task pretraining with architecture '{}'".format(arch)) |
|
|
112 |
url, filename = MTDRN_URLS[arch] |
|
|
113 |
state_dict = load_dox_url(url, filename, model_dir=download_dir, map_location="cpu") |
|
|
114 |
state_dict = clean_state_dict(state_dict, prefix="features.", filter=lambda k: not k.startswith("heads.")) |
|
|
115 |
else: |
|
|
116 |
raise ValueError("Unknown pre-training source") |
|
|
117 |
model.load_state_dict(state_dict) |
|
|
118 |
return model |
|
|
119 |
|
|
|
120 |
MTDP_URLS = { |
|
|
121 |
"densenet121": ("https://dox.uliege.be/index.php/s/G72InP4xmJvOrVp/download", "densenet121-mh-best-191205-141200.pth") |
|
|
122 |
} |
|
|
123 |
|
|
|
124 |
|
|
|
125 |
class NoHeadDenseNet(DenseNet, FeaturesInterface): |
|
|
126 |
def forward(self, x): |
|
|
127 |
return F.adaptive_avg_pool2d(self.features(x), (1, 1)) |
|
|
128 |
|
|
|
129 |
def n_features(self): |
|
|
130 |
return self.features[-1].num_features |
|
|
131 |
|
|
|
132 |
|
|
|
133 |
def build_densenet(download_dir, pretrained=False, arch="densenet121", model_class=NoHeadDenseNet, **kwargs): |
|
|
134 |
r"""Densenet-XXX model from |
|
|
135 |
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ |
|
|
136 |
Args: |
|
|
137 |
arch (str): Type of densenet (among: densenet121, densenet169, densenet201 and densenet161) |
|
|
138 |
pretrained (str|None): If "imagenet", returns a model pre-trained on ImageNet. If "mtdp" returns a model pre-trained |
|
|
139 |
in multi-task on digital pathology data. Otherwise (None), random weights. |
|
|
140 |
model_class (nn.Module): Actual densenet module class |
|
|
141 |
""" |
|
|
142 |
params = { |
|
|
143 |
"densenet121": {"num_init_features": 64, "growth_rate": 32, "block_config": (6, 12, 24, 16)}, |
|
|
144 |
"densenet169": {"num_init_features": 64, "growth_rate": 32, "block_config": (6, 12, 32, 32)}, |
|
|
145 |
"densenet201": {"num_init_features": 64, "growth_rate": 32, "block_config": (6, 12, 48, 32)}, |
|
|
146 |
"densenet161": {"num_init_features": 96, "growth_rate": 48, "block_config": (6, 12, 36, 24)} |
|
|
147 |
} |
|
|
148 |
model = model_class(**(params[arch]), **kwargs) |
|
|
149 |
if isinstance(pretrained, str): |
|
|
150 |
# '.'s are no longer allowed in module names, but pervious _DenseLayer |
|
|
151 |
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. |
|
|
152 |
# They are also in the checkpoints in model_urls. This pattern is used |
|
|
153 |
# to find such keys. |
|
|
154 |
if pretrained == "imagenet": |
|
|
155 |
pattern = re.compile( |
|
|
156 |
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') |
|
|
157 |
state_dict = model_zoo.load_url(densenet_urls[arch]) |
|
|
158 |
for key in list(state_dict.keys()): |
|
|
159 |
res = pattern.match(key) |
|
|
160 |
if res: |
|
|
161 |
new_key = res.group(1) + res.group(2) |
|
|
162 |
state_dict[new_key] = state_dict[key] |
|
|
163 |
del state_dict[key] |
|
|
164 |
elif pretrained == "mtdp": |
|
|
165 |
if arch not in MTDP_URLS: |
|
|
166 |
raise ValueError("No pretrained weights for multi task pretraining with architecture '{}'".format(arch)) |
|
|
167 |
url, filename = MTDP_URLS[arch] |
|
|
168 |
state_dict = load_dox_url(url, filename, model_dir=download_dir, map_location="cpu") |
|
|
169 |
state_dict = clean_state_dict(state_dict, prefix="features.", filter=lambda k: not k.startswith("heads.")) |
|
|
170 |
else: |
|
|
171 |
raise ValueError("Unknown pre-training source") |
|
|
172 |
model.load_state_dict(state_dict) |
|
|
173 |
return model |
|
|
174 |
|
|
|
175 |
|
|
|
176 |
class ResNetBottom(nn.Module): |
|
|
177 |
def __init__(self, original_model): |
|
|
178 |
super(ResNetBottom, self).__init__() |
|
|
179 |
self.features = nn.Sequential(*list(original_model.children())[:-1]) |
|
|
180 |
|
|
|
181 |
def forward(self, x): |
|
|
182 |
x = self.features(x) |
|
|
183 |
x = torch.flatten(x, 1) |
|
|
184 |
return x |
|
|
185 |
|
|
|
186 |
|
|
|
187 |
class DenseNetEmbedder: |
|
|
188 |
def __init__(self, model, preprocess, name, backbone): |
|
|
189 |
self.model = model |
|
|
190 |
self.preprocess = preprocess |
|
|
191 |
self.name = name |
|
|
192 |
self.backbone = backbone |
|
|
193 |
|
|
|
194 |
def image_embedder(self, list_of_images, device="cuda", num_workers=1, batch_size=32, additional_cache_name=""): |
|
|
195 |
# additional_cache_name: name of the validation dataset (e.g., Kather_7K) |
|
|
196 |
hit_or_miss = cache_hit_or_miss(self.name + "img" + additional_cache_name, self.backbone) |
|
|
197 |
|
|
|
198 |
if hit_or_miss is not None: |
|
|
199 |
return hit_or_miss |
|
|
200 |
else: |
|
|
201 |
hit = self.embed_images(list_of_images, device=device, num_workers=num_workers, batch_size=batch_size) |
|
|
202 |
cache_numpy_object(hit, self.name + "img" + additional_cache_name, self.backbone) |
|
|
203 |
return hit |
|
|
204 |
|
|
|
205 |
def embed_images(self, list_of_images, device="cuda", num_workers=1, batch_size=32): |
|
|
206 |
dataset = CLIPImageDataset(list_of_images, self.preprocess) |
|
|
207 |
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) |
|
|
208 |
|
|
|
209 |
all_embs = [] |
|
|
210 |
for batch_X in tqdm(dataloader): |
|
|
211 |
batch_X = batch_X.to(device) |
|
|
212 |
embeddings = self.model(batch_X).detach().float().squeeze() |
|
|
213 |
embeddings = embeddings.detach().cpu().numpy() |
|
|
214 |
all_embs.append(embeddings) |
|
|
215 |
return np.concatenate(all_embs) |
|
|
216 |
|
|
|
217 |
|