Diff of /attention_models.py [000000] .. [56a69a]

Switch to unified view

a b/attention_models.py
1
""" 
2
Copyright (C) 2022 King Saud University, Saudi Arabia 
3
SPDX-License-Identifier: Apache-2.0 
4
5
Licensed under the Apache License, Version 2.0 (the "License"); you may not use
6
this file except in compliance with the License. You may obtain a copy of the 
7
License at
8
9
http://www.apache.org/licenses/LICENSE-2.0  
10
11
Unless required by applicable law or agreed to in writing, software distributed
12
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 
13
CONDITIONS OF ANY KIND, either express or implied. See the License for the
14
specific language governing permissions and limitations under the License. 
15
16
Author:  Hamdi Altaheri 
17
"""
18
19
#%%
20
import math
21
import tensorflow as tf
22
from tensorflow.keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, Reshape, Dense
23
from tensorflow.keras.layers import multiply, Permute, Concatenate, Conv2D, Add, Activation, Lambda
24
from tensorflow.keras.layers import Dropout, MultiHeadAttention, LayerNormalization, Reshape
25
from tensorflow.keras import backend as K
26
27
28
#%% Create and apply the attention model
29
def attention_block(in_layer, attention_model, ratio=8, residual = False, apply_to_input=True): 
30
    in_sh = in_layer.shape # dimensions of the input tensor
31
    in_len = len(in_sh) 
32
    expanded_axis = 2 # defualt = 2
33
    
34
    if attention_model == 'mha':   # Multi-head self attention layer 
35
        if(in_len > 3):
36
            in_layer = Reshape((in_sh[1],-1))(in_layer)
37
        out_layer = mha_block(in_layer)
38
    elif attention_model == 'mhla':  # Multi-head local self-attention layer 
39
        if(in_len > 3):
40
            in_layer = Reshape((in_sh[1],-1))(in_layer)
41
        out_layer = mha_block(in_layer, vanilla = False)
42
    elif attention_model == 'se':   # Squeeze-and-excitation layer
43
        if(in_len < 4):
44
            in_layer = tf.expand_dims(in_layer, axis=expanded_axis)
45
        out_layer = se_block(in_layer, ratio, residual, apply_to_input)
46
    elif attention_model == 'cbam': # Convolutional block attention module
47
        if(in_len < 4):
48
            in_layer = tf.expand_dims(in_layer, axis=expanded_axis)
49
        out_layer = cbam_block(in_layer, ratio=ratio, residual = residual)
50
    else:
51
        raise Exception("'{}' is not supported attention module!".format(attention_model))
52
        
53
    if (in_len == 3 and len(out_layer.shape) == 4):
54
        out_layer = tf.squeeze(out_layer, expanded_axis)
55
    elif (in_len == 4 and len(out_layer.shape) == 3):
56
        out_layer = Reshape((in_sh[1], in_sh[2], in_sh[3]))(out_layer)
57
    return out_layer
58
59
60
#%% Multi-head self Attention (MHA) block
61
def mha_block(input_feature, key_dim=8, num_heads=2, dropout = 0.5, vanilla = True):
62
    """Multi Head self Attention (MHA) block.     
63
       
64
    Here we include two types of MHA blocks: 
65
            The original multi-head self-attention as described in https://arxiv.org/abs/1706.03762
66
            The multi-head local self attention as described in https://arxiv.org/abs/2112.13492v1
67
    """    
68
    # Layer normalization
69
    x = LayerNormalization(epsilon=1e-6)(input_feature)
70
    
71
    if vanilla:
72
        # Create a multi-head attention layer as described in 
73
        # 'Attention Is All You Need' https://arxiv.org/abs/1706.03762
74
        x = MultiHeadAttention(key_dim = key_dim, num_heads = num_heads, dropout = dropout)(x, x)
75
    else:
76
        # Create a multi-head local self-attention layer as described in 
