a b/shepherd/node_embedder_model.py
1
# Pytorch 
2
import torch
3
import torch.nn as nn
4
5
import torch.nn.functional as F
6
from torch_geometric.nn import BatchNorm, LayerNorm, GATv2Conv
7
8
# Pytorch Lightning
9
import pytorch_lightning as pl
10
from pytorch_lightning.loggers import WandbLogger
11
12
# General
13
import numpy as np
14
import math
15
import tqdm
16
import time
17
import wandb
18
19
# Own
20
from utils.pretrain_utils import sample_node_for_et, get_batched_data, get_edges, calc_metrics, plot_roc_curve, metrics_per_rel
21
from decoders import bilinear, trans, dot
22
23
# Global variables
24
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
26
27
class NodeEmbeder(pl.LightningModule):
28
29
    def __init__(self, all_data, edge_attr_dict, hp_dict=None, num_nodes=None, combined_training=False, spl_mat=[]):
30
        super().__init__()
31
32
        # save hyperparameters
33
        self.save_hyperparameters("hp_dict", ignore=["spl_mat"])
34
35
        # Data
36
        self.all_data = all_data
37
        self.edge_attr_dict = edge_attr_dict
38
39
        # Model parameters
40
        self.lr = self.hparams.hp_dict['lr']
41
        self.lr_factor = self.hparams.hp_dict['lr_factor']
42
        self.lr_patience = self.hparams.hp_dict['lr_patience']
43
        self.lr_threshold = self.hparams.hp_dict['lr_threshold']
44
        self.lr_threshold_mode = self.hparams.hp_dict['lr_threshold_mode']
45
        self.lr_cooldown = self.hparams.hp_dict['lr_cooldown']
46
        self.min_lr = self.hparams.hp_dict['min_lr']
47
        self.eps = self.hparams.hp_dict['eps']
48
        
49
        self.wd = self.hparams.hp_dict['wd']
50
        self.decoder_type = self.hparams.hp_dict['decoder_type']
51
        self.pred_threshold = self.hparams.hp_dict['pred_threshold']
52
        
53
        self.use_spl = None
54
        self.spl_mat = []
55
        self.spl_dim = 0
56
        
57
        self.nfeat = self.hparams.hp_dict['nfeat']
58
        self.nhid1 = self.hparams.hp_dict['hidden'] * 2
59
        self.nhid2 = self.hparams.hp_dict['hidden']
60
        self.output = self.hparams.hp_dict['output']
61
62
        self.node_emb = nn.Embedding(num_nodes, self.nfeat)
63
        
64
        self.num_nodes = num_nodes 
65
        self.num_relations = len(edge_attr_dict)
66
        self.n_heads = self.hparams.hp_dict['n_heads']
67
        self.dropout = self.hparams.hp_dict['dropout']
68
        self.norm_method = self.hparams.hp_dict['norm_method']
69
70
        # Select decoder
71
        if self.decoder_type == "bilinear": self.decoder = bilinear
72
        elif self.decoder_type == "trans": self.decoder = trans
73
        elif self.decoder_type == "dot": self.decoder = dot
74
        
75
        self.n_layers = 3
76
        
77
        self.loss_type = self.hparams.hp_dict['loss']
78
        self.combined_training = combined_training
79
80
        # Conv layers
81
        self.convs = torch.nn.ModuleList()
82
        self.convs.append(GATv2Conv(self.nfeat, self.nhid1, self.n_heads)) # input = nfeat, output = nhid1*n_heads
83
        if self.n_layers == 3:
84
            self.convs.append(GATv2Conv(self.nhid1*self.n_heads, self.nhid2, self.n_heads)) # input = nhid1*n_heads, output = nhid2*n_heads
85
            self.convs.append(GATv2Conv(self.nhid2*self.n_heads, self.output, self.n_heads)) # input = nhid2*n_heads, output = output*n_heads
86
        else:
87
            self.convs.append(GATv2Conv(self.nhid1*self.n_heads, self.output, self.n_heads)) # input = nhid2*n_heads, output = output*n_heads
