Download this file

29 lines (20 with data), 863 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import numpy as np
from reproducibility.metrics import eval_metrics
import logging
class ZeroShotClassifier:
def __init__(self):
pass
def zero_shot_classification(self, image_embeddings, text_embeddings, unique_labels, target_labels):
score = image_embeddings.dot(text_embeddings.T)
predictions = [unique_labels[np.argmax(i)] for i in score]
test_metrics = eval_metrics(target_labels, predictions)
train_metrics = eval_metrics(target_labels, predictions)
test_metrics["split"] = "test"
train_metrics["split"] = "train"
import pickle
with open("pickle.pkl", "wb") as f:
pickle.dump({"target" : target_labels,
"predictions" : predictions}, f)
exit()
logging.info(f"ZeroShot Done")
return train_metrics, test_metrics