|
a |
|
b/reproducibility/embedders/plip.py |
|
|
1 |
import clip |
|
|
2 |
import tqdm |
|
|
3 |
import numpy as np |
|
|
4 |
import torch |
|
|
5 |
from reproducibility.embedders.internal_datasets import * |
|
|
6 |
from torch.utils.data import DataLoader |
|
|
7 |
from reproducibility.utils.cacher import cache_hit_or_miss, cache_numpy_object, cache_hit_or_miss_raw_filename, cache_numpy_object_raw_filename |
|
|
8 |
|
|
|
9 |
class CLIPEmbedder: |
|
|
10 |
|
|
|
11 |
def __init__(self, model, preprocess, name, backbone): |
|
|
12 |
self.model = model |
|
|
13 |
self.preprocess = preprocess |
|
|
14 |
self.name = name |
|
|
15 |
self.backbone = backbone |
|
|
16 |
|
|
|
17 |
def image_embedder(self, list_of_images, device="cuda", num_workers=1, batch_size=32, additional_cache_name=""): |
|
|
18 |
hit_or_miss = cache_hit_or_miss_raw_filename(self.name + "img" + additional_cache_name, self.backbone) |
|
|
19 |
|
|
|
20 |
if hit_or_miss is not None: |
|
|
21 |
return hit_or_miss |
|
|
22 |
else: |
|
|
23 |
hit = self.embed_images(list_of_images, device=device, num_workers=num_workers, batch_size=batch_size) |
|
|
24 |
cache_numpy_object_raw_filename(hit, self.name + "img" + additional_cache_name, self.backbone) |
|
|
25 |
return hit |
|
|
26 |
|
|
|
27 |
def text_embedder(self, list_of_labels, device="cuda", num_workers=1, batch_size=32, additional_cache_name=""): |
|
|
28 |
hit_or_miss = cache_hit_or_miss(self.name + "txt" + additional_cache_name, self.backbone) |
|
|
29 |
|
|
|
30 |
if hit_or_miss is not None: |
|
|
31 |
return hit_or_miss |
|
|
32 |
else: |
|
|
33 |
hit = self.embed_text(list_of_labels, device=device, num_workers=num_workers, batch_size=batch_size) |
|
|
34 |
cache_numpy_object(hit, self.name + "txt" + additional_cache_name, self.backbone) |
|
|
35 |
return hit |
|
|
36 |
|
|
|
37 |
def embed_images(self, list_of_images, device="cuda", num_workers=1, batch_size=32): |
|
|
38 |
train_dataset = CLIPImageDataset(list_of_images, self.preprocess) |
|
|
39 |
dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers) |
|
|
40 |
|
|
|
41 |
image_embeddings = [] |
|
|
42 |
|
|
|
43 |
total = len(list_of_images) // batch_size |
|
|
44 |
pbar = tqdm.tqdm(total=total, position=0) |
|
|
45 |
with torch.no_grad(): |
|
|
46 |
for images in dataloader: |
|
|
47 |
images = images.to(device) |
|
|
48 |
image_embeddings.extend(self.model.encode_image(images).detach().cpu().numpy()) |
|
|
49 |
pbar.update(1) |
|
|
50 |
pbar.close() |
|
|
51 |
|
|
|
52 |
image_embeddings = np.array(image_embeddings) |
|
|
53 |
image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True) |
|
|
54 |
return image_embeddings |
|
|
55 |
|
|
|
56 |
def embed_text(self, list_of_labels, device="cuda", num_workers=1, batch_size=32): |
|
|
57 |
train_dataset = CLIPCaptioningDataset(list_of_labels) |
|
|
58 |
dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers) |
|
|
59 |
text_embeddings = [] |
|
|
60 |
total = len(list_of_labels) // batch_size |
|
|
61 |
|
|
|
62 |
pbar = tqdm.tqdm(total=total, position=0) |
|
|
63 |
with torch.no_grad(): |
|
|
64 |
for captions in dataloader: |
|
|
65 |
idx = clip.tokenize(captions, truncate=True).to(device) |
|
|
66 |
text_embeddings.extend(self.model.encode_text(idx).detach().cpu().numpy()) |
|
|
67 |
|
|
|
68 |
pbar.update(1) |
|
|
69 |
|
|
|
70 |
pbar.close() |
|
|
71 |
|
|
|
72 |
text_embeddings = np.array(text_embeddings) |
|
|
73 |
text_embeddings = text_embeddings / np.linalg.norm(text_embeddings, axis=1, keepdims=True) |
|
|
74 |
|
|
|
75 |
return text_embeddings |
|
|
76 |
|
|
|
77 |
|