Diff of /modules/beam_search.py [000000] .. [03245f]

Switch to side-by-side view

--- a
+++ b/modules/beam_search.py
@@ -0,0 +1,175 @@
+# numpy import
+import numpy as np
+
+# tensorflow imports
+import tensorflow as tf
+import tensorflow
+from tensorflow.keras.preprocessing.text import Tokenizer
+from tensorflow.keras.models import Model
+
+
+class BeamSearch:
+    def __init__(self, start_token:str, end_token:str, max_length:int, tokenizer:Tokenizer, idx_to_word:dict, word_to_idx:dict, beam_index:int):
+        """ The Beam Search sampling method for generating captions. An illustration of the algorithm is provided in my Thesis paper.
+
+        Args:
+            start_token (str): The start-token used during pre-processing of the training captions
+            end_token (str): The end-token used during pre-processing of the training captions
+            max_length (int): The maximum length (limit) for the generated captions
+            tokenizer (Tokenizer): The fitted tokenizer from the Vocabulary object
+            idx_to_word (dict): Dictionary with keys to be the index number and values the words in the created vocabulary
+            word_to_idx (dict): Dictionary with keys to be the words and values the index number in the created vocabulary  
+            beam_index (int): The beam size for the Beam Seach algorithm.
+        """
+        self.start_token = start_token
+        self.end_token = end_token
+        self.max_length = max_length
+        self.tokenizer = tokenizer
+        self.idx_to_word = idx_to_word
+        self.word_to_idx = word_to_idx
+        self.beam_index = beam_index
+
+    def get_word(self, idx:int) -> str:
+        """ Fetches the word from the index-to-word vocab, which was created after the pre-processing of the Training captions
+
+        Args:
+            idx (int): The index for the index-to-word vocab.
+
+        Returns:
+            str: The word for the given index if exist in the created index-to-word vocab, else None
+        """
+        return self.idx_to_word.get(idx, None)
+
+    def get_idx(self, word:str)->int:
+        """ Fetches the index number from the word-to-index vocab, which was created after the pre-processing of the Training captions
+
+        Args:
+            word (str): The word for which we want its index in the word-to-index dictionary.
+
+        Returns:
+            int: The index for the given word if exist in the created word-to-index vocab, else -1. The latter number refer to None
+        """
+        return self.word_to_idx.get(word, -1)
+
+    def beam_search_predict(self, model:Model, image:np.array, tag:np.array, dataset:str='iuxray', multi_modal:bool=False)->str:
+        """ Executes the beam search algorithm employing the pre-trained model along with the test instance's data.
+
+        Args:
+            model (Model): The model we want to evaluate on our employed dataset
+            image (np.array): Current test image embedding
+            tag (np.array): The tag embedding for the current test instance. This is used only for IU X-Ray dataset.
+            dataset (str, optional): The dataset we employed for the model. Defaults to 'iuxray'.
+            multi_modal (bool, optional): If we want to use the multi-modal version of model. This is used only for IU X-Ray dataset. Defaults to False.
+
+        Returns:
+            str: The generated description for the given image using the beam search.
+        """
+        start = [self.get_idx(self.start_token)]
+        start_word = [[start, 0.0]]
+        while len(start_word[0][0]) < self.max_length:
+            # store current word,probs pairs
+            temp = []
+            # for current sequence
+            for s in start_word:
+                # pad the sequence in order to fetch the next token
+                par_caps = tf.keras.preprocessing.sequence.pad_sequences([s[0]], maxlen=self.max_length, padding="post")
+                if multi_modal:
+                    if dataset == 'iuxray':
+                        preds = model.predict(
+                            [image[0], image[1], tag, par_caps], verbose=0)
+                    else:
+                        preds = model.predict(
+                            [image, tag, par_caps], verbose=0)
+                else:
+                    if dataset == 'iuxray':
+                        preds = model.predict(
+                            [image[0], image[1], par_caps], verbose=0)
+                    else:
+                        preds = model.predict([image, par_caps], verbose=0)
+                # get the best paths
+                word_preds = np.argsort(preds[0])[-self.beam_index:]
+
+                # Getting the top <self.self.beam_index>(n) predictions and creating a
+                # new list so as to put them via the model again
+                for w in word_preds:
+                    next_cap, prob = s[0][:], s[1]
+                    next_cap.append(w)
+                    prob += preds[0][w]
+                    temp.append([next_cap, prob])
+
+            start_word = temp
+            # Sorting according to the probabilities
+            start_word = sorted(start_word, reverse=False, key=lambda l: l[1])
+            # Getting the top words
+            start_word = start_word[-self.beam_index:]
+
+        # get the best path
+        start_word = start_word[-1][0]
+        intermediate_caption = [self.get_word(i) for i in start_word]
+        final_caption = []
+
+        for i in intermediate_caption:
+            if i != self.end_token:
+                final_caption.append(i)
+            else:
+                break
+
+        final_caption = " ".join(final_caption[1:])
+        return final_caption
+
+    def ensemble_beam_search(self, models:list, images_list:list)->str:
+        """ Executes the beam search algorithm employing the pre-trained models along with the test instances data.
+        This utilises the beam search algorithm for each model in models list.
+
+        Args:
+            models (list): The models we want to evaluate on our employed dataset
+            images_list (list): Current test images embeddings for each encoder we used.
+
+        Returns:
+            str: The generated description for the given image using the beam search.
+        """
+        start = [self.get_idx(self.start_token)]
+        start_word = [[start, 0.0]]
+        while len(start_word[0][0]) < self.max_length:
+            # for current seq
+            for s in start_word:
+                # pad current caption
+                current_caption = tf.keras.preprocessing.sequence.pad_sequences([s[0]], maxlen=self.max_length, padding="post")
+                # get all predictions from the pre-trained models
+                ensemble_predictions = [ ensemble_member.predict([image, current_caption], verbose=0) for ensemble_member, image in zip(models, images_list) ]
+                # get the best pairs
+                ensemble_word_predictions = [ np.argsort(prediction[0])[-self.beam_index:] for prediction in ensemble_predictions ]
+
+                # and store them with word,pairs for each ensemble member
+                ensemble_current_probs = list()
+                for member_prediction, member_word_predictions in zip(ensemble_predictions, ensemble_word_predictions):
+                    temp_current_seq = list()
+
+                    for word in member_word_predictions:
+                        next_cap, prob = s[0][:], s[1]
+                        next_cap.append(word)
+                        prob += member_prediction[0][word]
+                        temp_current_seq.append([next_cap, prob])
+
+                    ensemble_current_probs.append(temp_current_seq)
+            # get all the best candidates
+            ensemble_starting_words = [
+                sorted(current_probs, reverse=False, key=lambda l: l[1])[-self.beam_index:] for current_probs in ensemble_current_probs
+            ]
+
+            start_words = [member_starting_words[-1] for member_starting_words in ensemble_starting_words]
+            start_word = sorted(start_words, reverse=False, key=lambda l: l[1])
+
+        # create the caption
+        start_word = start_word[-1][0]
+        intermediate_caption = [self.get_word(i) for i in start_word]
+        final_caption = []
+
+        for i in intermediate_caption:
+            if i != self.end_token:
+                final_caption.append(i)
+            else:
+                break
+
+        final_caption = " ".join(final_caption[1:])
+        return final_caption