a b/group_norm.py
1
from keras.engine import Layer, InputSpec
2
from keras import initializers
3
from keras import regularizers
4
from keras import constraints
5
from keras import backend as K
6
7
from keras.utils.generic_utils import get_custom_objects
8
9
10
class GroupNormalization(Layer):
11
    """Group normalization layer
12
13
    Group Normalization divides the channels into groups and computes within each group
14
    the mean and variance for normalization. GN's computation is independent of batch sizes,
15
    and its accuracy is stable in a wide range of batch sizes
16
17
    # Arguments
18
        groups: Integer, the number of groups for Group Normalization.
19
        axis: Integer, the axis that should be normalized
20
            (typically the features axis).
21
            For instance, after a `Conv2D` layer with
22
            `data_format="channels_first"`,
23
            set `axis=1` in `BatchNormalization`.
24
        epsilon: Small float added to variance to avoid dividing by zero.
25
        center: If True, add offset of `beta` to normalized tensor.
26
            If False, `beta` is ignored.
27
        scale: If True, multiply by `gamma`.
28
            If False, `gamma` is not used.
29
            When the next layer is linear (also e.g. `nn.relu`),
30
            this can be disabled since the scaling
31
            will be done by the next layer.
32
        beta_initializer: Initializer for the beta weight.
33
        gamma_initializer: Initializer for the gamma weight.
34
        beta_regularizer: Optional regularizer for the beta weight.
35
        gamma_regularizer: Optional regularizer for the gamma weight.
36
        beta_constraint: Optional constraint for the beta weight.
37
        gamma_constraint: Optional constraint for the gamma weight.
38
39
    # Input shape
40
        Arbitrary. Use the keyword argument `input_shape`
41
        (tuple of integers, does not include the samples axis)
42
        when using this layer as the first layer in a model.
43
44
    # Output shape
45
        Same shape as input.
46
47
    # References
48
        - [Group Normalization](https://arxiv.org/abs/1803.08494)
49
    """
50
51
    def __init__(self,
52
                 groups=32,
53
                 axis=-1,
54
                 epsilon=1e-5,
55
                 center=True,
56
                 scale=True,
57
                 beta_initializer='zeros',
58
                 gamma_initializer='ones',
59
                 beta_regularizer=None,
60
                 gamma_regularizer=None,
61
                 beta_constraint=None,
62
                 gamma_constraint=None,
63
                 **kwargs):
64
        super(GroupNormalization, self).__init__(**kwargs)
65
        self.supports_masking = True
66
        self.groups = groups
67
        self.axis = axis
68
        self.epsilon = epsilon
69
        self.center = center
70
        self.scale = scale
71
        self.beta_initializer = initializers.get(beta_initializer)
72
        self.gamma_initializer = initializers.get(gamma_initializer)
73
        self.beta_regularizer = regularizers.get(beta_regularizer)
74
        self.gamma_regularizer = regularizers.get(gamma_regularizer)
75
        self.beta_constraint = constraints.get(beta_constraint)
76
        self.gamma_constraint = constraints.get(gamma_constraint)
77
78
    def build(self, input_shape):
79
        dim = input_shape[self.axis]
80
81
        if dim is None:
82
            raise ValueError('Axis ' + str(self.axis) + ' of '
83
                             'input tensor should have a defined dimension '
84
                             'but the layer received an input with shape ' +
85
                             str(input_shape) + '.')
86
87
        if dim < self.groups:
88
            raise ValueError('Number of groups (' + str(self.groups) + ') cannot be '
89
                             'more than the number of channels (' +
90
                             str(dim) + ').')
91
92
        if dim % self.groups != 0:
93
            raise ValueError('Number of groups (' + str(self.groups) + ') must be a '
94
                             'multiple of the number of channels (' +
95
                             str(dim) + ').')
96
97
        self.input_spec = InputSpec(ndim=len(input_shape),
98
                                    axes={self.axis: dim})
