Switch to side-by-side view

--- a
+++ b/reproducibility/embedders/factory.py
@@ -0,0 +1,47 @@
+import torch
+import clip
+from reproducibility.embedders.plip import CLIPEmbedder
+from reproducibility.embedders.mudipath import build_densenet
+from torchvision import transforms
+from reproducibility.embedders.mudipath import DenseNetEmbedder
+import os
+
+
+class EmbedderFactory:
+
+    def __init__(self):
+        pass
+    
+    def factory(self, args):
+        name = args.model_name
+        path = args.backbone
+
+        device = "cuda" if torch.cuda.is_available() else "cpu"
+        if name == "plip":
+            model, preprocess = clip.load(os.environ["PC_CLIP_ARCH"], device=device)
+            if device == 'cuda':
+                model.load_state_dict(torch.load(path))
+            elif device == 'cpu':
+                model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
+            model.eval()
+            return CLIPEmbedder(model, preprocess, name, path)
+
+        elif name == "clip":
+            model, preprocess = clip.load(os.environ["PC_CLIP_ARCH"], device=device)
+            model.eval()
+            return CLIPEmbedder(model, preprocess, name, path)
+
+        elif name == "mudipath":
+            backbone = build_densenet(download_dir="/oak/stanford/groups/jamesz/pathtweets/models/",
+                                      pretrained="mtdp")  # TODO fixed path
+            backbone.num_feats = backbone.n_features()
+            backbone.forward_type = "image"
+            backbone = backbone.to(device)
+            backbone.eval()
+            image_preprocess = transforms.Compose([
+                transforms.Resize(224),
+                transforms.CenterCrop(224),
+                transforms.ToTensor(),
+                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet stats
+            ])
+            return DenseNetEmbedder(backbone, image_preprocess, name, path)
\ No newline at end of file