Diff of /modules/image_encoder.py [000000] .. [03245f]

Switch to unified view

a b/modules/image_encoder.py
1
# os imports
2
import os
3
4
# numopy and progress bar imports
5
import numpy as np
6
from tqdm import tqdm
7
import pickle
8
9
# tensorflow imports
10
import tensorflow
11
# CNN image encoders
12
from tensorflow.keras.applications.densenet import DenseNet201 as dn201
13
from tensorflow.keras.applications.densenet import DenseNet121 as dn121
14
from tensorflow.keras.applications.densenet import DenseNet169 as dn169
15
from tensorflow.keras.applications.efficientnet import EfficientNetB0
16
from tensorflow.keras.applications.efficientnet import EfficientNetB5 as enb5
17
from tensorflow.keras.applications.efficientnet import EfficientNetB7 as enb7
18
from tensorflow.keras.applications.resnet_v2 import ResNet50V2 as rn50v2
19
from tensorflow.keras.applications.inception_resnet_v2 import InceptionResNetV2 as incres
20
# preprocessing functions
21
from tensorflow.keras.applications.densenet import preprocess_input as dense_preprocess
22
from tensorflow.keras.applications.efficientnet import preprocess_input as efficient_preprocess
23
from tensorflow.keras.applications.resnet_v2 import preprocess_input as resnet_preprocess
24
from tensorflow.keras.applications.inception_resnet_v2 import preprocess_input as inceptionresnet_preprocess
25
26
# layers imports
27
from tensorflow.keras.layers import Flatten, Dropout, Dense, Conv2D, MaxPooling2D, BatchNormalization
28
from tensorflow.keras.models import Model
29
from tensorflow.keras.preprocessing import image as img
30
31
# cotnet import
32
# from keras_cv_attention_models import cotnet
33
34
35
def load_encoded_vecs(filename:str) -> dict:
36
    """ Loads the image embeddings for each image id, we extracted offline during my research
37
38
    Args:
39
        filename (str): the whole path of npy file
40
41
    Returns:
42
        dict: encoded_vectors from filename
43
    """
44
    with open(filename, 'rb') as f:
45
        print("Image Encoded Vectors loaded from directory path:", filename)
46
        return pickle.load(f)
47
48
49
def save_encoded_vecs(image_vecs:np.array, output_path:str, filename:str) -> None:
50
    """ Function which helps us to save the encoded images into a pickle file
51
52
    Args:
53
        image_vecs (np.array): the encoded images vectors that we extracted using the encode_images function
54
        output_path (str): the output path where we want to save our image embeddings
55
        filename (str): a name we want to use for our npy file (ex. densenet201_image_vecs). It's not necessary to write '.pkl' at the end!
56
    """
57
    path = output_path + filename + '.pkl'
58
    with open(path, 'wb') as f:
59
        pickle.dump(image_vecs, f, pickle.HIGHEST_PROTOCOL)
60
    print("Image Encoded Vectors stored in:", path)
61
62
63
class ImageEncoder:
64
65
    def __init__(self, encoder:str, images_dir_path:str, weights:str='imagenet'):
66
        """ This class helps us to extract image embeddings with different Keras CNNs.
67
        
68
        Args:
69
            encoder (str): encoder name you want to use (ex. densenet201 for DenseNet201)
70
            images_dir_path (str): The directory to store our extracted vectors
71
            weights (str, optional): the pretrained weights you want to use for your model. It's common to use imagenet as default pretrained weights.. Defaults to 'imagenet'.
72
        """
73
        self.encoder_weights = weights
74
        self.image_dir_path = images_dir_path
75
76
        # we extracted the last average pooling layer for each encoder
77
78
        if encoder == 'densenet201':
79
            self.image_shape = 224
80
            self.preprocess = 'densenet'
81
            model = dn201(include_top=True, weights=self.encoder_weights,
82
                          input_shape=(self.image_shape, self.image_shape, 3))
83
            self.image_encoder = Model(
84
                inputs=model.input, outputs=model.get_layer('avg_pool').output)
85
86
        elif encoder == 'densenet121':
87
            self.image_shape = 224
88
            self.preprocess = 'densenet'
89
            model = dn121(include_top=True, weights=self.encoder_weights,
90
                          input_shape=(self.image_shape, self.image_shape, 3))
91
92
            self.image_encoder = Model(
93
                inputs=model.input, outputs=model.get_layer('avg_pool').output)
94
95
        elif encoder == 'densenet169':
96
            self.image_shape = 224
97
            self.preprocess = 'densenet'
98
            model = dn169(include_top=True, weights=self.encoder_weights,
99
                          input_shape=(self.image_shape, self.image_shape, 3))
100
            self.image_encoder = Model(
101
                inputs=model.input, outputs=model.get_layer('avg_pool').output)
102
103
        elif encoder == 'efficientnet5':
104
            self.image_shape = 456
105
            self.preprocess = 'efficientnet'
106
            model = enb5(include_top=True, weights=self.encoder_weights,
107
                         input_shape=(self.image_shape, self.image_shape, 3))
