|
a |
|
b/reproducibility/embedders/factory.py |
|
|
1 |
import torch |
|
|
2 |
import clip |
|
|
3 |
from reproducibility.embedders.plip import CLIPEmbedder |
|
|
4 |
from reproducibility.embedders.mudipath import build_densenet |
|
|
5 |
from torchvision import transforms |
|
|
6 |
from reproducibility.embedders.mudipath import DenseNetEmbedder |
|
|
7 |
import os |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
class EmbedderFactory: |
|
|
11 |
|
|
|
12 |
def __init__(self): |
|
|
13 |
pass |
|
|
14 |
|
|
|
15 |
def factory(self, args): |
|
|
16 |
name = args.model_name |
|
|
17 |
path = args.backbone |
|
|
18 |
|
|
|
19 |
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
20 |
if name == "plip": |
|
|
21 |
model, preprocess = clip.load(os.environ["PC_CLIP_ARCH"], device=device) |
|
|
22 |
if device == 'cuda': |
|
|
23 |
model.load_state_dict(torch.load(path)) |
|
|
24 |
elif device == 'cpu': |
|
|
25 |
model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) |
|
|
26 |
model.eval() |
|
|
27 |
return CLIPEmbedder(model, preprocess, name, path) |
|
|
28 |
|
|
|
29 |
elif name == "clip": |
|
|
30 |
model, preprocess = clip.load(os.environ["PC_CLIP_ARCH"], device=device) |
|
|
31 |
model.eval() |
|
|
32 |
return CLIPEmbedder(model, preprocess, name, path) |
|
|
33 |
|
|
|
34 |
elif name == "mudipath": |
|
|
35 |
backbone = build_densenet(download_dir="/oak/stanford/groups/jamesz/pathtweets/models/", |
|
|
36 |
pretrained="mtdp") # TODO fixed path |
|
|
37 |
backbone.num_feats = backbone.n_features() |
|
|
38 |
backbone.forward_type = "image" |
|
|
39 |
backbone = backbone.to(device) |
|
|
40 |
backbone.eval() |
|
|
41 |
image_preprocess = transforms.Compose([ |
|
|
42 |
transforms.Resize(224), |
|
|
43 |
transforms.CenterCrop(224), |
|
|
44 |
transforms.ToTensor(), |
|
|
45 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet stats |
|
|
46 |
]) |
|
|
47 |
return DenseNetEmbedder(backbone, image_preprocess, name, path) |