Switch to unified view

a b/reproducibility/evaluation/retrieval/retrieval.py
1
from reproducibility.metrics import retrieval_metrics
2
import logging
3
4
class ImageRetrieval:
5
6
    def __init__(self):
7
        pass
8
9
    def retrieval(self, image_embeddings, text_embeddings):
10
11
        best_scores = []
12
13
        for t in text_embeddings:
14
            arr = t.dot(image_embeddings.T)
15
16
            best = arr.argsort()[-50:][::-1]
17
18
            best_scores.append(best)
19
20
        targets = list(range(0, len(image_embeddings)))
21
22
        test_metrics = retrieval_metrics(targets, best_scores)
23
        train_metrics = retrieval_metrics(targets, best_scores)
24
25
        test_metrics["split"] = "test"
26
        train_metrics["split"] = "train"
27
28
        logging.info(f"Retrieval Done")
29
30
        return train_metrics, test_metrics