Diff of /layers/group_norm.py [000000] .. [408896]

Switch to unified view

a b/layers/group_norm.py
1
"""
2
    Contains Keras group normalization class from
3
    https://github.com/titu1994/Keras-Group-Normalization/blob/master/group_norm.py
4
"""
5
import tensorflow as tf
6
from tensorflow.keras import initializers, constraints, regularizers
7
8
9
class GroupNormalization(tf.keras.layers.Layer):
10
    def __init__(self,
11
                 groups=8,
12
                 axis=-1,
13
                 epsilon=1e-5,
14
                 center=True,
15
                 scale=True,
16
                 beta_initializer='zeros',
17
                 gamma_initializer='ones',
18
                 beta_regularizer=None,
19
                 gamma_regularizer=None,
20
                 beta_constraint=None,
21
                 gamma_constraint=None,
22
                 **kwargs):
23
        """ Initializes one group normalization layer.
24
25
            References:
26
                - [Group Normalization](https://arxiv.org/abs/1803.08494)
27
        """
28
        super(GroupNormalization, self).__init__(**kwargs)
29
        self.supports_masking = True
30
        self.groups = groups
31
        self.axis = axis
32
        self.epsilon = epsilon
33
        self.center = center
34
        self.scale = scale
35
        self.beta_initializer = initializers.get(beta_initializer)
36
        self.gamma_initializer = initializers.get(gamma_initializer)
37
        self.beta_regularizer = regularizers.get(beta_regularizer)
38
        self.gamma_regularizer = regularizers.get(gamma_regularizer)
39
        self.beta_constraint = constraints.get(beta_constraint)
40
        self.gamma_constraint = constraints.get(gamma_constraint)
41
42
    def build(self, input_shape):
43
        dim = input_shape[self.axis]
44
45
        if dim is None:
46
            raise ValueError('Axis ' + str(self.axis) + ' of '
47
                             'input tensor should have a defined dimension '
48
                             'but the layer received an input with shape ' +
49
                             str(input_shape) + '.')
50
51
        if dim < self.groups:
52
            raise ValueError('Number of groups (' + str(self.groups) + ') cannot be '
53
                             'more than the number of channels (' +
54
                             str(dim) + ').')
55
56
        if dim % self.groups != 0:
57
            raise ValueError('Number of groups (' + str(self.groups) + ') must be a '
58
                             'multiple of the number of channels (' +
59
                             str(dim) + ').')
60
61
        self.input_spec = tf.keras.layers.InputSpec(ndim=len(input_shape),
62
                                    axes={self.axis: dim})
63
        shape = (dim,)
64
65
        if self.scale:
66
            self.gamma = self.add_weight(shape=shape,
67
                                         name='gamma',
68
                                         initializer=self.gamma_initializer,
69
                                         regularizer=self.gamma_regularizer,
70
                                         constraint=self.gamma_constraint)
71
        else:
72
            self.gamma = None
73
        if self.center:
74
            self.beta = self.add_weight(shape=shape,
75
                                        name='beta',
76
                                        initializer=self.beta_initializer,
77
                                        regularizer=self.beta_regularizer,
78
                                        constraint=self.beta_constraint)
79
        else:
80
            self.beta = None
81
        self.built = True
82
83
    def call(self, inputs, training=None, **kwargs):
84
        input_shape = list(inputs.shape)
85
86
        # Prepare broadcasting shape.
87
        reduction_axes = list(range(len(input_shape)))
88
        del reduction_axes[self.axis]
89
        broadcast_shape = [1] * len(input_shape)
90
        broadcast_shape[self.axis] = input_shape[self.axis] // self.groups
91
        broadcast_shape.insert(1, self.groups)
92
93
        group_axes = [input_shape[i] for i in range(len(input_shape))]
94
        group_axes[self.axis] = input_shape[self.axis] // self.groups
95
        group_axes.insert(1, self.groups)
96
97
        # Reshape inputs to new group shape.
98
        group_shape = [group_axes[0], self.groups] + group_axes[2:]
99
        group_shape = tf.stack(group_shape)
100
        inputs = tf.reshape(inputs, group_shape)
101
102
        group_reduction_axes = list(range(len(group_axes)))
103
        group_reduction_axes = group_reduction_axes[2:]
104
105
        mean, variance = tf.nn.moments(inputs, axes=group_reduction_axes, keepdims=True)
106
107
        inputs = (inputs - mean) / (tf.math.sqrt(variance + self.epsilon))
108
109
        # Prepare broadcast shape.
110
        inputs = tf.reshape(inputs, group_shape)
111
        outputs = inputs
112
113
        # In this case we must explicitly broadcast all parameters.
114
        if self.scale:
115
            broadcast_gamma = tf.reshape(self.gamma, broadcast_shape)
116
            outputs = outputs * broadcast_gamma
117
118
        if self.center:
119
            broadcast_beta = tf.reshape(self.beta, broadcast_shape)
120
            outputs = outputs + broadcast_beta
121
122
        outputs = tf.reshape(outputs, input_shape)
123
124
        return outputs
125
126
    def get_config(self):
127
        config = {
128
            'groups': self.groups,
129
            'axis': self.axis,
130
            'epsilon': self.epsilon,
131
            'center': self.center,
132
            'scale': self.scale,
133
            'beta_initializer': initializers.serialize(self.beta_initializer),
134
            'gamma_initializer': initializers.serialize(self.gamma_initializer),
135
            'beta_regularizer': regularizers.serialize(self.beta_regularizer),
136
            'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
137
            'beta_constraint': constraints.serialize(self.beta_constraint),
138
            'gamma_constraint': constraints.serialize(self.gamma_constraint)
139
        }
140
        base_config = super(GroupNormalization, self).get_config()
141
        return dict(list(base_config.items()) + list(config.items()))
142
143
    def compute_output_shape(self, input_shape):
144
        return input_shape