77
        # 'Vision Transformer for Small-Size Datasets' https://arxiv.org/abs/2112.13492v1
78
        
79
        # Build the diagonal attention mask
80
        NUM_PATCHES = input_feature.shape[1]
81
        diag_attn_mask = 1 - tf.eye(NUM_PATCHES)
82
        diag_attn_mask = tf.cast([diag_attn_mask], dtype=tf.int8)
83
        
84
        # Create a multi-head local self attention layer.
85
        # x = MultiHeadAttention_LSA(key_dim = key_dim, num_heads = num_heads, dropout = dropout)(
86
        #     x, x, attention_mask = diag_attn_mask)
87
        x = MultiHeadAttention_LSA(key_dim = key_dim, num_heads = num_heads, dropout = dropout)(
88
            x, x, attention_mask = diag_attn_mask)
89
    x = Dropout(0.3)(x)
90
    # Skip connection
91
    mha_feature = Add()([input_feature, x])
92
    
93
    return mha_feature
94
95
96
#%% Multi head self Attention (MHA) block: Locality Self Attention (LSA)
97
class MultiHeadAttention_LSA(tf.keras.layers.MultiHeadAttention):
98
    """local multi-head self attention block
99
     
100
     Locality Self Attention as described in https://arxiv.org/abs/2112.13492v1
101
     This implementation is taken from  https://keras.io/examples/vision/vit_small_ds/ 
102
    """    
103
    def __init__(self, **kwargs):
104
        super().__init__(**kwargs)
105
        # The trainable temperature term. The initial value is the square 
106
        # root of the key dimension.
107
        self.tau = tf.Variable(math.sqrt(float(self._key_dim)), trainable=True)
108
109
    def _compute_attention(self, query, key, value, attention_mask=None, training=None):
110
        query = tf.multiply(query, 1.0 / self.tau)
111
        attention_scores = tf.einsum(self._dot_product_equation, key, query)
112
        attention_scores = self._masked_softmax(attention_scores, attention_mask)
113
        attention_scores_dropout = self._dropout_layer(
114
            attention_scores, training=training
115
        )
116
        attention_output = tf.einsum(
117
            self._combine_equation, attention_scores_dropout, value
118
        )
119
        return attention_output, attention_scores
120
121
122
#%% Squeeze-and-excitation block
123
def se_block(input_feature, ratio=8, residual = False, apply_to_input=True):
124
    """Squeeze-and-Excitation(SE) block.
125
    
126
    As described in https://arxiv.org/abs/1709.01507
127
    The implementation is taken from https://github.com/kobiso/CBAM-keras
128
    """
129
    channel_axis = 1 if K.image_data_format() == "channels_first" else -1
130
    channel = input_feature.shape[channel_axis]
131
132
    se_feature = GlobalAveragePooling2D()(input_feature)
133
    se_feature = Reshape((1, 1, channel))(se_feature)
134
    assert se_feature.shape[1:] == (1,1,channel)
135
    if (ratio != 0):
136
        se_feature = Dense(channel // ratio,
137
                           activation='relu',
138
                           kernel_initializer='he_normal',
139
                           use_bias=True,
140
                           bias_initializer='zeros')(se_feature)