99
        shape = (dim,)
100
101
        if self.scale:
102
            self.gamma = self.add_weight(shape=shape,
103
                                         name='gamma',
104
                                         initializer=self.gamma_initializer,
105
                                         regularizer=self.gamma_regularizer,
106
                                         constraint=self.gamma_constraint)
107
        else:
108
            self.gamma = None
109
        if self.center:
110
            self.beta = self.add_weight(shape=shape,
111
                                        name='beta',
112
                                        initializer=self.beta_initializer,
113
                                        regularizer=self.beta_regularizer,
114
                                        constraint=self.beta_constraint)
115
        else:
116
            self.beta = None
117
        self.built = True
118
119
    def call(self, inputs, **kwargs):
120
        input_shape = K.int_shape(inputs)
121
        tensor_input_shape = K.shape(inputs)
122
123
        # Prepare broadcasting shape.
124
        reduction_axes = list(range(len(input_shape)))
125
        del reduction_axes[self.axis]
126
        broadcast_shape = [1] * len(input_shape)
127
        broadcast_shape[self.axis] = input_shape[self.axis] // self.groups
128
        broadcast_shape.insert(1, self.groups)
129
130
        reshape_group_shape = K.shape(inputs)
131
        group_axes = [reshape_group_shape[i] for i in range(len(input_shape))]
132
        group_axes[self.axis] = input_shape[self.axis] // self.groups
133
        group_axes.insert(1, self.groups)
134
135
        # reshape inputs to new group shape
136
        group_shape = [group_axes[0], self.groups] + group_axes[2:]
137
        group_shape = K.stack(group_shape)
138
        inputs = K.reshape(inputs, group_shape)
139
140
        group_reduction_axes = list(range(len(group_axes)))
141
        group_reduction_axes = group_reduction_axes[2:]
142
143
        mean = K.mean(inputs, axis=group_reduction_axes, keepdims=True)
144
        variance = K.var(inputs, axis=group_reduction_axes, keepdims=True)
145
146
        inputs = (inputs - mean) / (K.sqrt(variance + self.epsilon))
147
148
        # prepare broadcast shape
149
        inputs = K.reshape(inputs, group_shape)
150
        outputs = inputs
151
152
        # In this case we must explicitly broadcast all parameters.
153
        if self.scale:
154
            broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
155
            outputs = outputs * broadcast_gamma
156
157
        if self.center:
158
            broadcast_beta = K.reshape(self.beta, broadcast_shape)
159
            outputs = outputs + broadcast_beta
160
161
        outputs = K.reshape(outputs, tensor_input_shape)
162
163
        return outputs
164
165
    def get_config(self):
166
        config = {
167
            'groups': self.groups,
168
            'axis': self.axis,
169
            'epsilon': self.epsilon,
170
            'center': self.center,
171
            'scale': self.scale,
172
            'beta_initializer': initializers.serialize(self.beta_initializer),
173
            'gamma_initializer': initializers.serialize(self.gamma_initializer),
174
            'beta_regularizer': regularizers.serialize(self.beta_regularizer),
175
            'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
176
            'beta_constraint': constraints.serialize(self.beta_constraint),
177
            'gamma_constraint': constraints.serialize(self.gamma_constraint)
178
        }
179
        base_config = super(GroupNormalization, self).get_config()
180
        return dict(list(base_config.items()) + list(config.items()))
181
182
    def compute_output_shape(self, input_shape):
183
        return input_shape
184
185
186
get_custom_objects().update({'GroupNormalization': GroupNormalization})
187
188
189
if __name__ == '__main__':
190
    from keras.layers import Input
191
    from keras.models import Model
192
    ip = Input(shape=(None, None, 4))
193
    #ip = Input(batch_shape=(100, None, None, 2))
194
    x = GroupNormalization(groups=2, axis=-1, epsilon=0.1)(ip)
195
    model = Model(ip, x)
196
    model.summary()
197