Diff of /modules/beam_search.py [000000] .. [03245f]

Switch to unified view

a b/modules/beam_search.py
1
# numpy import
2
import numpy as np
3
4
# tensorflow imports
5
import tensorflow as tf
6
import tensorflow
7
from tensorflow.keras.preprocessing.text import Tokenizer
8
from tensorflow.keras.models import Model
9
10
11
class BeamSearch:
12
    def __init__(self, start_token:str, end_token:str, max_length:int, tokenizer:Tokenizer, idx_to_word:dict, word_to_idx:dict, beam_index:int):
13
        """ The Beam Search sampling method for generating captions. An illustration of the algorithm is provided in my Thesis paper.
14
15
        Args:
16
            start_token (str): The start-token used during pre-processing of the training captions
17
            end_token (str): The end-token used during pre-processing of the training captions
18
            max_length (int): The maximum length (limit) for the generated captions
19
            tokenizer (Tokenizer): The fitted tokenizer from the Vocabulary object
20
            idx_to_word (dict): Dictionary with keys to be the index number and values the words in the created vocabulary
21
            word_to_idx (dict): Dictionary with keys to be the words and values the index number in the created vocabulary  
22
            beam_index (int): The beam size for the Beam Seach algorithm.
23
        """
24
        self.start_token = start_token
25
        self.end_token = end_token
26
        self.max_length = max_length
27
        self.tokenizer = tokenizer
28
        self.idx_to_word = idx_to_word
29
        self.word_to_idx = word_to_idx
30
        self.beam_index = beam_index
31
32
    def get_word(self, idx:int) -> str:
33
        """ Fetches the word from the index-to-word vocab, which was created after the pre-processing of the Training captions
34
35
        Args:
36
            idx (int): The index for the index-to-word vocab.
37
38
        Returns:
39
            str: The word for the given index if exist in the created index-to-word vocab, else None
40
        """
41
        return self.idx_to_word.get(idx, None)
42
43
    def get_idx(self, word:str)->int:
44
        """ Fetches the index number from the word-to-index vocab, which was created after the pre-processing of the Training captions
45
46
        Args:
47
            word (str): The word for which we want its index in the word-to-index dictionary.
48
49
        Returns:
50
            int: The index for the given word if exist in the created word-to-index vocab, else -1. The latter number refer to None
51
        """
52
        return self.word_to_idx.get(word, -1)
53
54
    def beam_search_predict(self, model:Model, image:np.array, tag:np.array, dataset:str='iuxray', multi_modal:bool=False)->str:
55
        """ Executes the beam search algorithm employing the pre-trained model along with the test instance's data.
56
57
        Args:
58
            model (Model): The model we want to evaluate on our employed dataset
59
            image (np.array): Current test image embedding
60
            tag (np.array): The tag embedding for the current test instance. This is used only for IU X-Ray dataset.
61
            dataset (str, optional): The dataset we employed for the model. Defaults to 'iuxray'.
62
            multi_modal (bool, optional): If we want to use the multi-modal version of model. This is used only for IU X-Ray dataset. Defaults to False.
63
64
        Returns:
65
            str: The generated description for the given image using the beam search.
66
        """
67
        start = [self.get_idx(self.start_token)]
68
        start_word = [[start, 0.0]]
69
        while len(start_word[0][0]) < self.max_length:
70
            # store current word,probs pairs
71
            temp = []
72
            # for current sequence
73
            for s in start_word:
74
                # pad the sequence in order to fetch the next token
75
                par_caps = tf.keras.preprocessing.sequence.pad_sequences([s[0]], maxlen=self.max_length, padding="post")
76
                if multi_modal:
77
                    if dataset == 'iuxray':
78
                        preds = model.predict(
79
                            [image[0], image[1], tag, par_caps], verbose=0)
80
                    else:
81
                        preds = model.predict(
82
                            [image, tag, par_caps], verbose=0)
83
                else:
84
                    if dataset == 'iuxray':
85
                        preds = model.predict(
86
                            [image[0], image[1], par_caps], verbose=0)
87
                    else:
88
                        preds = model.predict([image, par_caps], verbose=0)
89
                # get the best paths
90
                word_preds = np.argsort(preds[0])[-self.beam_index:]
91
92
                # Getting the top <self.self.beam_index>(n) predictions and creating a
93
                # new list so as to put them via the model again
94
                for w in word_preds:
95
                    next_cap, prob = s[0][:], s[1]
96
                    next_cap.append(w)
97
                    prob += preds[0][w]
98
                    temp.append([next_cap, prob])
99
100
            start_word = temp
101
            # Sorting according to the probabilities
102
            start_word = sorted(start_word, reverse=False, key=lambda l: l[1])
103
            # Getting the top words
104
            start_word = start_word[-self.beam_index:]
105
106
        # get the best path
107
        start_word = start_word[-1][0]
108
        intermediate_caption = [self.get_word(i) for i in start_word]
109
        final_caption = []
110
111
        for i in intermediate_caption:
112
            if i != self.end_token:
113
                final_caption.append(i)
114
            else:
115
                break
116
117
        final_caption = " ".join(final_caption[1:])
118
        return final_caption
119
120
    def ensemble_beam_search(self, models:list, images_list:list)->str:
121
        """ Executes the beam search algorithm employing the pre-trained models along with the test instances data.
122
        This utilises the beam search algorithm for each model in models list.
123
124
        Args:
125
            models (list): The models we want to evaluate on our employed dataset
126
            images_list (list): Current test images embeddings for each encoder we used.
127
128
        Returns:
129
            str: The generated description for the given image using the beam search.
130
        """
131
        start = [self.get_idx(self.start_token)]
132
        start_word = [[start, 0.0]]
133
        while len(start_word[0][0]) < self.max_length:
134
            # for current seq
135
            for s in start_word:
136
                # pad current caption
137
                current_caption = tf.keras.preprocessing.sequence.pad_sequences([s[0]], maxlen=self.max_length, padding="post")
138
                # get all predictions from the pre-trained models
139
                ensemble_predictions = [ ensemble_member.predict([image, current_caption], verbose=0) for ensemble_member, image in zip(models, images_list) ]
140
                # get the best pairs
141
                ensemble_word_predictions = [ np.argsort(prediction[0])[-self.beam_index:] for prediction in ensemble_predictions ]
142
143
                # and store them with word,pairs for each ensemble member
144
                ensemble_current_probs = list()
145
                for member_prediction, member_word_predictions in zip(ensemble_predictions, ensemble_word_predictions):
146
                    temp_current_seq = list()
147
148
                    for word in member_word_predictions:
149
                        next_cap, prob = s[0][:], s[1]
150
                        next_cap.append(word)
151
                        prob += member_prediction[0][word]
152
                        temp_current_seq.append([next_cap, prob])
153
154
                    ensemble_current_probs.append(temp_current_seq)
155
            # get all the best candidates
156
            ensemble_starting_words = [
157
                sorted(current_probs, reverse=False, key=lambda l: l[1])[-self.beam_index:] for current_probs in ensemble_current_probs
158
            ]
159
160
            start_words = [member_starting_words[-1] for member_starting_words in ensemble_starting_words]
161
            start_word = sorted(start_words, reverse=False, key=lambda l: l[1])
162
163
        # create the caption
164
        start_word = start_word[-1][0]
165
        intermediate_caption = [self.get_word(i) for i in start_word]
166
        final_caption = []
167
168
        for i in intermediate_caption:
169
            if i != self.end_token:
170
                final_caption.append(i)
171
            else:
172
                break
173
174
        final_caption = " ".join(final_caption[1:])
175
        return final_caption