|
a |
|
b/biovil_t/pretrained.py |
|
|
1 |
# ------------------------------------------------------------------------------------------- |
|
|
2 |
# Copyright (c) Microsoft Corporation. All rights reserved. |
|
|
3 |
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. |
|
|
4 |
# ------------------------------------------------------------------------------------------- |
|
|
5 |
|
|
|
6 |
from __future__ import annotations |
|
|
7 |
|
|
|
8 |
import tempfile |
|
|
9 |
from pathlib import Path |
|
|
10 |
|
|
|
11 |
from torchvision.datasets.utils import download_url |
|
|
12 |
|
|
|
13 |
from .model import ImageModel |
|
|
14 |
from .types import ImageEncoderType |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
JOINT_FEATURE_SIZE = 128 |
|
|
18 |
|
|
|
19 |
BIOMED_VLP_CXR_BERT_SPECIALIZED = "microsoft/BiomedVLP-CXR-BERT-specialized" |
|
|
20 |
BIOMED_VLP_BIOVIL_T = "microsoft/BiomedVLP-BioViL-T" |
|
|
21 |
HF_URL = "https://huggingface.co" |
|
|
22 |
|
|
|
23 |
CXR_BERT_COMMIT_TAG = "v1.1" |
|
|
24 |
BIOVIL_T_COMMIT_TAG = "v1.0" |
|
|
25 |
|
|
|
26 |
BIOVIL_IMAGE_WEIGHTS_NAME = "biovil_image_resnet50_proj_size_128.pt" |
|
|
27 |
BIOVIL_IMAGE_WEIGHTS_URL = f"{HF_URL}/{BIOMED_VLP_CXR_BERT_SPECIALIZED}/resolve/{CXR_BERT_COMMIT_TAG}/{BIOVIL_IMAGE_WEIGHTS_NAME}" # noqa: E501 |
|
|
28 |
BIOVIL_IMAGE_WEIGHTS_MD5 = "02ce6ee460f72efd599295f440dbb453" |
|
|
29 |
|
|
|
30 |
BIOVIL_T_IMAGE_WEIGHTS_NAME = "biovil_t_image_model_proj_size_128.pt" |
|
|
31 |
BIOVIL_T_IMAGE_WEIGHTS_URL = f"{HF_URL}/{BIOMED_VLP_BIOVIL_T}/resolve/{BIOVIL_T_COMMIT_TAG}/{BIOVIL_T_IMAGE_WEIGHTS_NAME}" # noqa: E501 |
|
|
32 |
BIOVIL_T_IMAGE_WEIGHTS_MD5 = "a83080e2f23aa584a4f2b24c39b1bb64" |
|
|
33 |
|
|
|
34 |
|
|
|
35 |
def _download_biovil_image_model_weights() -> Path: |
|
|
36 |
"""Download image model weights from Hugging Face. |
|
|
37 |
|
|
|
38 |
More information available at https://huggingface.co/microsoft/BiomedVLP-CXR-BERT-specialized. |
|
|
39 |
""" |
|
|
40 |
root_dir = tempfile.gettempdir() |
|
|
41 |
download_url( |
|
|
42 |
BIOVIL_IMAGE_WEIGHTS_URL, |
|
|
43 |
root=root_dir, |
|
|
44 |
filename=BIOVIL_IMAGE_WEIGHTS_NAME, |
|
|
45 |
md5=BIOVIL_IMAGE_WEIGHTS_MD5, |
|
|
46 |
) |
|
|
47 |
return Path(root_dir, BIOVIL_IMAGE_WEIGHTS_NAME) |
|
|
48 |
|
|
|
49 |
|
|
|
50 |
def _download_biovil_t_image_model_weights() -> Path: |
|
|
51 |
"""Download image model weights from Hugging Face. |
|
|
52 |
|
|
|
53 |
More information available at https://huggingface.co/microsoft/microsoft/BiomedVLP-BioViL-T. |
|
|
54 |
""" |
|
|
55 |
root_dir = tempfile.gettempdir() |
|
|
56 |
download_url( |
|
|
57 |
BIOVIL_T_IMAGE_WEIGHTS_URL, |
|
|
58 |
root=root_dir, |
|
|
59 |
filename=BIOVIL_T_IMAGE_WEIGHTS_NAME, |
|
|
60 |
md5=BIOVIL_T_IMAGE_WEIGHTS_MD5 |
|
|
61 |
) |
|
|
62 |
return Path(root_dir, BIOVIL_T_IMAGE_WEIGHTS_NAME) |
|
|
63 |
|
|
|
64 |
|
|
|
65 |
def get_biovil_image_encoder(pretrained: bool = True) -> ImageModel: |
|
|
66 |
"""Download weights from Hugging Face and instantiate the image model.""" |
|
|
67 |
resnet_checkpoint_path = _download_biovil_image_model_weights() if pretrained else None |
|
|
68 |
|
|
|
69 |
image_model = ImageModel( |
|
|
70 |
img_encoder_type=ImageEncoderType.RESNET50, |
|
|
71 |
joint_feature_size=JOINT_FEATURE_SIZE, |
|
|
72 |
pretrained_model_path=resnet_checkpoint_path, |
|
|
73 |
) |
|
|
74 |
return image_model |
|
|
75 |
|
|
|
76 |
|
|
|
77 |
def get_biovil_t_image_encoder() -> ImageModel: |
|
|
78 |
"""Download weights from Hugging Face and instantiate the image model.""" |
|
|
79 |
|
|
|
80 |
biovilt_checkpoint_path = _download_biovil_t_image_model_weights() |
|
|
81 |
model_type = ImageEncoderType.RESNET50_MULTI_IMAGE |
|
|
82 |
image_model = ImageModel(img_encoder_type=model_type, |
|
|
83 |
joint_feature_size=JOINT_FEATURE_SIZE, |
|
|
84 |
pretrained_model_path=biovilt_checkpoint_path) |
|
|
85 |
return image_model |