Diff of /layers/aggregator.py [000000] .. [c0da92]

Switch to unified view

a b/layers/aggregator.py
1
# -*- coding: utf-8 -*-
2
3
from keras.engine.topology import Layer
4
from keras import backend as K
5
6
# class AvgAggregator(Layer):
7
#     def __init__(self, activation: str ='relu', initializer='glorot_normal', regularizer=None,
8
#                  **kwargs):
9
#         super(AvgAggregator, self).__init__(**kwargs)
10
#         if activation == 'relu':
11
#             self.activation = K.relu
12
#         elif activation == 'tanh':
13
#             self.activation = K.tanh
14
#         else:
15
#             raise ValueError(f'`activation` not understood: {activation}')
16
#         self.initializer = initializer
17
#         self.regularizer = regularizer
18
#     def build(self, input_shape):
19
#         ent_embed_dim = input_shape[0][-1]
20
#         self.w = self.add_weight(name=self.name+'_w', shape=(ent_embed_dim, ent_embed_dim),
21
#                                  initializer=self.initializer, regularizer=self.regularizer)
22
#         self.b = self.add_weight(name=self.name+'_b', shape=(ent_embed_dim,), initializer='zeros')
23
#         super(SumAggregator, self).build(input_shape) 
24
25
26
27
class SumAggregator(Layer):
28
    def __init__(self, activation: str ='relu', initializer='glorot_normal', regularizer=None,
29
                 **kwargs):
30
        super(SumAggregator, self).__init__(**kwargs)
31
        if activation == 'relu':
32
            self.activation = K.relu
33
        elif activation == 'tanh':
34
            self.activation = K.tanh
35
        else:
36
            raise ValueError(f'`activation` not understood: {activation}')
37
        self.initializer = initializer
38
        self.regularizer = regularizer
39
40
    def build(self, input_shape):
41
        ent_embed_dim = input_shape[0][-1]
42
        self.w = self.add_weight(name=self.name+'_w', shape=(ent_embed_dim, ent_embed_dim),
43
                                 initializer=self.initializer, regularizer=self.regularizer)
44
        self.b = self.add_weight(name=self.name+'_b', shape=(ent_embed_dim,), initializer='zeros')
45
        super(SumAggregator, self).build(input_shape)
46
47
    def call(self, inputs, **kwargs):
48
        entity, neighbor = inputs
49
        return self.activation(K.dot((entity + neighbor), self.w) + self.b)
50
51
    def compute_output_shape(self, input_shape):
52
        return input_shape[0]
53
54
55
class ConcatAggregator(Layer):
56
    def __init__(self, activation: str = 'relu', initializer='glorot_normal', regularizer=None,
57
                 **kwargs):
58
        super(ConcatAggregator, self).__init__(**kwargs)
59
        if activation == 'relu':
60
            self.activation = K.relu
61
        elif activation == 'tanh':
62
            self.activation = K.tanh
63
        else:
64
            raise ValueError(f'`activation` not understood: {activation}')
65
        self.initializer = initializer
66
        self.regularizer = regularizer
67
68
    def build(self, input_shape):
69
        ent_embed_dim = input_shape[0][-1]
70
        neighbor_embed_dim = input_shape[1][-1]
71
        self.w = self.add_weight(name=self.name + '_w',
72
                                 shape=(ent_embed_dim+neighbor_embed_dim, ent_embed_dim),
73
                                 initializer=self.initializer, regularizer=self.regularizer)
74
        self.b = self.add_weight(name=self.name + '_b', shape=(ent_embed_dim,),
75
                                 initializer='zeros')
76
        super(ConcatAggregator, self).build(input_shape)
77
78
    def call(self, inputs, **kwargs):
79
        entity, neighbor = inputs
80
        return self.activation(K.dot(K.concatenate([entity, neighbor]), self.w) + self.b)
81
82
    def compute_output_shape(self, input_shape):
83
        return input_shape[0]
84
85
86
class NeighAggregator(Layer):
87
    def __init__(self, activation: str = 'relu', initializer='glorot_normal', regularizer=None,
88
                 **kwargs):
89
        super(NeighAggregator, self).__init__()
90
        if activation == 'relu':
91
            self.activation = K.relu
92
        elif activation == 'tanh':
93
            self.activation = K.tanh
94
        else:
95
            raise ValueError(f'`activation` not understood: {activation}')
96
        self.initializer = initializer
97
        self.regularizer = regularizer
98
99
    def build(self, input_shape):
100
        ent_embed_dim = input_shape[0][-1]
101
        neighbor_embed_dim = input_shape[1][-1]
102
        self.w = self.add_weight(name=self.name + '_w',
103
                                 shape=(neighbor_embed_dim, ent_embed_dim),
104
                                 initializer=self.initializer, regularizer=self.regularizer)
105
        self.b = self.add_weight(name=self.name + '_b', shape=(ent_embed_dim,),
106
                                 initializer='zeros')
107
        super(NeighAggregator, self).build(input_shape)
108
109
    def call(self, inputs, **kwargs):
110
        entity, neighbor = inputs
111
        return self.activation(K.dot(neighbor, self.w) + self.b)
112
113
    def compute_output_shape(self, input_shape):
114
        return input_shape[0]