108
            self.image_encoder = Model(
109
                inputs=model.input, outputs=model.get_layer('avg_pool').output)
110
111
        elif encoder == 'efficientnet0':
112
            self.image_shape = 224
113
            self.preprocess = 'efficientnet'
114
            model = EfficientNetB0(include_top=True, weights=self.encoder_weights,
115
                                   input_shape=(self.image_shape, self.image_shape, 3))
116
            self.image_encoder = Model(
117
                inputs=model.input, outputs=model.get_layer('avg_pool').output)
118
119
        elif encoder == 'resnet50v2':
120
            self.image_shape = 224
121
            self.preprocess = 'resnet'
122
            model = rn50v2(include_top=True, weights=self.encoder_weights,
123
                           input_shape=(self.image_shape, self.image_shape, 3))
124
            self.image_encoder = Model(
125
                inputs=model.input, outputs=model.get_layer('avg_pool').output)
126
        elif encoder == 'inceptionresnet':
127
            self.image_shape = 299
128
            self.preprocess = 'inceptionresnet'
129
            model = incres(include_top=True, weights=self.encoder_weights,
130
                           input_shape=(self.image_shape, self.image_shape, 3))
131
            self.image_encoder = Model(
132
                inputs=model.input, outputs=model.get_layer('avg_pool').output)
133
134
        elif encoder == 'cotnet':
135
            self.image_shape = 224
136
            self.preprocess = 'cotnet'
137
            model = cotnet.CotNet50(pretrained="imagenet", num_classes=0)
138
            self.image_encoder = Model(
139
                inputs=model.input, outputs=model.output)
140
141
        else:
142
            print("You have to initialize a valid version of image encoder\n"
143
                  "Choices are: [densenet201, densenet121, densenet169, efficientnet0, efficientnet5, resnet50v2, inceptionresnet, cotnet]")
144
            print("Exiting...")
145
            return
146
147
    def get_preprocessor(self) -> str:
148
        """ Gets the pre-processing function
149
150
        Returns:
151
            str: The pre-processing name we initialized
152
        """
153
        return self.preprocess
154
155
    def get_image_shape(self) -> int:
156
        """ Gets the input shape
157
158
        Returns:
159
            int: The input shape for the employed encoder
160
        """
161
        return self.image_shape
162
163
    def get_image_encoder(self) -> Model:
164
        """ Gets the image encoder we built
165
166
        Returns:
167
            Model: The CNN encoder
168
        """
169
        return self.image_encoder
170
171
    def get_images_dirpath(self) -> str:
172
        """ Gets the image directory path to store our vectors
173
174
        Returns:
175
            str: The image directory path
176
        """
177
        return self.image_dir_path
178
179
    def encode(self, _image:str, verbose:int=0) -> np.array:
180
        """ Loads an image and it passes it in CNN encoder to extract its image embeddings.
181
182
        Args:
183
            _image (str): The image id, for which its image we want to encode
184
            verbose (int, optional): If we want to display the extraction. Defaults to 0.
185
186
        Returns:
187
            np.array: The encoded version of the given image
188
        """
189
        # case CoTNet
190
        if self.get_preprocessor() == 'cotnet':
191
            image = img.load_img(self.image_dir_path + _image + '.jpg')
192
            image_array = img.img_to_array(image)
193
194
            imm = tensorflow.keras.applications.imagenet_utils.preprocess_input(image_array, mode='torch')
195
            image_encoded = self.image_encoder(
196
                tensorflow.expand_dims(tensorflow.image.resize(imm, self.image_encoder.input_shape[1:3]), 0)).numpy()
197
        else:
198
            # case othe encoders
199
            # load the image and convert it to np.array
200
            image = img.load_img(self.image_dir_path + _image + '.jpg',
201
                                 target_size=(self.image_shape, self.image_shape))
202
            image_array = img.img_to_array(image)
203
            image_array = np.expand_dims(image_array, axis=0)
204
            # pre-process array in order to fit with the employed encoder
205
            if self.get_preprocessor() == 'densenet':
206
                preprocessed_image_array = dense_preprocess(image_array)
207
            elif self.get_preprocessor() == 'efficientnet':
208
                preprocessed_image_array = efficient_preprocess(image_array)
209
            elif self.get_preprocessor() == 'resnet':
210
                preprocessed_image_array = resnet_preprocess(image_array)
211
            elif self.get_preprocessor() == 'inceptionresnet':
212
                preprocessed_image_array = inceptionresnet_preprocess(
213
                    image_array)
214
            # extract image embeddings
215
            image_encoded = self.image_encoder.predict(preprocessed_image_array, verbose=verbose)
216
        return image_encoded
217
218
    def encode_images(self, images:list) -> np.array:
219
        """ Loads an image list with image ids, and extract their image embeddings
220
221
        Args:
222
            images (list): Image IDs list
223
224
        Returns:
225
            np.array: All image vectors
226
        """
227
        image_vecs = {_image: self.encode(_image) for _image in
228
                      tqdm(images, desc="Encoding images", position=0, leave=True)}
229
        return image_vecs