88
        
89
        # Relation learnable weights
90
        self.relation_weights = nn.Parameter(torch.Tensor(self.num_relations, self.output * self.n_heads))
91
92
        # Normalization (applied after a single conv layer)
93
        if self.norm_method == "batch":
94
            self.norms = torch.nn.ModuleList()
95
            self.norms.append(BatchNorm(self.nhid1*self.n_heads))
96
            self.norms.append(BatchNorm(self.nhid2*self.n_heads))
97
        elif self.norm_method == "layer":
98
            self.norms = torch.nn.ModuleList()
99
            self.norms.append(LayerNorm(self.nhid1*self.n_heads))
100
            self.norms.append(LayerNorm(self.nhid2*self.n_heads))
101
        elif self.norm_method == "batch_layer":
102
            self.batch_norms = torch.nn.ModuleList()
103
            self.batch_norms.append(BatchNorm(self.nhid1*self.n_heads))
104
            if self.n_layers == 3: self.batch_norms.append(BatchNorm(self.nhid2*self.n_heads))
105
            self.layer_norms = torch.nn.ModuleList()
106
            self.layer_norms.append(LayerNorm(self.nhid1*self.n_heads))
107
            if self.n_layers == 3: self.layer_norms.append(LayerNorm(self.nhid2*self.n_heads))
108
109
        self.reset_parameters()
110
111
    def reset_parameters(self):
112
        for conv in self.convs:
113
            conv.reset_parameters()
114
        nn.init.xavier_uniform_(self.relation_weights, gain = nn.init.calculate_gain('leaky_relu'))
115
116
    def forward(self, n_ids, adjs): 
117
        x = self.node_emb(n_ids)
118
        
119
        gat_attn = []
120
        assert len(adjs) == self.n_layers
121
        for i, (edge_index, _, edge_type, size) in enumerate(adjs):
122
            
123
            # Update node embeddings
124
            x_target = x[:size[1]]  # Target nodes are always placed first. 
125
126
            x, (edge_i, alpha) = self.convs[i]((x, x_target), edge_index, return_attention_weights=True)
127
128
            edge_i = edge_i.detach().cpu()
129
            alpha = alpha.detach().cpu()
130
            edge_i[0,:] = n_ids[edge_i[0,:]]
131
            edge_i[1,:] = n_ids[edge_i[1,:]]
132
            gat_attn.append((edge_i, alpha))
133
134
            # Normalize
135
            if i != self.n_layers - 1:
136
                if self.norm_method in ["batch", "layer"]:
137
                    x = self.norms[i](x)
138
                elif self.norm_method == "batch_layer":
139
                    x = self.layer_norms[i](x)
140
                x = F.leaky_relu(x)
141
                if self.norm_method == "batch_layer":
142
                    x = self.batch_norms[i](x)
143
                x = F.dropout(x, p=self.dropout, training=self.training)
144
145
        return x, gat_attn
146
147
148
149
    
150
    def get_negative_target_nodes(self, data, pos_target_embeds, curr_pos_target_embeds, all_edge_types):
151
        if self.hparams.hp_dict['negative_sampler_approach'] == 'all':
152
            # get negative targets by shuffling positive targets
153
            if 'index_to_node_features_pos' in data:
154
                rand_index = torch.randperm(data.index_to_node_features_pos.size(0))
155
            else:
156
                rand_index = torch.randperm(curr_pos_target_embeds.size(0))
157
            
158
        elif self.hparams.hp_dict['negative_sampler_approach'] == 'by_edge_type':
159
            # get negative targets by shuffling positive targets within each edge type
160
            et_ids, et_counts = all_edge_types.unique(return_counts=True)
161
            targets_dict = self.create_target_dict(all_edge_types, et_ids) # indices into all_edge_types for each edge type
162
            rand_index = torch.tensor(np.vectorize(sample_node_for_et)(all_edge_types.cpu(), targets_dict)).to(device)
163
164
        if 'index_to_node_features_pos' in data:
