a b/reproducibility/embedders/factory.py
1
import torch
2
import clip
3
from reproducibility.embedders.plip import CLIPEmbedder
4
from reproducibility.embedders.mudipath import build_densenet
5
from torchvision import transforms
6
from reproducibility.embedders.mudipath import DenseNetEmbedder
7
import os
8
9
10
class EmbedderFactory:
11
12
    def __init__(self):
13
        pass
14
    
15
    def factory(self, args):
16
        name = args.model_name
17
        path = args.backbone
18
19
        device = "cuda" if torch.cuda.is_available() else "cpu"
20
        if name == "plip":
21
            model, preprocess = clip.load(os.environ["PC_CLIP_ARCH"], device=device)
22
            if device == 'cuda':
23
                model.load_state_dict(torch.load(path))
24
            elif device == 'cpu':
25
                model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
26
            model.eval()
27
            return CLIPEmbedder(model, preprocess, name, path)
28
29
        elif name == "clip":
30
            model, preprocess = clip.load(os.environ["PC_CLIP_ARCH"], device=device)
31
            model.eval()
32
            return CLIPEmbedder(model, preprocess, name, path)
33
34
        elif name == "mudipath":
35
            backbone = build_densenet(download_dir="/oak/stanford/groups/jamesz/pathtweets/models/",
36
                                      pretrained="mtdp")  # TODO fixed path
37
            backbone.num_feats = backbone.n_features()
38
            backbone.forward_type = "image"
39
            backbone = backbone.to(device)
40
            backbone.eval()
41
            image_preprocess = transforms.Compose([
42
                transforms.Resize(224),
43
                transforms.CenterCrop(224),
44
                transforms.ToTensor(),
45
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet stats
46
            ])
47
            return DenseNetEmbedder(backbone, image_preprocess, name, path)