--- a +++ b/plip.py @@ -0,0 +1,118 @@ +import torch +import numpy as np +from tqdm import tqdm +from typing import List, Union, Tuple +from torch.utils.data import DataLoader +import PIL +from transformers import CLIPModel, CLIPProcessor +from datasets import Dataset, Image + + +class PLIP: + + + def __init__(self, model_name, auth_token=None): + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model_name = model_name + self.model, self.preprocess, self.model_hash = self._load_model(model_name, auth_token=auth_token) + self.model = self.model.to(self.device) + + + def _load_model(self, + name: str, + device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", + auth_token=None): + + model = CLIPModel.from_pretrained(name, use_auth_token=auth_token) + preprocessing = CLIPProcessor.from_pretrained(name, use_auth_token=auth_token) + + return model, preprocessing, hash + + def encode_images(self, images: Union[List[str], List[PIL.Image.Image]], batch_size: int): + def transform_fn(el): + imgs = el['image'] if isinstance(el['image'][0], PIL.Image.Image) else [Image().decode_example(_) for _ in + el['image']] + return self.preprocess(images=imgs, return_tensors='pt') + + dataset = Dataset.from_dict({'image': images}) + dataset = dataset.cast_column('image', Image(decode=False)) if isinstance(images[0], str) else dataset + # dataset = dataset.map(map_fn, + # batched=True, + # remove_columns=['image']) + dataset.set_format('torch') + dataset.set_transform(transform_fn) + dataloader = DataLoader(dataset, batch_size=batch_size) + image_embeddings = [] + pbar = tqdm(total=len(images) // batch_size, position=0) + with torch.no_grad(): + for batch in dataloader: + batch = {k: v.to(self.device) for k, v in batch.items()} + image_embeddings.extend(self.model.get_image_features(**batch).detach().cpu().numpy()) + pbar.update(1) + pbar.close() + return np.stack(image_embeddings) + + def encode_text(self, text: List[str], batch_size: int): + dataset = Dataset.from_dict({'text': text}) + dataset = dataset.map(lambda el: self.preprocess(text=el['text'], return_tensors="pt", + max_length=77, padding="max_length", truncation=True), + batched=True, + remove_columns=['text']) + dataset.set_format('torch') + dataloader = DataLoader(dataset, batch_size=batch_size) + text_embeddings = [] + pbar = tqdm(total=len(text) // batch_size, position=0) + with torch.no_grad(): + for batch in dataloader: + batch = {k: v.to(self.device) for k, v in batch.items()} + text_embeddings.extend(self.model.get_text_features(**batch).detach().cpu().numpy()) + pbar.update(1) + pbar.close() + return np.stack(text_embeddings) + + def _cosine_similarity(self, key_vectors: np.ndarray, space_vectors: np.ndarray, normalize=True): + if normalize: + key_vectors = key_vectors / np.linalg.norm(key_vectors, ord=2, axis=-1, keepdims=True) + return np.matmul(key_vectors, space_vectors.T) + + def _nearest_neighbours(self, k, key_vectors, space_vectors, normalize=True, debug=False): + if type(key_vectors) == List: + key_vectors = np.array(key_vectors) + if type(space_vectors) == List: + space_vectors = np.array(space_vectors) + + cosine_sim = self._cosine_similarity(key_vectors, space_vectors, normalize=normalize) + nn = cosine_sim.argsort()[:, -k:][:, ::-1] + + return nn + + def zero_shot_classification(self, images, text_labels: List[str], debug=False): + """ + Perform zero-shot image classification + :return: + """ + # encode text + text_vectors = self.encode_text(text_labels, batch_size=8) + # encode images + image_vectors = self.encode_images(images, batch_size=8) + # compute cosine similarity + cosine_sim = self._cosine_similarity(image_vectors, text_vectors) + if debug: + print(cosine_sim) + preds = np.argmax(cosine_sim, axis=-1) + return [text_labels[idx] for idx in preds] + + def retrieval(self, queries: List[str], top_k: int = 10): + """ + Image retrieval from queries + :return: + """ + # encode text + text_vectors = self.encode_text(queries, batch_size=8) + # compute cosine similarity + # cosine_sim = self._cosine_similarity(text_vectors, self.image_vectors) + return self._nearest_neighbours(k=top_k, key_vectors=text_vectors, space_vectors=self.image_vectors) + + # return np.argmax(cosine_sim, axis=-1) + # return cosine_sim.argsort()[:,-top_k:][:,::-1] +