165
            index_to_node_features_neg = data.index_to_node_features_pos[rand_index] #NOTE: currently possible to get the same node as positive & negative target
166
            curr_neg_target_embeds = pos_target_embeds[index_to_node_features_neg,:]
167
        else:
168
            curr_neg_target_embeds = curr_pos_target_embeds[rand_index,:]
169
        
170
        return curr_neg_target_embeds
171
172
    def create_target_dict(self, all_edge_types, et_ids):
173
        targets_dict = {}
174
        for k in et_ids:
175
            indices = (all_edge_types == int(k)).nonzero().cpu()
176
            targets_dict[int(k)] = indices
177
        return targets_dict
178
179
    def decode(self, data, source_embeds, pos_target_embeds, all_edge_types): 
180
        curr_source_embeds = source_embeds[data.index_to_node_features_pos,:]
181
        curr_pos_target_embeds = pos_target_embeds[data.index_to_node_features_pos,:]
182
183
        ts = time.time()
184
        curr_neg_target_embeds = self.get_negative_target_nodes(data, pos_target_embeds, curr_pos_target_embeds, all_edge_types)
185
        te = time.time()
186
        if self.hparams.hp_dict['time']:
187
            print(f"Negative sampling took {te - ts:0.4f} seconds")
188
189
        # Get source & targets for pos & negative edges
190
        source = torch.cat([curr_source_embeds, curr_source_embeds])
191
        target = torch.cat([curr_pos_target_embeds, curr_neg_target_embeds])
192
        all_edge_types = torch.cat([all_edge_types, all_edge_types])
193
        data.all_edge_types = all_edge_types
194
195
        if self.decoder_type == "dot": 
196
            return data, self.decoder(source, target)
197
        else: 
198
            relation = self.relation_weights[all_edge_types]
199
            return data, self.decoder(source, relation, target)
200
201
    def get_predictions(self, data, embed):
202
        
203
        # Apply decoder
