[278d8a]: / reproducibility / embedders / factory.py

Download this file

47 lines (40 with data), 1.8 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import clip
from reproducibility.embedders.plip import CLIPEmbedder
from reproducibility.embedders.mudipath import build_densenet
from torchvision import transforms
from reproducibility.embedders.mudipath import DenseNetEmbedder
import os
class EmbedderFactory:
def __init__(self):
pass
def factory(self, args):
name = args.model_name
path = args.backbone
device = "cuda" if torch.cuda.is_available() else "cpu"
if name == "plip":
model, preprocess = clip.load(os.environ["PC_CLIP_ARCH"], device=device)
if device == 'cuda':
model.load_state_dict(torch.load(path))
elif device == 'cpu':
model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
model.eval()
return CLIPEmbedder(model, preprocess, name, path)
elif name == "clip":
model, preprocess = clip.load(os.environ["PC_CLIP_ARCH"], device=device)
model.eval()
return CLIPEmbedder(model, preprocess, name, path)
elif name == "mudipath":
backbone = build_densenet(download_dir="/oak/stanford/groups/jamesz/pathtweets/models/",
pretrained="mtdp") # TODO fixed path
backbone.num_feats = backbone.n_features()
backbone.forward_type = "image"
backbone = backbone.to(device)
backbone.eval()
image_preprocess = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet stats
])
return DenseNetEmbedder(backbone, image_preprocess, name, path)