[3ff14f]: / example.py

Download this file

11 lines (8 with data), 223 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
import torch
from medpalm.model import MedPalm
# usage
img = torch.randn(1, 3, 256, 256)
caption = torch.randint(0, 20000, (1, 1024))
model = MedPalm()
output = model(img, caption)
print(output.shape) # (1, 1024, 20000)