204
        source_embed, target_embed = embed.split(embed.size(0) // 2, dim=0)
205
        data, raw_pred = self.decode(data, source_embed, target_embed, data.pos_edge_types)
206
        
207
        # Apply activation
208
        if self.loss_type != "max-margin": 
209
            pred = torch.sigmoid(raw_pred)
210
        else: 
211
            pred = torch.tanh(raw_pred)
212
        
213
        return data, raw_pred, pred
214
215
    def get_link_labels(self, edge_types):
216
        num_links = edge_types.size(0) 
217
        link_labels = torch.zeros(num_links, dtype=torch.float, device=edge_types.device)
218
        link_labels[:(int(num_links/2))] = 1.
219
        link_labels[(int(num_links/2)):] = 0.
220
        return link_labels
221
222
    def _step(self, data, dataset_type):
223
224
        if not self.combined_training: 
225
            ts = time.time()
226
            data = get_batched_data(data, self.all_data) 
227
            tm = time.time()
228
            data = get_edges(data, self.all_data, dataset_type)
229
            te=time.time()
230
        data = data.to(device)
231
232
        # Get predictions
233
        t0 = time.time()
234
        out, gat_attn = self.forward(data.n_id, data.adjs) 
235
        t1 = time.time()
236
        data, raw_pred, pred = self.get_predictions(data, out)
237
        t2 = time.time()
238
239
        # Calculate loss
240
        link_labels = self.get_link_labels(data.all_edge_types)
241
        loss = self.calc_loss(pred, link_labels)
242
        t3 = time.time()
243
244
        # Calculate metrics
245
        if self.loss_type == "max-margin":
246
            metric_pred = torch.sigmoid(raw_pred)
247
            self.logger.experiment.log({f'{dataset_type}/node_predicted_probs': wandb.Histogram(metric_pred.cpu().detach().numpy())})
248
        else: metric_pred = pred
249
        roc_score, ap_score, acc, f1 = calc_metrics(metric_pred.cpu().detach().numpy(), link_labels.cpu().detach().numpy(), self.pred_threshold)
250
        self.logger.experiment.log({f'{dataset_type}/node_roc_curve': plot_roc_curve(metric_pred.cpu().detach().numpy(), link_labels.cpu().detach().numpy())})
251
        
252
        t4 = time.time()
253
        if self.hparams.hp_dict['time']:
254
            print(f'It took {tm-ts:0.2f}s to get batched data, {te-tm:0.2f}s to get edges, {t1-t0:0.2f}s to complete forward pass, {t2-t1:0.2f}s to decode, {t3-t2:0.2f}s to calc loss, and {t4-t3:0.2f}s to calc other metrics.')
255
256
        return data, loss, pred, link_labels, roc_score, ap_score, acc, f1
257
258
    def training_step(self, data, data_idx):
259
        data, loss, pred, link_labels, roc_score, ap_score, acc, f1 = self._step(data, 'train')
260
        
261
        logs = {"train/node_batch_loss": loss.detach(), 
262
                "train/node_roc": roc_score, 
263
                "train/node_ap": ap_score, 
264
                "train/node_acc": acc, 
265
                "train/node_f1": f1
266
               }
267
268
        rel_logs = metrics_per_rel(pred, link_labels, self.edge_attr_dict, data.all_edge_types, "train", self.pred_threshold)
269
        logs.update(rel_logs)
270
        self._logger(logs)
271
        return {'loss': loss, 'logs': logs}
272
273
    def training_epoch_end(self, outputs):     
274
        roc_train = []
275
        ap_train = []
276
        acc_train = []
277
        f1_train = []
278
        total_train_loss = []
279
280
        for batch_log in outputs:
281
            roc_train.append(batch_log['logs']["train/node_roc"])
282
            ap_train.append(batch_log['logs']["train/node_ap"])
283
            acc_train.append(batch_log['logs']["train/node_acc"])
284
            f1_train.append(batch_log['logs']["train/node_f1"])
285
            total_train_loss.append(batch_log['logs']["train/node_batch_loss"])
286
287
        self._logger({"train/node_total_loss": torch.mean(torch.Tensor(total_train_loss)), 
288
                      "train/node_total_roc": np.mean(roc_train), 
289
                      "train/node_total_ap": np.mean(ap_train), 
290
                      "train/node_total_acc": np.mean(acc_train), 
291
                      "train/node_total_f1": np.mean(f1_train)})
292
        self._logger({'node_curr_epoch': self.current_epoch})
293
294
    def validation_step(self, data, data_idx):
295
        data, loss, pred, link_labels, roc_score, ap_score, acc, f1 = self._step(data, 'val')
296
        
297
        logs = {"val/node_batch_loss": loss.detach().cpu(), 
298
                "val/node_roc": roc_score, 
299
                "val/node_ap": ap_score, 
300
                "val/node_acc": acc, 
301
                "val/node_f1": f1
302
               }
303
304
        rel_logs = metrics_per_rel(pred, link_labels, self.edge_attr_dict, data.all_edge_types, "val", self.pred_threshold)
305
        logs.update(rel_logs)
306
        self._logger(logs)
307
        return logs
308
309
    def validation_epoch_end(self, outputs):
310
        roc_val = []
311
        ap_val = []
312
        acc_val = []
313
        f1_val = []
314
        total_val_loss = []
315
316
        for batch_log in outputs:
317
            roc_val.append(batch_log["val/node_roc"])
318
            ap_val.append(batch_log["val/node_ap"])
319
            acc_val.append(batch_log["val/node_acc"])
320
            f1_val.append(batch_log["val/node_f1"])
321
            total_val_loss.append(batch_log["val/node_batch_loss"])
322
        
323
        self._logger({"val/node_total_loss": torch.mean(torch.Tensor(total_val_loss)), 
324
                      "val/node_total_roc": np.mean(roc_val), 
325
                      "val/node_total_ap": np.mean(ap_val), 
326
                      "val/node_total_acc": np.mean(acc_val), 
327
                      "val/node_total_f1": np.mean(f1_val)})
328
        self._logger({'node_curr_epoch': self.current_epoch})
329
330
    def test_step(self, data, data_idx):
331
        data, loss, pred, link_labels, roc_score, ap_score, acc, f1 = self._step(data, 'test')
332
333
        logs = {"test/node_batch_loss": loss.detach().cpu(), 
334
                "test/node_roc": roc_score, 
335
                "test/node_ap": ap_score, 
336
                "test/node_acc": acc, 
337
                "test/node_f1": f1
338
               }
339
340
        rel_logs = metrics_per_rel(pred, link_labels, self.edge_attr_dict, data.all_edge_types, "test", self.pred_threshold)
341
        logs.update(rel_logs)
342
        self._logger(logs)
343
        return logs
344
345
    def test_epoch_end(self, outputs):
346
        roc = []
347
        ap = []
348
        acc = []
349
        f1 = []
350
351
        for batch_log in outputs:
352
            roc.append(batch_log["test/node_roc"])
353
            ap.append(batch_log["test/node_ap"])
354
            acc.append(batch_log["test/node_acc"])
355
            f1.append(batch_log["test/node_f1"])
356
        
357
        self._logger({"test/node_total_roc": np.mean(roc), 
358
                      "test/node_total_ap": np.mean(ap), 
359
                      "test/node_total_acc": np.mean(acc), 
360
                      "test/node_total_f1": np.mean(f1)})
361
        self._logger({'node_curr_epoch': self.current_epoch})
362
363
    
364
    def predict(self, data):
365
        n_id = torch.arange(self.node_emb.weight.shape[0], device=self.device)
366
367
        x = self.node_emb(n_id)
368
369
        gat_attn = []
370
        for i in range(len(self.convs)):
371
            
372
            # Update node embeddings
373
            x, (edge_i, alpha) = self.convs[i](x, data.edge_index.to(self.device), return_attention_weights=True) #
374
375
            edge_i = edge_i.detach().cpu()
376
            alpha = alpha.detach().cpu()
377
            edge_i[0,:] = n_id[edge_i[0,:]]
378
            edge_i[1,:] = n_id[edge_i[1,:]]
379
            gat_attn.append((edge_i, alpha))
380
            
381
            # Normalize
382
            if i != self.n_layers - 1:
383
                if self.norm_method in ["batch", "layer"]:
384
                    x = self.norms[i](x)
385
                elif self.norm_method == "batch_layer":
386
                    x = self.layer_norms[i](x)
387
                x = F.leaky_relu(x)
388
                if self.norm_method == "batch_layer":
389
                    x = self.batch_norms[i](x)
390
        
391
        assert x.shape[0] == self.node_emb.weight.shape[0]
392
393
        return x, gat_attn
394
395
    def predict_step(self, data, data_idx):
396
        x, gat_attn = self.predict(data)
397
        return x, gat_attn
398
    
399
    def configure_optimizers(self):
400
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay = self.wd)
401
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=self.lr_factor, patience=self.lr_patience, threshold=self.lr_threshold, threshold_mode=self.lr_threshold_mode, cooldown=self.lr_cooldown, min_lr=self.min_lr, eps=self.eps)
402
        return {
403
                "optimizer": optimizer,
404
                "lr_scheduler": 
405
                    {
406
                    "scheduler": scheduler,
407
                    "monitor": "val/node_total_loss",
408
                    'name': 'curr_lr'
409
                    },
410
                }
411
412
    def _logger(self, logs):
413
        for k, v in logs.items():
414
            self.log(k, v)
415
416
    def calc_loss(self, pred, y):
417
        if self.loss_type == "BCE":
418
            loss = F.binary_cross_entropy(pred, y, reduction='none')
419
            norm_loss = torch.mean(loss)
420
421
        elif self.loss_type == "max-margin": 
422
            loss = ((1 - (pred[y == 1] - pred[y != 1])).clamp(min=0).mean())
423
            norm_loss = loss
424
425
        return norm_loss 
426