|
a |
|
b/documentation/demo.py |
|
|
1 |
""" |
|
|
2 |
Download Merlin and test the model on sample data that is downloaded from huggingface |
|
|
3 |
""" |
|
|
4 |
|
|
|
5 |
import os |
|
|
6 |
import warnings |
|
|
7 |
import torch |
|
|
8 |
|
|
|
9 |
from merlin.data import download_sample_data |
|
|
10 |
from merlin.data import DataLoader |
|
|
11 |
from merlin import Merlin |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
warnings.filterwarnings("ignore") |
|
|
15 |
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
16 |
|
|
|
17 |
model = Merlin() |
|
|
18 |
model.eval() |
|
|
19 |
model.cuda() |
|
|
20 |
|
|
|
21 |
data_dir = os.path.join(os.path.dirname(__file__), "abct_data") |
|
|
22 |
cache_dir = data_dir.replace("abct_data", "abct_data_cache") |
|
|
23 |
|
|
|
24 |
datalist = [ |
|
|
25 |
{ |
|
|
26 |
"image": download_sample_data( |
|
|
27 |
data_dir |
|
|
28 |
), # function returns local path to nifti file |
|
|
29 |
"text": "Lower thorax: A small low-attenuating fluid structure is noted in the right cardiophrenic angle in keeping with a tiny pericardial cyst." |
|
|
30 |
"Liver and biliary tree: Normal. Gallbladder: Normal. Spleen: Normal. Pancreas: Normal. Adrenal glands: Normal. " |
|
|
31 |
"Kidneys and ureters: Symmetric enhancement and excretion of the bilateral kidneys, with no striated nephrogram to suggest pyelonephritis. " |
|
|
32 |
"Urothelial enhancement bilaterally, consistent with urinary tract infection. No renal/ureteral calculi. No hydronephrosis. " |
|
|
33 |
"Gastrointestinal tract: Normal. Normal gas-filled appendix. Peritoneal cavity: No free fluid. " |
|
|
34 |
"Bladder: Marked urothelial enhancement consistent with cystitis. Uterus and ovaries: Normal. " |
|
|
35 |
"Vasculature: Patent. Lymph nodes: Normal. Abdominal wall: Normal. " |
|
|
36 |
"Musculoskeletal: Degenerative change of the spine.", |
|
|
37 |
}, |
|
|
38 |
] |
|
|
39 |
|
|
|
40 |
dataloader = DataLoader( |
|
|
41 |
datalist=datalist, |
|
|
42 |
cache_dir=cache_dir, |
|
|
43 |
batchsize=8, |
|
|
44 |
shuffle=True, |
|
|
45 |
num_workers=0, |
|
|
46 |
) |
|
|
47 |
|
|
|
48 |
for batch in dataloader: |
|
|
49 |
outputs = model(batch["image"].to(device), batch["text"]) |
|
|
50 |
print("\n================== Output Shapes ==================") |
|
|
51 |
print(f"Contrastive image embeddings shape: {outputs[0].shape}") |
|
|
52 |
print(f"Phenotype predictions shape: {outputs[1].shape}") |
|
|
53 |
print(f"Contrastive text embeddings shape: {outputs[2].shape}") |
|
|
54 |
|
|
|
55 |
## Get the Image Embeddings |
|
|
56 |
model = Merlin(ImageEmbedding=True) |
|
|
57 |
model.eval() |
|
|
58 |
model.cuda() |
|
|
59 |
|
|
|
60 |
for batch in dataloader: |
|
|
61 |
outputs = model( |
|
|
62 |
batch["image"].to(device), |
|
|
63 |
) |
|
|
64 |
print("\n================== Output Shapes ==================") |
|
|
65 |
print( |
|
|
66 |
f"Image embeddings shape (Can be used for downstream tasks): {outputs[0].shape}" |
|
|
67 |
) |