141
        assert se_feature.shape[1:] == (1,1,channel//ratio)
142
    se_feature = Dense(channel,
143
                       activation='sigmoid',
144
                       kernel_initializer='he_normal',
145
                       use_bias=True,
146
                       bias_initializer='zeros')(se_feature)
147
    assert se_feature.shape[1:] == (1,1,channel)
148
    if K.image_data_format() == 'channels_first':
149
        se_feature = Permute((3, 1, 2))(se_feature)
150
        
151
    if(apply_to_input):
152
        se_feature = multiply([input_feature, se_feature])
153
    
154
    # Residual Connection
155
    if(residual): 
156
        se_feature = Add()([se_feature, input_feature])
157
158
    return se_feature
159
160
161
#%% Convolutional block attention module
162
def cbam_block(input_feature, ratio=8, residual = False):
163
    """ Convolutional Block Attention Module(CBAM) block.
164
    
165
    As described in https://arxiv.org/abs/1807.06521
166
    The implementation is taken from https://github.com/kobiso/CBAM-keras
167
    """
168
    
169
    cbam_feature = channel_attention(input_feature, ratio)
170
    cbam_feature = spatial_attention(cbam_feature)
171
    
172
    # Residual Connection
173
    if(residual): 
174
        cbam_feature = Add()([input_feature, cbam_feature])
175
176
    return cbam_feature
177
178
def channel_attention(input_feature, ratio=8):
179
    channel_axis = 1 if K.image_data_format() == "channels_first" else -1
180
#     channel = input_feature._keras_shape[channel_axis]
181
    channel = input_feature.shape[channel_axis]
182
    
183
    shared_layer_one = Dense(channel//ratio,
184
                             activation='relu',
185
                             kernel_initializer='he_normal',
186
                             use_bias=True,
187
                             bias_initializer='zeros')
188
    shared_layer_two = Dense(channel,
189
                             kernel_initializer='he_normal',
190
                             use_bias=True,
191
                             bias_initializer='zeros')
192
    
193
    avg_pool = GlobalAveragePooling2D()(input_feature)    
194
    avg_pool = Reshape((1,1,channel))(avg_pool)
195
    assert avg_pool.shape[1:] == (1,1,channel)
196
    avg_pool = shared_layer_one(avg_pool)
197
    assert avg_pool.shape[1:] == (1,1,channel//ratio)
198
    avg_pool = shared_layer_two(avg_pool)
199
    assert avg_pool.shape[1:] == (1,1,channel)
200
    
201
    max_pool = GlobalMaxPooling2D()(input_feature)
202
    max_pool = Reshape((1,1,channel))(max_pool)
203
    assert max_pool.shape[1:] == (1,1,channel)
204
    max_pool = shared_layer_one(max_pool)
205
    assert max_pool.shape[1:] == (1,1,channel//ratio)
206
    max_pool = shared_layer_two(max_pool)
207
    assert max_pool.shape[1:] == (1,1,channel)
208
    
209
    cbam_feature = Add()([avg_pool,max_pool])
210
    cbam_feature = Activation('sigmoid')(cbam_feature)
211
    
212
    if K.image_data_format() == "channels_first":
213
        cbam_feature = Permute((3, 1, 2))(cbam_feature)
214
    
215
    return multiply([input_feature, cbam_feature])
216
217
def spatial_attention(input_feature):
218
    kernel_size = 7
219
    
220
    if K.image_data_format() == "channels_first":
221
        channel = input_feature.shape[1]
222
        cbam_feature = Permute((2,3,1))(input_feature)
223
    else:
224
        channel = input_feature.shape[-1]
225
        cbam_feature = input_feature
226
    
227
    avg_pool = Lambda(lambda x: K.mean(x, axis=3, keepdims=True))(cbam_feature)
228
    assert avg_pool.shape[-1] == 1
229
    max_pool = Lambda(lambda x: K.max(x, axis=3, keepdims=True))(cbam_feature)
230
    assert max_pool.shape[-1] == 1
231
    concat = Concatenate(axis=3)([avg_pool, max_pool])
232
    assert concat.shape[-1] == 2
233
    cbam_feature = Conv2D(filters = 1,
234
                    kernel_size=kernel_size,
235
                    strides=1,
236
                    padding='same',
237
                    activation='sigmoid',
238
                    kernel_initializer='he_normal',
239
                    use_bias=False)(concat)    
240
    assert cbam_feature.shape[-1] == 1
241
    
242
    if K.image_data_format() == "channels_first":
243
        cbam_feature = Permute((3, 1, 2))(cbam_feature)
244
        
245
    return multiply([input_feature, cbam_feature])
246
        
247