--- a +++ b/documentation/demo.py @@ -0,0 +1,67 @@ +""" +Download Merlin and test the model on sample data that is downloaded from huggingface +""" + +import os +import warnings +import torch + +from merlin.data import download_sample_data +from merlin.data import DataLoader +from merlin import Merlin + + +warnings.filterwarnings("ignore") +device = "cuda" if torch.cuda.is_available() else "cpu" + +model = Merlin() +model.eval() +model.cuda() + +data_dir = os.path.join(os.path.dirname(__file__), "abct_data") +cache_dir = data_dir.replace("abct_data", "abct_data_cache") + +datalist = [ + { + "image": download_sample_data( + data_dir + ), # function returns local path to nifti file + "text": "Lower thorax: A small low-attenuating fluid structure is noted in the right cardiophrenic angle in keeping with a tiny pericardial cyst." + "Liver and biliary tree: Normal. Gallbladder: Normal. Spleen: Normal. Pancreas: Normal. Adrenal glands: Normal. " + "Kidneys and ureters: Symmetric enhancement and excretion of the bilateral kidneys, with no striated nephrogram to suggest pyelonephritis. " + "Urothelial enhancement bilaterally, consistent with urinary tract infection. No renal/ureteral calculi. No hydronephrosis. " + "Gastrointestinal tract: Normal. Normal gas-filled appendix. Peritoneal cavity: No free fluid. " + "Bladder: Marked urothelial enhancement consistent with cystitis. Uterus and ovaries: Normal. " + "Vasculature: Patent. Lymph nodes: Normal. Abdominal wall: Normal. " + "Musculoskeletal: Degenerative change of the spine.", + }, +] + +dataloader = DataLoader( + datalist=datalist, + cache_dir=cache_dir, + batchsize=8, + shuffle=True, + num_workers=0, +) + +for batch in dataloader: + outputs = model(batch["image"].to(device), batch["text"]) + print("\n================== Output Shapes ==================") + print(f"Contrastive image embeddings shape: {outputs[0].shape}") + print(f"Phenotype predictions shape: {outputs[1].shape}") + print(f"Contrastive text embeddings shape: {outputs[2].shape}") + +## Get the Image Embeddings +model = Merlin(ImageEmbedding=True) +model.eval() +model.cuda() + +for batch in dataloader: + outputs = model( + batch["image"].to(device), + ) + print("\n================== Output Shapes ==================") + print( + f"Image embeddings shape (Can be used for downstream tasks): {outputs[0].shape}" + )