|
a |
|
b/modules/beam_search.py |
|
|
1 |
# numpy import |
|
|
2 |
import numpy as np |
|
|
3 |
|
|
|
4 |
# tensorflow imports |
|
|
5 |
import tensorflow as tf |
|
|
6 |
import tensorflow |
|
|
7 |
from tensorflow.keras.preprocessing.text import Tokenizer |
|
|
8 |
from tensorflow.keras.models import Model |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
class BeamSearch: |
|
|
12 |
def __init__(self, start_token:str, end_token:str, max_length:int, tokenizer:Tokenizer, idx_to_word:dict, word_to_idx:dict, beam_index:int): |
|
|
13 |
""" The Beam Search sampling method for generating captions. An illustration of the algorithm is provided in my Thesis paper. |
|
|
14 |
|
|
|
15 |
Args: |
|
|
16 |
start_token (str): The start-token used during pre-processing of the training captions |
|
|
17 |
end_token (str): The end-token used during pre-processing of the training captions |
|
|
18 |
max_length (int): The maximum length (limit) for the generated captions |
|
|
19 |
tokenizer (Tokenizer): The fitted tokenizer from the Vocabulary object |
|
|
20 |
idx_to_word (dict): Dictionary with keys to be the index number and values the words in the created vocabulary |
|
|
21 |
word_to_idx (dict): Dictionary with keys to be the words and values the index number in the created vocabulary |
|
|
22 |
beam_index (int): The beam size for the Beam Seach algorithm. |
|
|
23 |
""" |
|
|
24 |
self.start_token = start_token |
|
|
25 |
self.end_token = end_token |
|
|
26 |
self.max_length = max_length |
|
|
27 |
self.tokenizer = tokenizer |
|
|
28 |
self.idx_to_word = idx_to_word |
|
|
29 |
self.word_to_idx = word_to_idx |
|
|
30 |
self.beam_index = beam_index |
|
|
31 |
|
|
|
32 |
def get_word(self, idx:int) -> str: |
|
|
33 |
""" Fetches the word from the index-to-word vocab, which was created after the pre-processing of the Training captions |
|
|
34 |
|
|
|
35 |
Args: |
|
|
36 |
idx (int): The index for the index-to-word vocab. |
|
|
37 |
|
|
|
38 |
Returns: |
|
|
39 |
str: The word for the given index if exist in the created index-to-word vocab, else None |
|
|
40 |
""" |
|
|
41 |
return self.idx_to_word.get(idx, None) |
|
|
42 |
|
|
|
43 |
def get_idx(self, word:str)->int: |
|
|
44 |
""" Fetches the index number from the word-to-index vocab, which was created after the pre-processing of the Training captions |
|
|
45 |
|
|
|
46 |
Args: |
|
|
47 |
word (str): The word for which we want its index in the word-to-index dictionary. |
|
|
48 |
|
|
|
49 |
Returns: |
|
|
50 |
int: The index for the given word if exist in the created word-to-index vocab, else -1. The latter number refer to None |
|
|
51 |
""" |
|
|
52 |
return self.word_to_idx.get(word, -1) |
|
|
53 |
|
|
|
54 |
def beam_search_predict(self, model:Model, image:np.array, tag:np.array, dataset:str='iuxray', multi_modal:bool=False)->str: |
|
|
55 |
""" Executes the beam search algorithm employing the pre-trained model along with the test instance's data. |
|
|
56 |
|
|
|
57 |
Args: |
|
|
58 |
model (Model): The model we want to evaluate on our employed dataset |
|
|
59 |
image (np.array): Current test image embedding |
|
|
60 |
tag (np.array): The tag embedding for the current test instance. This is used only for IU X-Ray dataset. |
|
|
61 |
dataset (str, optional): The dataset we employed for the model. Defaults to 'iuxray'. |
|
|
62 |
multi_modal (bool, optional): If we want to use the multi-modal version of model. This is used only for IU X-Ray dataset. Defaults to False. |
|
|
63 |
|
|
|
64 |
Returns: |
|
|
65 |
str: The generated description for the given image using the beam search. |
|
|
66 |
""" |
|
|
67 |
start = [self.get_idx(self.start_token)] |
|
|
68 |
start_word = [[start, 0.0]] |
|
|
69 |
while len(start_word[0][0]) < self.max_length: |
|
|
70 |
# store current word,probs pairs |
|
|
71 |
temp = [] |
|
|
72 |
# for current sequence |
|
|
73 |
for s in start_word: |
|
|
74 |
# pad the sequence in order to fetch the next token |
|
|
75 |
par_caps = tf.keras.preprocessing.sequence.pad_sequences([s[0]], maxlen=self.max_length, padding="post") |
|
|
76 |
if multi_modal: |
|
|
77 |
if dataset == 'iuxray': |
|
|
78 |
preds = model.predict( |
|
|
79 |
[image[0], image[1], tag, par_caps], verbose=0) |
|
|
80 |
else: |
|
|
81 |
preds = model.predict( |
|
|
82 |
[image, tag, par_caps], verbose=0) |
|
|
83 |
else: |
|
|
84 |
if dataset == 'iuxray': |
|
|
85 |
preds = model.predict( |
|
|
86 |
[image[0], image[1], par_caps], verbose=0) |
|
|
87 |
else: |
|
|
88 |
preds = model.predict([image, par_caps], verbose=0) |
|
|
89 |
# get the best paths |
|
|
90 |
word_preds = np.argsort(preds[0])[-self.beam_index:] |
|
|
91 |
|
|
|
92 |
# Getting the top <self.self.beam_index>(n) predictions and creating a |
|
|
93 |
# new list so as to put them via the model again |
|
|
94 |
for w in word_preds: |
|
|
95 |
next_cap, prob = s[0][:], s[1] |
|
|
96 |
next_cap.append(w) |
|
|
97 |
prob += preds[0][w] |
|
|
98 |
temp.append([next_cap, prob]) |
|
|
99 |
|
|
|
100 |
start_word = temp |
|
|
101 |
# Sorting according to the probabilities |
|
|
102 |
start_word = sorted(start_word, reverse=False, key=lambda l: l[1]) |
|
|
103 |
# Getting the top words |
|
|
104 |
start_word = start_word[-self.beam_index:] |
|
|
105 |
|
|
|
106 |
# get the best path |
|
|
107 |
start_word = start_word[-1][0] |
|
|
108 |
intermediate_caption = [self.get_word(i) for i in start_word] |
|
|
109 |
final_caption = [] |
|
|
110 |
|
|
|
111 |
for i in intermediate_caption: |
|
|
112 |
if i != self.end_token: |
|
|
113 |
final_caption.append(i) |
|
|
114 |
else: |
|
|
115 |
break |
|
|
116 |
|
|
|
117 |
final_caption = " ".join(final_caption[1:]) |
|
|
118 |
return final_caption |
|
|
119 |
|
|
|
120 |
def ensemble_beam_search(self, models:list, images_list:list)->str: |
|
|
121 |
""" Executes the beam search algorithm employing the pre-trained models along with the test instances data. |
|
|
122 |
This utilises the beam search algorithm for each model in models list. |
|
|
123 |
|
|
|
124 |
Args: |
|
|
125 |
models (list): The models we want to evaluate on our employed dataset |
|
|
126 |
images_list (list): Current test images embeddings for each encoder we used. |
|
|
127 |
|
|
|
128 |
Returns: |
|
|
129 |
str: The generated description for the given image using the beam search. |
|
|
130 |
""" |
|
|
131 |
start = [self.get_idx(self.start_token)] |
|
|
132 |
start_word = [[start, 0.0]] |
|
|
133 |
while len(start_word[0][0]) < self.max_length: |
|
|
134 |
# for current seq |
|
|
135 |
for s in start_word: |
|
|
136 |
# pad current caption |
|
|
137 |
current_caption = tf.keras.preprocessing.sequence.pad_sequences([s[0]], maxlen=self.max_length, padding="post") |
|
|
138 |
# get all predictions from the pre-trained models |
|
|
139 |
ensemble_predictions = [ ensemble_member.predict([image, current_caption], verbose=0) for ensemble_member, image in zip(models, images_list) ] |
|
|
140 |
# get the best pairs |
|
|
141 |
ensemble_word_predictions = [ np.argsort(prediction[0])[-self.beam_index:] for prediction in ensemble_predictions ] |
|
|
142 |
|
|
|
143 |
# and store them with word,pairs for each ensemble member |
|
|
144 |
ensemble_current_probs = list() |
|
|
145 |
for member_prediction, member_word_predictions in zip(ensemble_predictions, ensemble_word_predictions): |
|
|
146 |
temp_current_seq = list() |
|
|
147 |
|
|
|
148 |
for word in member_word_predictions: |
|
|
149 |
next_cap, prob = s[0][:], s[1] |
|
|
150 |
next_cap.append(word) |
|
|
151 |
prob += member_prediction[0][word] |
|
|
152 |
temp_current_seq.append([next_cap, prob]) |
|
|
153 |
|
|
|
154 |
ensemble_current_probs.append(temp_current_seq) |
|
|
155 |
# get all the best candidates |
|
|
156 |
ensemble_starting_words = [ |
|
|
157 |
sorted(current_probs, reverse=False, key=lambda l: l[1])[-self.beam_index:] for current_probs in ensemble_current_probs |
|
|
158 |
] |
|
|
159 |
|
|
|
160 |
start_words = [member_starting_words[-1] for member_starting_words in ensemble_starting_words] |
|
|
161 |
start_word = sorted(start_words, reverse=False, key=lambda l: l[1]) |
|
|
162 |
|
|
|
163 |
# create the caption |
|
|
164 |
start_word = start_word[-1][0] |
|
|
165 |
intermediate_caption = [self.get_word(i) for i in start_word] |
|
|
166 |
final_caption = [] |
|
|
167 |
|
|
|
168 |
for i in intermediate_caption: |
|
|
169 |
if i != self.end_token: |
|
|
170 |
final_caption.append(i) |
|
|
171 |
else: |
|
|
172 |
break |
|
|
173 |
|
|
|
174 |
final_caption = " ".join(final_caption[1:]) |
|
|
175 |
return final_caption |