|
a |
|
b/reproducibility/evaluation/retrieval/retrieval.py |
|
|
1 |
from reproducibility.metrics import retrieval_metrics |
|
|
2 |
import logging |
|
|
3 |
|
|
|
4 |
class ImageRetrieval: |
|
|
5 |
|
|
|
6 |
def __init__(self): |
|
|
7 |
pass |
|
|
8 |
|
|
|
9 |
def retrieval(self, image_embeddings, text_embeddings): |
|
|
10 |
|
|
|
11 |
best_scores = [] |
|
|
12 |
|
|
|
13 |
for t in text_embeddings: |
|
|
14 |
arr = t.dot(image_embeddings.T) |
|
|
15 |
|
|
|
16 |
best = arr.argsort()[-50:][::-1] |
|
|
17 |
|
|
|
18 |
best_scores.append(best) |
|
|
19 |
|
|
|
20 |
targets = list(range(0, len(image_embeddings))) |
|
|
21 |
|
|
|
22 |
test_metrics = retrieval_metrics(targets, best_scores) |
|
|
23 |
train_metrics = retrieval_metrics(targets, best_scores) |
|
|
24 |
|
|
|
25 |
test_metrics["split"] = "test" |
|
|
26 |
train_metrics["split"] = "train" |
|
|
27 |
|
|
|
28 |
logging.info(f"Retrieval Done") |
|
|
29 |
|
|
|
30 |
return train_metrics, test_metrics |