Switch to unified view

a b/reproducibility/embedders/plip.py
1
import clip
2
import tqdm
3
import numpy as np
4
import torch
5
from reproducibility.embedders.internal_datasets import *
6
from torch.utils.data import DataLoader
7
from reproducibility.utils.cacher import cache_hit_or_miss, cache_numpy_object, cache_hit_or_miss_raw_filename, cache_numpy_object_raw_filename
8
9
class CLIPEmbedder:
10
11
    def __init__(self, model, preprocess, name, backbone):
12
        self.model = model
13
        self.preprocess = preprocess
14
        self.name = name
15
        self.backbone = backbone
16
        
17
    def image_embedder(self, list_of_images, device="cuda", num_workers=1, batch_size=32, additional_cache_name=""):
18
        hit_or_miss = cache_hit_or_miss_raw_filename(self.name + "img" + additional_cache_name, self.backbone)
19
20
        if hit_or_miss is not None:
21
            return hit_or_miss
22
        else:
23
            hit = self.embed_images(list_of_images, device=device, num_workers=num_workers, batch_size=batch_size)
24
            cache_numpy_object_raw_filename(hit, self.name + "img" + additional_cache_name, self.backbone)
25
            return hit
26
27
    def text_embedder(self, list_of_labels, device="cuda", num_workers=1, batch_size=32, additional_cache_name=""):
28
        hit_or_miss = cache_hit_or_miss(self.name + "txt" + additional_cache_name, self.backbone)
29
30
        if hit_or_miss is not None:
31
            return hit_or_miss
32
        else:
33
            hit = self.embed_text(list_of_labels, device=device, num_workers=num_workers, batch_size=batch_size)
34
            cache_numpy_object(hit, self.name + "txt" + additional_cache_name, self.backbone)
35
            return hit
36
37
    def embed_images(self, list_of_images, device="cuda", num_workers=1, batch_size=32):
38
        train_dataset = CLIPImageDataset(list_of_images, self.preprocess)
39
        dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers)
40
41
        image_embeddings = []
42
43
        total = len(list_of_images) // batch_size
44
        pbar = tqdm.tqdm(total=total, position=0)
45
        with torch.no_grad():
46
            for images in dataloader:
47
                images = images.to(device)
48
                image_embeddings.extend(self.model.encode_image(images).detach().cpu().numpy())
49
                pbar.update(1)
50
            pbar.close()
51
52
        image_embeddings = np.array(image_embeddings)
53
        image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True)
54
        return image_embeddings
55
56
    def embed_text(self, list_of_labels, device="cuda", num_workers=1, batch_size=32):
57
        train_dataset = CLIPCaptioningDataset(list_of_labels)
58
        dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers)
59
        text_embeddings = []
60
        total = len(list_of_labels) // batch_size
61
62
        pbar = tqdm.tqdm(total=total, position=0)
63
        with torch.no_grad():
64
            for captions in dataloader:
65
                idx = clip.tokenize(captions, truncate=True).to(device)
66
                text_embeddings.extend(self.model.encode_text(idx).detach().cpu().numpy())
67
68
                pbar.update(1)
69
70
            pbar.close()
71
72
        text_embeddings = np.array(text_embeddings)
73
        text_embeddings = text_embeddings / np.linalg.norm(text_embeddings, axis=1, keepdims=True)
74
75
        return text_embeddings
76
77