[7829e6]: / reproducibility / embedders / mudipath.py

Download this file

218 lines (182 with data), 8.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import os
import re
import sys
from abc import abstractmethod
from reproducibility.utils.cacher import cache_hit_or_miss, cache_numpy_object
from reproducibility.embedders.internal_datasets import CLIPImageDataset
from torch.utils.data import DataLoader
from torch.utils import model_zoo
from torchvision.models.resnet import ResNet, model_urls as resnet_urls, BasicBlock, Bottleneck
from torchvision.models.densenet import DenseNet, model_urls as densenet_urls
import torch.nn.functional as F
import numpy as np
from torch import nn
class FeaturesInterface(object):
@abstractmethod
def n_features(self):
pass
import torch
from torch.hub import download_url_to_file
try:
from requests.utils import urlparse
from requests import get as urlopen
requests_available = True
except ImportError:
requests_available = False
from urllib.request import urlopen
from urllib.parse import urlparse
try:
from tqdm import tqdm
except ImportError:
tqdm = None # defined below
def _remove_prefix(s, prefix):
if s.startswith(prefix):
s = s[len(prefix):]
return s
def clean_state_dict(state_dict, prefix, filter=None):
if filter is None:
filter = lambda *args: True
return {_remove_prefix(k, prefix): v for k, v in state_dict.items() if filter(k)}
def load_dox_url(url, filename, model_dir=None, map_location=None, progress=True):
r"""Adapt to fit format file of mtdp pre-trained models
"""
if model_dir is None:
torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch'))
model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models'))
if not os.path.exists(model_dir):
os.makedirs(model_dir)
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
sys.stderr.flush()
download_url_to_file(url, cached_file, None, progress=progress)
return torch.load(cached_file, map_location=map_location)
MTDRN_URLS = {
"resnet50": ("https://dox.uliege.be/index.php/s/kvABLtVuMxW8iJy/download", "resnet50-mh-best-191205-141200.pth")
}
class NoHeadResNet(ResNet, FeaturesInterface):
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
return x
def n_features(self):
return [b for b in list(self.layer4[-1].children()) if hasattr(b, 'num_features')][-1].num_features
def build_resnet(download_dir, pretrained=None, arch="resnet50", model_class=NoHeadResNet, **kwargs):
"""Constructs a ResNet-18 model.
Args:
arch (str): Type of densenet (among: resnet18, resnet34, resnet50, resnet101 and resnet152)
pretrained (str|None): If "imagenet", returns a model pre-trained on ImageNet. If "mtdp" returns a model
pre-trained in multi-task on digital pathology data. Otherwise (None), random weights.
model_class (nn.Module): Actual resnet module class
"""
params = {
"resnet18": [BasicBlock, [2, 2, 2, 2]],
"resnet34": [BasicBlock, [3, 4, 6, 3]],
"resnet50": [Bottleneck, [3, 4, 6, 3]],
"resnet101": [Bottleneck, [3, 4, 23, 3]],
"resnet152": [Bottleneck, [3, 8, 36, 3]]
}
model = model_class(*params[arch], **kwargs)
if isinstance(pretrained, str):
if pretrained == "imagenet":
url = resnet_urls[arch] # default imagenet
state_dict = model_zoo.load_url(url)
elif pretrained == "mtdp":
if arch not in MTDRN_URLS:
raise ValueError("No pretrained weights for multi task pretraining with architecture '{}'".format(arch))
url, filename = MTDRN_URLS[arch]
state_dict = load_dox_url(url, filename, model_dir=download_dir, map_location="cpu")
state_dict = clean_state_dict(state_dict, prefix="features.", filter=lambda k: not k.startswith("heads."))
else:
raise ValueError("Unknown pre-training source")
model.load_state_dict(state_dict)
return model
MTDP_URLS = {
"densenet121": ("https://dox.uliege.be/index.php/s/G72InP4xmJvOrVp/download", "densenet121-mh-best-191205-141200.pth")
}
class NoHeadDenseNet(DenseNet, FeaturesInterface):
def forward(self, x):
return F.adaptive_avg_pool2d(self.features(x), (1, 1))
def n_features(self):
return self.features[-1].num_features
def build_densenet(download_dir, pretrained=False, arch="densenet121", model_class=NoHeadDenseNet, **kwargs):
r"""Densenet-XXX model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
arch (str): Type of densenet (among: densenet121, densenet169, densenet201 and densenet161)
pretrained (str|None): If "imagenet", returns a model pre-trained on ImageNet. If "mtdp" returns a model pre-trained
in multi-task on digital pathology data. Otherwise (None), random weights.
model_class (nn.Module): Actual densenet module class
"""
params = {
"densenet121": {"num_init_features": 64, "growth_rate": 32, "block_config": (6, 12, 24, 16)},
"densenet169": {"num_init_features": 64, "growth_rate": 32, "block_config": (6, 12, 32, 32)},
"densenet201": {"num_init_features": 64, "growth_rate": 32, "block_config": (6, 12, 48, 32)},
"densenet161": {"num_init_features": 96, "growth_rate": 48, "block_config": (6, 12, 36, 24)}
}
model = model_class(**(params[arch]), **kwargs)
if isinstance(pretrained, str):
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
if pretrained == "imagenet":
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(densenet_urls[arch])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
elif pretrained == "mtdp":
if arch not in MTDP_URLS:
raise ValueError("No pretrained weights for multi task pretraining with architecture '{}'".format(arch))
url, filename = MTDP_URLS[arch]
state_dict = load_dox_url(url, filename, model_dir=download_dir, map_location="cpu")
state_dict = clean_state_dict(state_dict, prefix="features.", filter=lambda k: not k.startswith("heads."))
else:
raise ValueError("Unknown pre-training source")
model.load_state_dict(state_dict)
return model
class ResNetBottom(nn.Module):
def __init__(self, original_model):
super(ResNetBottom, self).__init__()
self.features = nn.Sequential(*list(original_model.children())[:-1])
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
return x
class DenseNetEmbedder:
def __init__(self, model, preprocess, name, backbone):
self.model = model
self.preprocess = preprocess
self.name = name
self.backbone = backbone
def image_embedder(self, list_of_images, device="cuda", num_workers=1, batch_size=32, additional_cache_name=""):
# additional_cache_name: name of the validation dataset (e.g., Kather_7K)
hit_or_miss = cache_hit_or_miss(self.name + "img" + additional_cache_name, self.backbone)
if hit_or_miss is not None:
return hit_or_miss
else:
hit = self.embed_images(list_of_images, device=device, num_workers=num_workers, batch_size=batch_size)
cache_numpy_object(hit, self.name + "img" + additional_cache_name, self.backbone)
return hit
def embed_images(self, list_of_images, device="cuda", num_workers=1, batch_size=32):
dataset = CLIPImageDataset(list_of_images, self.preprocess)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
all_embs = []
for batch_X in tqdm(dataloader):
batch_X = batch_X.to(device)
embeddings = self.model(batch_X).detach().float().squeeze()
embeddings = embeddings.detach().cpu().numpy()
all_embs.append(embeddings)
return np.concatenate(all_embs)