Switch to side-by-side view

--- a
+++ b/reproducibility/embedders/mudipath.py
@@ -0,0 +1,217 @@
+import os
+import re
+import sys
+from abc import abstractmethod
+from reproducibility.utils.cacher import cache_hit_or_miss, cache_numpy_object
+from reproducibility.embedders.internal_datasets import CLIPImageDataset
+from torch.utils.data import DataLoader
+from torch.utils import model_zoo
+from torchvision.models.resnet import ResNet, model_urls as resnet_urls, BasicBlock, Bottleneck
+from torchvision.models.densenet import DenseNet, model_urls as densenet_urls
+import torch.nn.functional as F
+import numpy as np
+from torch import nn
+
+class FeaturesInterface(object):
+    @abstractmethod
+    def n_features(self):
+        pass
+
+import torch
+from torch.hub import download_url_to_file
+
+try:
+    from requests.utils import urlparse
+    from requests import get as urlopen
+    requests_available = True
+except ImportError:
+    requests_available = False
+    from urllib.request import urlopen
+    from urllib.parse import urlparse
+try:
+    from tqdm import tqdm
+except ImportError:
+    tqdm = None  # defined below
+
+
+def _remove_prefix(s, prefix):
+    if s.startswith(prefix):
+        s = s[len(prefix):]
+    return s
+
+
+def clean_state_dict(state_dict, prefix, filter=None):
+    if filter is None:
+        filter = lambda *args: True
+    return {_remove_prefix(k, prefix): v for k, v in state_dict.items() if filter(k)}
+
+
+def load_dox_url(url, filename, model_dir=None, map_location=None, progress=True):
+    r"""Adapt to fit format file of mtdp pre-trained models
+    """
+    if model_dir is None:
+        torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch'))
+        model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models'))
+    if not os.path.exists(model_dir):
+        os.makedirs(model_dir)
+    cached_file = os.path.join(model_dir, filename)
+    if not os.path.exists(cached_file):
+        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
+        sys.stderr.flush()
+        download_url_to_file(url, cached_file, None, progress=progress)
+    return torch.load(cached_file, map_location=map_location)
+
+
+
+MTDRN_URLS = {
+    "resnet50": ("https://dox.uliege.be/index.php/s/kvABLtVuMxW8iJy/download", "resnet50-mh-best-191205-141200.pth")
+}
+
+
+class NoHeadResNet(ResNet, FeaturesInterface):
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+        x = self.avgpool(x)
+        return x
+
+    def n_features(self):
+        return [b for b in list(self.layer4[-1].children()) if hasattr(b, 'num_features')][-1].num_features
+
+
+def build_resnet(download_dir, pretrained=None, arch="resnet50", model_class=NoHeadResNet, **kwargs):
+    """Constructs a ResNet-18 model.
+    Args:
+        arch (str): Type of densenet (among: resnet18, resnet34, resnet50, resnet101 and resnet152)
+        pretrained (str|None): If "imagenet", returns a model pre-trained on ImageNet. If "mtdp" returns a model
+                              pre-trained in multi-task on digital pathology data. Otherwise (None), random weights.
+        model_class (nn.Module): Actual resnet module class
+    """
+    params = {
+        "resnet18": [BasicBlock, [2, 2, 2, 2]],
+        "resnet34": [BasicBlock, [3, 4, 6, 3]],
+        "resnet50": [Bottleneck, [3, 4, 6, 3]],
+        "resnet101": [Bottleneck, [3, 4, 23, 3]],
+        "resnet152":  [Bottleneck, [3, 8, 36, 3]]
+    }
+    model = model_class(*params[arch], **kwargs)
+    if isinstance(pretrained, str):
+        if pretrained == "imagenet":
+            url = resnet_urls[arch]  # default imagenet
+            state_dict = model_zoo.load_url(url)
+        elif pretrained == "mtdp":
+            if arch not in MTDRN_URLS:
+                raise ValueError("No pretrained weights for multi task pretraining with architecture '{}'".format(arch))
+            url, filename = MTDRN_URLS[arch]
+            state_dict = load_dox_url(url, filename, model_dir=download_dir, map_location="cpu")
+            state_dict = clean_state_dict(state_dict, prefix="features.", filter=lambda k: not k.startswith("heads."))
+        else:
+            raise ValueError("Unknown pre-training source")
+        model.load_state_dict(state_dict)
+    return model
+
+MTDP_URLS = {
+    "densenet121": ("https://dox.uliege.be/index.php/s/G72InP4xmJvOrVp/download", "densenet121-mh-best-191205-141200.pth")
+}
+
+
+class NoHeadDenseNet(DenseNet, FeaturesInterface):
+    def forward(self, x):
+        return F.adaptive_avg_pool2d(self.features(x), (1, 1))
+
+    def n_features(self):
+        return self.features[-1].num_features
+
+
+def build_densenet(download_dir, pretrained=False, arch="densenet121", model_class=NoHeadDenseNet, **kwargs):
+    r"""Densenet-XXX model from
+    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
+    Args:
+        arch (str): Type of densenet (among: densenet121, densenet169, densenet201 and densenet161)
+        pretrained (str|None): If "imagenet", returns a model pre-trained on ImageNet. If "mtdp" returns a model pre-trained
+                           in multi-task on digital pathology data. Otherwise (None), random weights.
+        model_class (nn.Module): Actual densenet module class
+    """
+    params = {
+        "densenet121": {"num_init_features": 64, "growth_rate": 32, "block_config": (6, 12, 24, 16)},
+        "densenet169": {"num_init_features": 64, "growth_rate": 32, "block_config": (6, 12, 32, 32)},
+        "densenet201": {"num_init_features": 64, "growth_rate": 32, "block_config": (6, 12, 48, 32)},
+        "densenet161": {"num_init_features": 96, "growth_rate": 48, "block_config": (6, 12, 36, 24)}
+    }
+    model = model_class(**(params[arch]), **kwargs)
+    if isinstance(pretrained, str):
+        # '.'s are no longer allowed in module names, but pervious _DenseLayer
+        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
+        # They are also in the checkpoints in model_urls. This pattern is used
+        # to find such keys.
+        if pretrained == "imagenet":
+            pattern = re.compile(
+                r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
+            state_dict = model_zoo.load_url(densenet_urls[arch])
+            for key in list(state_dict.keys()):
+                res = pattern.match(key)
+                if res:
+                    new_key = res.group(1) + res.group(2)
+                    state_dict[new_key] = state_dict[key]
+                    del state_dict[key]
+        elif pretrained == "mtdp":
+            if arch not in MTDP_URLS:
+                raise ValueError("No pretrained weights for multi task pretraining with architecture '{}'".format(arch))
+            url, filename = MTDP_URLS[arch]
+            state_dict = load_dox_url(url, filename, model_dir=download_dir, map_location="cpu")
+            state_dict = clean_state_dict(state_dict, prefix="features.", filter=lambda k: not k.startswith("heads."))
+        else:
+            raise ValueError("Unknown pre-training source")
+        model.load_state_dict(state_dict)
+    return model
+
+
+class ResNetBottom(nn.Module):
+    def __init__(self, original_model):
+        super(ResNetBottom, self).__init__()
+        self.features = nn.Sequential(*list(original_model.children())[:-1])
+
+    def forward(self, x):
+        x = self.features(x)
+        x = torch.flatten(x, 1)
+        return x
+
+
+class DenseNetEmbedder:
+    def __init__(self, model, preprocess, name, backbone):
+        self.model = model
+        self.preprocess = preprocess
+        self.name = name
+        self.backbone = backbone
+
+    def image_embedder(self, list_of_images, device="cuda", num_workers=1, batch_size=32, additional_cache_name=""):
+        # additional_cache_name: name of the validation dataset (e.g., Kather_7K)
+        hit_or_miss = cache_hit_or_miss(self.name + "img" + additional_cache_name, self.backbone)
+
+        if hit_or_miss is not None:
+            return hit_or_miss
+        else:
+            hit = self.embed_images(list_of_images, device=device, num_workers=num_workers, batch_size=batch_size)
+            cache_numpy_object(hit, self.name + "img" + additional_cache_name, self.backbone)
+            return hit
+
+    def embed_images(self, list_of_images, device="cuda", num_workers=1, batch_size=32):
+        dataset = CLIPImageDataset(list_of_images, self.preprocess)
+        dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
+
+        all_embs = []
+        for batch_X in tqdm(dataloader):
+            batch_X = batch_X.to(device)
+            embeddings = self.model(batch_X).detach().float().squeeze()
+            embeddings = embeddings.detach().cpu().numpy()
+            all_embs.append(embeddings)
+        return np.concatenate(all_embs)
+
+