Diff of /modules/cluster.py [000000] .. [03245f]

Switch to unified view

a b/modules/cluster.py
1
# import sys
2
import sys
3
sys.path.append('..')
4
5
# progress bar import
6
from tqdm import tqdm
7
8
# numpy, sklearn imports
9
import numpy as np
10
from sklearn.cluster import KMeans
11
from sklearn.decomposition import PCA
12
from sklearn.metrics import silhouette_samples, silhouette_score
13
from tensorflow.keras.preprocessing.image import load_img
14
15
# utils imports
16
from models import *
17
from utils.dataset import ImageCLEFDataset
18
19
20
class Cluster:
21
22
    def __init__(self, K:int, clef_dataset:ImageCLEFDataset):
23
        """ Class to perform K-Means clustering in ImageCLEF dataset. We used this system for the contest
24
25
        Args:
26
            K (int): The K clusters we want
27
            clef_dataset (ImageCLEFDataset): The dataset we employed. Only CLEF is acceptable.
28
        """
29
        self.K = K
30
        self.dataset = clef_dataset
31
32
    def do_PCA(self, features:dict) -> np.array:
33
        """ Perforrms Principal Component Analysis (PCA), to reduce the huge size of the arrays
34
35
        Args:
36
            features (dict): The image_ids, image_vectors pairs.
37
38
        Returns:
39
            np.array: The image_ids, image_vectors pairs, with reduced size.
40
        """
41
42
        feat = np.array(list(features.values()))
43
        feat = feat.reshape(-1, feat.shape[2])
44
45
        pca = PCA(n_components=100, random_state=22)
46
        pca.fit(feat)
47
        x = pca.transform(feat)
48
        return x
49
50
    def do_Kmeans(self, x:np.array) -> KMeans:
51
        """ Fit the K-Means
52
53
        Args:
54
            x (np.array): The image vectors
55
56
        Returns:
57
            KMeans: The fitted K-Means object
58
        """
59
        kmeans = KMeans(n_clusters=self.K, random_state=22)
60
        kmeans.fit(x)
61
        return kmeans
62
63
    def load_features(self) ->tuple[list[dict], list[dict], list[dict]]:
64
        """ Loads train, validation, test sets 
65
66
        Returns:
67
            tuple[list[dict], list[dict], list[dict]]: The train, validation, test sets in dictionary format
68
        """
69
        return self.dataset.get_splits_sets()
70
71
    def clustering(self) -> tuple[dict, dict, dict]:
72
        """ Performs the k-Means clustering using the fitted K-Means object.
73
74
        Returns:
75
            tuple[dict, dict, dict]: The clustered train, val, test image_ids, image_vectors pairs.
76
        """
77
        # load splits
78
        train_features, valid_features, test_features = self.load_features()
79
        # get the ids for each split
80
        train_ids, val_ids, test_ids = list(train_features[0].keys()), list(valid_features[0].keys()), list(test_features[0].keys())
81
82
        # concate all features to perform a more efficient K-Means
83
        all_features = dict(train_features, **valid_features)
84
        all_features = dict(all_features, **test_features)
85
86
        # reduce size for fast training
87
        pca = self.do_PCA(all_features)
88
        # perform clustering
89
        kmeans = self.do_Kmeans(pca)
90
        
91
        train_index_limit, val_index_limit = len(train_features), len(train_features)+len(valid_features)
92
        # get the clustering labels for each set
93
        train_k_means_labels = kmeans.labels_[:train_index_limit]
94
        valid_k_means_labels = kmeans.labels_[train_index_limit:val_index_limit]
95
        test_k_means_labels = kmeans.labels_[val_index_limit:]
96
        
97
        
98
        print('# train kmeans:',  len(train_k_means_labels))
99
        print('# dev kmeans:',  len(valid_k_means_labels))
100
        print('# test kmeans:',  len(test_k_means_labels))
101
        
102
103
        # store the clustered train, validation, test set images
104
        groups_train = {}
105
        for file, cluster in tqdm(zip(train_ids, train_k_means_labels)):
106
            if cluster not in groups_train.keys():
107
                groups_train[cluster] = []
108
                groups_train[cluster].append(file)
109
            else:
110
                groups_train[cluster].append(file)
111
112
        groups_valid = {}
113
        for file, cluster in tqdm(zip(val_ids, valid_k_means_labels)):
114
            if cluster not in groups_valid.keys():
115
                groups_valid[cluster] = []
116
                groups_valid[cluster].append(file)
117
            else:
118
                groups_valid[cluster].append(file)
119
        
120
        groups_test = {}
121
        for file, cluster in tqdm(zip(test_ids, test_k_means_labels)):
122
            if cluster not in groups_test.keys():
123
                groups_test[cluster] = []
124
                groups_test[cluster].append(file)
125
            else:
126
                groups_test[cluster].append(file)
127
128
        print('# train kmeans:',  len(groups_train))
129
        print('# dev kmeans:',  len(groups_valid))
130
        print('# test kmeans:',  len(groups_test))
131
        return groups_train, groups_valid, groups_test
132
133
134