a b/biovil_t/pretrained.py
1
#  -------------------------------------------------------------------------------------------
2
#  Copyright (c) Microsoft Corporation. All rights reserved.
3
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
#  -------------------------------------------------------------------------------------------
5
6
from __future__ import annotations
7
8
import tempfile
9
from pathlib import Path
10
11
from torchvision.datasets.utils import download_url
12
13
from .model import ImageModel
14
from .types import ImageEncoderType
15
16
17
JOINT_FEATURE_SIZE = 128
18
19
BIOMED_VLP_CXR_BERT_SPECIALIZED = "microsoft/BiomedVLP-CXR-BERT-specialized"
20
BIOMED_VLP_BIOVIL_T = "microsoft/BiomedVLP-BioViL-T"
21
HF_URL = "https://huggingface.co"
22
23
CXR_BERT_COMMIT_TAG = "v1.1"
24
BIOVIL_T_COMMIT_TAG = "v1.0"
25
26
BIOVIL_IMAGE_WEIGHTS_NAME = "biovil_image_resnet50_proj_size_128.pt"
27
BIOVIL_IMAGE_WEIGHTS_URL = f"{HF_URL}/{BIOMED_VLP_CXR_BERT_SPECIALIZED}/resolve/{CXR_BERT_COMMIT_TAG}/{BIOVIL_IMAGE_WEIGHTS_NAME}"  # noqa: E501
28
BIOVIL_IMAGE_WEIGHTS_MD5 = "02ce6ee460f72efd599295f440dbb453"
29
30
BIOVIL_T_IMAGE_WEIGHTS_NAME = "biovil_t_image_model_proj_size_128.pt"
31
BIOVIL_T_IMAGE_WEIGHTS_URL = f"{HF_URL}/{BIOMED_VLP_BIOVIL_T}/resolve/{BIOVIL_T_COMMIT_TAG}/{BIOVIL_T_IMAGE_WEIGHTS_NAME}"  # noqa: E501
32
BIOVIL_T_IMAGE_WEIGHTS_MD5 = "a83080e2f23aa584a4f2b24c39b1bb64"
33
34
35
def _download_biovil_image_model_weights() -> Path:
36
    """Download image model weights from Hugging Face.
37
38
    More information available at https://huggingface.co/microsoft/BiomedVLP-CXR-BERT-specialized.
39
    """
40
    root_dir = tempfile.gettempdir()
41
    download_url(
42
        BIOVIL_IMAGE_WEIGHTS_URL,
43
        root=root_dir,
44
        filename=BIOVIL_IMAGE_WEIGHTS_NAME,
45
        md5=BIOVIL_IMAGE_WEIGHTS_MD5,
46
    )
47
    return Path(root_dir, BIOVIL_IMAGE_WEIGHTS_NAME)
48
49
50
def _download_biovil_t_image_model_weights() -> Path:
51
    """Download image model weights from Hugging Face.
52
53
    More information available at https://huggingface.co/microsoft/microsoft/BiomedVLP-BioViL-T.
54
    """
55
    root_dir = tempfile.gettempdir()
56
    download_url(
57
        BIOVIL_T_IMAGE_WEIGHTS_URL,
58
        root=root_dir,
59
        filename=BIOVIL_T_IMAGE_WEIGHTS_NAME,
60
        md5=BIOVIL_T_IMAGE_WEIGHTS_MD5
61
    )
62
    return Path(root_dir, BIOVIL_T_IMAGE_WEIGHTS_NAME)
63
64
65
def get_biovil_image_encoder(pretrained: bool = True) -> ImageModel:
66
    """Download weights from Hugging Face and instantiate the image model."""
67
    resnet_checkpoint_path = _download_biovil_image_model_weights() if pretrained else None
68
69
    image_model = ImageModel(
70
        img_encoder_type=ImageEncoderType.RESNET50,
71
        joint_feature_size=JOINT_FEATURE_SIZE,
72
        pretrained_model_path=resnet_checkpoint_path,
73
    )
74
    return image_model
75
76
77
def get_biovil_t_image_encoder() -> ImageModel:
78
    """Download weights from Hugging Face and instantiate the image model."""
79
80
    biovilt_checkpoint_path = _download_biovil_t_image_model_weights()
81
    model_type = ImageEncoderType.RESNET50_MULTI_IMAGE
82
    image_model = ImageModel(img_encoder_type=model_type,
83
                             joint_feature_size=JOINT_FEATURE_SIZE,
84
                             pretrained_model_path=biovilt_checkpoint_path)
85
    return image_model