Diff of /networks.py [000000] .. [2095ed]

Switch to unified view

a b/networks.py
1
# Base / Native
2
import csv
3
from collections import Counter
4
import copy
5
import json
6
import functools
7
import gc
8
import logging
9
import math
10
import os
11
import pdb
12
import pickle
13
import random
14
import sys
15
import tables
16
import time
17
from tqdm import tqdm
18
19
# Numerical / Array
20
import numpy as np
21
22
# Torch
23
import torch
24
import torch.nn as nn
25
import torch.nn.functional as F
26
import torch.optim as optim
27
from torch import Tensor
28
from torch.autograd import Variable
29
from torch.nn import init, Parameter
30
from torch.utils.data import DataLoader
31
from torch.utils.model_zoo import load_url as load_state_dict_from_url
32
from torchvision import datasets, transforms
33
import torch.optim.lr_scheduler as lr_scheduler
34
from torch_geometric.nn import GCNConv, SAGEConv, GraphConv, GatedGraphConv, GATConv
35
from torch_geometric.nn import GraphConv, TopKPooling, SAGPooling
36
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
37
from torch_geometric.transforms.normalize_features import NormalizeFeatures
38
39
# Env
40
from fusion import *
41
from options import parse_args
42
from utils import *
43
44
45
################
46
# Network Utils
47
################
48
def define_net(opt, k):
49
    net = None
50
    act = define_act_layer(act_type=opt.act_type)
51
    init_max = True if opt.init_type == "max" else False
52
53
    if opt.mode == "path":
54
        net = get_vgg(path_dim=opt.path_dim, act=act, label_dim=opt.label_dim)
55
    elif opt.mode == "graph":
56
        net = GraphNet(grph_dim=opt.grph_dim, dropout_rate=opt.dropout_rate, GNN=opt.GNN, use_edges=opt.use_edges, pooling_ratio=opt.pooling_ratio, act=act, label_dim=opt.label_dim, init_max=init_max)
57
    elif opt.mode == "omic":
58
        net = MaxNet(input_dim=opt.input_size_omic, omic_dim=opt.omic_dim, dropout_rate=opt.dropout_rate, act=act, label_dim=opt.label_dim, init_max=init_max)
59
    elif opt.mode == "graphomic":
60
        net = GraphomicNet(opt=opt, act=act, k=k)
61
    elif opt.mode == "pathomic":
62
        net = PathomicNet(opt=opt, act=act, k=k)
63
    elif opt.mode == "pathgraphomic":
64
        net = PathgraphomicNet(opt=opt, act=act, k=k)
65
    elif opt.mode == "pathpath":
66
        net = PathpathNet(opt=opt, act=act, k=k)
67
    elif opt.mode == "graphgraph":
68
        net = GraphgraphNet(opt=opt, act=act, k=k)
69
    elif opt.mode == "omicomic":
70
        net = OmicomicNet(opt=opt, act=act, k=k)
71
    else:
72
        raise NotImplementedError('model [%s] is not implemented' % opt.model)
73
    return init_net(net, opt.init_type, opt.init_gain, opt.gpu_ids)
74
75
76
def define_optimizer(opt, model):
77
    optimizer = None
78
    if opt.optimizer_type == 'adabound':
79
        optimizer = adabound.AdaBound(model.parameters(), lr=opt.lr, final_lr=opt.final_lr)
80
    elif opt.optimizer_type == 'adam':
81
        optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay)
82
    elif opt.optimizer_type == 'adagrad':
83
        optimizer = torch.optim.Adagrad(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay, initial_accumulator_value=0.1)
84
    else:
85
        raise NotImplementedError('initialization method [%s] is not implemented' % opt.optimizer)
86
    return optimizer
87
88
89
def define_reg(opt, model):
90
    loss_reg = None
91
    
92
    if opt.reg_type == 'none':
93
        loss_reg = 0
94
    elif opt.reg_type == 'path':
95
        loss_reg = regularize_path_weights(model=model)
96
    elif opt.reg_type == 'mm':
97
        loss_reg = regularize_MM_weights(model=model)
98
    elif opt.reg_type == 'all':
99
        loss_reg = regularize_weights(model=model)
100
    elif opt.reg_type == 'omic':
101
        loss_reg = regularize_MM_omic(model=model)
102
    else:
103
        raise NotImplementedError('reg method [%s] is not implemented' % opt.reg_type)
104
    return loss_reg
105
106
107
def define_scheduler(opt, optimizer):
108
    if opt.lr_policy == 'linear':
109
        def lambda_rule(epoch):
110
            lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
111
            return lr_l
112
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
113
    elif opt.lr_policy == 'exp':
114
        scheduler = lr_scheduler.ExponentialLR(optimizer, 0.1, last_epoch=-1)
115
    elif opt.lr_policy == 'step':
116
       scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
117
    elif opt.lr_policy == 'plateau':
118
       scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
119
    elif opt.lr_policy == 'cosine':
120
       scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
121
    else:
122
       return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
123
    return scheduler
124
125
126
def define_act_layer(act_type='Tanh'):
127
    if act_type == 'Tanh':
128
        act_layer = nn.Tanh()
129
    elif act_type == 'ReLU':
130
        act_layer = nn.ReLU()
131
    elif act_type == 'Sigmoid':
132
        act_layer = nn.Sigmoid()
133
    elif act_type == 'LSM':
134
        act_layer = nn.LogSoftmax(dim=1)
135
    elif act_type == "none":
136
        act_layer = None
137
    else:
138
        raise NotImplementedError('activation layer [%s] is not found' % act_type)
139
    return act_layer
140
141
142
def define_bifusion(fusion_type, skip=1, use_bilinear=1, gate1=1, gate2=1, dim1=32, dim2=32, scale_dim1=1, scale_dim2=1, mmhid=64, dropout_rate=0.25):
143
    fusion = None
144
    if fusion_type == 'pofusion':
145
        fusion = BilinearFusion(skip=skip, use_bilinear=use_bilinear, gate1=gate1, gate2=gate2, dim1=dim1, dim2=dim2, scale_dim1=scale_dim1, scale_dim2=scale_dim2, mmhid=mmhid, dropout_rate=dropout_rate)
146
    else:
147
        raise NotImplementedError('fusion type [%s] is not found' % fusion_type)
148
    return fusion
149
150
151
def define_trifusion(fusion_type, skip=1, use_bilinear=1, gate1=1, gate2=1, gate3=3, dim1=32, dim2=32, dim3=32, scale_dim1=1, scale_dim2=1, scale_dim3=1, mmhid=96, dropout_rate=0.25):
152
    fusion = None
153
    if fusion_type == 'pofusion_A':
154
        fusion = TrilinearFusion_A(skip=skip, use_bilinear=use_bilinear, gate1=gate1, gate2=gate2, gate3=gate3, dim1=dim1, dim2=dim2, dim3=dim3, scale_dim1=scale_dim1, scale_dim2=scale_dim2, scale_dim3=scale_dim3, mmhid=mmhid, dropout_rate=dropout_rate)
155
    elif fusion_type == 'pofusion_B':
156
        fusion = TrilinearFusion_B(skip=skip, use_bilinear=use_bilinear, gate1=gate1, gate2=gate2, gate3=gate3, dim1=dim1, dim2=dim2, dim3=dim3, scale_dim1=scale_dim1, scale_dim2=scale_dim2, scale_dim3=scale_dim3, mmhid=mmhid, dropout_rate=dropout_rate)
157
    else:
158
        raise NotImplementedError('fusion type [%s] is not found' % fusion_type)
159
    return fusion
160
161
162
163
############
164
# Omic Model
165
############
166
class MaxNet(nn.Module):
167
    def __init__(self, input_dim=80, omic_dim=32, dropout_rate=0.25, act=None, label_dim=1, init_max=True):
168
        super(MaxNet, self).__init__()
169
        hidden = [64, 48, 32, 32]
170
        self.act = act
171
172
        encoder1 = nn.Sequential(
173
            nn.Linear(input_dim, hidden[0]),
174
            nn.ELU(),
175
            nn.AlphaDropout(p=dropout_rate, inplace=False))
176
        
177
        encoder2 = nn.Sequential(
178
            nn.Linear(hidden[0], hidden[1]),
179
            nn.ELU(),
180
            nn.AlphaDropout(p=dropout_rate, inplace=False))
181
        
182
        encoder3 = nn.Sequential(
183
            nn.Linear(hidden[1], hidden[2]),
184
            nn.ELU(),
185
            nn.AlphaDropout(p=dropout_rate, inplace=False))
186
187
        encoder4 = nn.Sequential(
188
            nn.Linear(hidden[2], omic_dim),
189
            nn.ELU(),
190
            nn.AlphaDropout(p=dropout_rate, inplace=False))
191
        
192
        self.encoder = nn.Sequential(encoder1, encoder2, encoder3, encoder4)
193
        self.classifier = nn.Sequential(nn.Linear(omic_dim, label_dim))
194
195
        if init_max: init_max_weights(self)
196
197
        self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False)
198
        self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False)
199
200
    def forward(self, **kwargs):
201
        x = kwargs['x_omic']
202
        features = self.encoder(x)
203
        out = self.classifier(features)
204
        if self.act is not None:
205
            out = self.act(out)
206
207
            if isinstance(self.act, nn.Sigmoid):
208
                out = out * self.output_range + self.output_shift
209
210
        return features, out
211
212
213
214
############
215
# Graph Model
216
############
217
class NormalizeFeaturesV2(object):
218
    r"""Column-normalizes node features to sum-up to one."""
219
220
    def __call__(self, data):
221
        data.x = data.x / data.x.max(0, keepdim=True)[0]#.type(torch.cuda.FloatTensor)
222
        return data
223
224
    def __repr__(self):
225
        return '{}()'.format(self.__class__.__name__)
226
227
class NormalizeFeaturesV2(object):
228
    r"""Column-normalizes node features to sum-up to one."""
229
230
    def __call__(self, data):
231
        data.x[:, :12] = data.x[:, :12] / data.x[:, :12].max(0, keepdim=True)[0]
232
        data.x = data.x.type(torch.cuda.FloatTensor)
233
        return data
234
235
    def __repr__(self):
236
        return '{}()'.format(self.__class__.__name__)
237
238
239
class NormalizeEdgesV2(object):
240
    r"""Column-normalizes node features to sum-up to one."""
241
242
    def __call__(self, data):
243
        data.edge_attr = data.edge_attr.type(torch.cuda.FloatTensor)
244
        data.edge_attr = data.edge_attr / data.edge_attr.max(0, keepdim=True)[0]#.type(torch.cuda.FloatTensor)
245
        return data
246
247
    def __repr__(self):
248
        return '{}()'.format(self.__class__.__name__)
249
250
251
class GraphNet(torch.nn.Module):
252
    def __init__(self, features=1036, nhid=128, grph_dim=32, nonlinearity=torch.tanh, 
253
        dropout_rate=0.25, GNN='GCN', use_edges=0, pooling_ratio=0.20, act=None, label_dim=1, init_max=True):
254
        super(GraphNet, self).__init__()
255
256
        self.dropout_rate = dropout_rate
257
        self.use_edges = use_edges
258
        self.act = act
259
260
        self.conv1 = SAGEConv(features, nhid)
261
        self.pool1 = SAGPooling(nhid, ratio=pooling_ratio, gnn=GNN)#, nonlinearity=nonlinearity)
262
        self.conv2 = SAGEConv(nhid, nhid)
263
        self.pool2 = SAGPooling(nhid, ratio=pooling_ratio, gnn=GNN)#, nonlinearity=nonlinearity)
264
        self.conv3 = SAGEConv(nhid, nhid)
265
        self.pool3 = SAGPooling(nhid, ratio=pooling_ratio, gnn=GNN)#, nonlinearity=nonlinearity)
266
267
        self.lin1 = torch.nn.Linear(nhid*2, nhid)
268
        self.lin2 = torch.nn.Linear(nhid, grph_dim)
269
        self.lin3 = torch.nn.Linear(grph_dim, label_dim)
270
271
        self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False)
272
        self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False)
273
274
        if init_max: 
275
            init_max_weights(self)
276
            print("Initialzing with Max")
277
278
    def forward(self, **kwargs):
279
        data = kwargs['x_grph']
280
        data = NormalizeFeaturesV2()(data)
281
        data = NormalizeEdgesV2()(data)
282
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
283
284
        #x, edge_index, edge_attr, batch = data.x.type(torch.cuda.FloatTensor), data.edge_index.type(torch.cuda.LongTensor), data.edge_attr.type(torch.cuda.FloatTensor), data.batch
285
        x = F.relu(self.conv1(x, edge_index))
286
        x, edge_index, edge_attr, batch, _ = self.pool1(x, edge_index, edge_attr, batch)
287
        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
288
289
        x = F.relu(self.conv2(x, edge_index))
290
        x, edge_index, edge_attr, batch, _ = self.pool2(x, edge_index, edge_attr, batch)
291
        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
292
293
        x = F.relu(self.conv3(x, edge_index))
294
        x, edge_index, edge_attr, batch, _ = self.pool3(x, edge_index, edge_attr, batch)
295
        x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
296
297
        x = x1 + x2 + x3 
298
299
        x = F.relu(self.lin1(x))
300
        x = F.dropout(x, p=self.dropout_rate, training=self.training)
301
        features = F.relu(self.lin2(x))
302
        out = self.lin3(features)
303
        if self.act is not None:
304
            out = self.act(out)
305
306
            if isinstance(self.act, nn.Sigmoid):
307
                out = out * self.output_range + self.output_shift
308
309
        return features, out
310
311
312
313
############
314
# Path Model
315
############
316
model_urls = {
317
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
318
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
319
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
320
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
321
    'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
322
    'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
323
    'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
324
    'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
325
}
326
327
328
class PathNet(nn.Module):
329
330
    def __init__(self, features, path_dim=32, act=None, num_classes=1):
331
        super(PathNet, self).__init__()
332
        self.features = features
333
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
334
        
335
        self.classifier = nn.Sequential(
336
            nn.Linear(512 * 7 * 7, 1024),
337
            nn.ReLU(True),
338
            nn.Dropout(0.25),
339
            nn.Linear(1024, 1024),
340
            nn.ReLU(True),
341
            nn.Dropout(0.25),
342
            nn.Linear(1024, path_dim),
343
            nn.ReLU(True),
344
            nn.Dropout(0.05)
345
        )
346
347
        self.linear = nn.Linear(path_dim, num_classes)
348
        self.act = act
349
350
        self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False)
351
        self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False)
352
353
        dfs_freeze(self.features)
354
355
    def forward(self, **kwargs):
356
        x = kwargs['x_path']
357
        x = self.features(x)
358
        x = self.avgpool(x)
359
        x = x.view(x.size(0), -1)
360
        features = self.classifier(x)
361
        hazard = self.linear(features)
362
363
        if self.act is not None:
364
            hazard = self.act(hazard)
365
366
            if isinstance(self.act, nn.Sigmoid):
367
                hazard = hazard * self.output_range + self.output_shift
368
369
        return features, hazard
370
371
372
def make_layers(cfg, batch_norm=False):
373
    layers = []
374
    in_channels = 3
375
    for v in cfg:
376
        if v == 'M':
377
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
378
        else:
379
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
380
            if batch_norm:
381
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
382
            else:
383
                layers += [conv2d, nn.ReLU(inplace=True)]
384
            in_channels = v
385
    return nn.Sequential(*layers)
386
387
388
cfgs = {
389
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
390
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
391
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
392
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
393
}
394
395
396
397
def get_vgg(arch='vgg19_bn', cfg='E', act=None, batch_norm=True, label_dim=1, pretrained=True, progress=True, **kwargs):
398
    model = PathNet(make_layers(cfgs[cfg], batch_norm=batch_norm), act=act, num_classes=label_dim, **kwargs)
399
    
400
    if pretrained:
401
        pretrained_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
402
403
        for key in list(pretrained_dict.keys()):
404
            if 'classifier' in key: pretrained_dict.pop(key)
405
406
        model.load_state_dict(pretrained_dict, strict=False)
407
        print("Initializing Path Weights")
408
409
    return model
410
411
412
413
##############################################################################
414
# Graph + Omic
415
##############################################################################
416
class GraphomicNet(nn.Module):
417
    def __init__(self, opt, act, k):
418
        super(GraphomicNet, self).__init__()
419
        self.grph_net = GraphNet(grph_dim=opt.grph_dim, dropout_rate=opt.dropout_rate, use_edges=1, pooling_ratio=0.20, label_dim=opt.label_dim, init_max=False)
420
        self.omic_net = MaxNet(input_dim=opt.input_size_omic, omic_dim=opt.omic_dim, dropout_rate=opt.dropout_rate, act=act, label_dim=opt.label_dim, init_max=False)
421
422
        if k is not None:
423
            pt_fname = '_%d.pt' % k
424
            best_grph_ckpt = torch.load(os.path.join(opt.checkpoints_dir, opt.exp_name, 'graph', 'graph'+pt_fname), map_location=torch.device('cpu'))
425
            best_omic_ckpt = torch.load(os.path.join(opt.checkpoints_dir, opt.exp_name, 'omic', 'omic'+pt_fname), map_location=torch.device('cpu'))
426
            self.grph_net.load_state_dict(best_grph_ckpt['model_state_dict'])
427
            self.omic_net.load_state_dict(best_omic_ckpt['model_state_dict'])
428
            print("Loading Models:\n", os.path.join(opt.checkpoints_dir, opt.exp_name, 'graph', 'graph'+pt_fname), "\n", os.path.join(opt.checkpoints_dir, opt.exp_name, 'omic', 'omic'+pt_fname))
429
430
        self.fusion = define_bifusion(fusion_type=opt.fusion_type, skip=opt.skip, use_bilinear=opt.use_bilinear, gate1=opt.grph_gate, gate2=opt.omic_gate, dim1=opt.grph_dim, dim2=opt.omic_dim, scale_dim1=opt.grph_scale, scale_dim2=opt.omic_scale, mmhid=opt.mmhid, dropout_rate=opt.dropout_rate)
431
        self.classifier = nn.Sequential(nn.Linear(opt.mmhid, opt.label_dim))
432
        self.act = act
433
434
        dfs_freeze(self.grph_net)
435
        dfs_freeze(self.omic_net)
436
        self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False)
437
        self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False)
438
439
    def forward(self, **kwargs):
440
        grph_vec, _ = self.grph_net(x_grph=kwargs['x_grph'])
441
        omic_vec, _ = self.omic_net(x_omic=kwargs['x_omic'])
442
        features = self.fusion(grph_vec, omic_vec)
443
        hazard = self.classifier(features)
444
        if self.act is not None:
445
            hazard = self.act(hazard)
446
447
            if isinstance(self.act, nn.Sigmoid):
448
                hazard = hazard * self.output_range + self.output_shift
449
450
        return features, hazard
451
452
    def __hasattr__(self, name):
453
        if '_parameters' in self.__dict__:
454
            _parameters = self.__dict__['_parameters']
455
            if name in _parameters:
456
                return True
457
        if '_buffers' in self.__dict__:
458
            _buffers = self.__dict__['_buffers']
459
            if name in _buffers:
460
                return True
461
        if '_modules' in self.__dict__:
462
            modules = self.__dict__['_modules']
463
            if name in modules:
464
                return True
465
        return False
466
467
468
##############################################################################
469
# Path + Omic
470
##############################################################################
471
class PathomicNet(nn.Module):
472
    def __init__(self, opt, act, k):
473
        super(PathomicNet, self).__init__()
474
        self.omic_net = MaxNet(input_dim=opt.input_size_omic, omic_dim=opt.omic_dim, dropout_rate=opt.dropout_rate, act=act, label_dim=opt.label_dim, init_max=False)
475
476
        if k is not None:
477
            pt_fname = '_%d.pt' % k
478
            best_omic_ckpt = torch.load(os.path.join(opt.checkpoints_dir, opt.exp_name, 'omic', 'omic'+pt_fname), map_location=torch.device('cpu'))
479
            self.omic_net.load_state_dict(best_omic_ckpt['model_state_dict'])
480
            print("Loading Models:\n", os.path.join(opt.checkpoints_dir, opt.exp_name, 'omic', 'omic'+pt_fname))
481
482
        self.fusion = define_bifusion(fusion_type=opt.fusion_type, skip=opt.skip, use_bilinear=opt.use_bilinear, gate1=opt.path_gate, gate2=opt.omic_gate, dim1=opt.path_dim, dim2=opt.omic_dim, scale_dim1=opt.path_scale, scale_dim2=opt.omic_scale, mmhid=opt.mmhid, dropout_rate=opt.dropout_rate)
483
        self.classifier = nn.Sequential(nn.Linear(opt.mmhid, opt.label_dim))
484
        self.act = act
485
486
        dfs_freeze(self.omic_net)
487
        self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False)
488
        self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False)
489
490
    def forward(self, **kwargs):
491
        path_vec = kwargs['x_path']
492
        omic_vec, _ = self.omic_net(x_omic=kwargs['x_omic'])
493
        features = self.fusion(path_vec, omic_vec)
494
        hazard = self.classifier(features)
495
        if self.act is not None:
496
            hazard = self.act(hazard)
497
498
            if isinstance(self.act, nn.Sigmoid):
499
                hazard = hazard * self.output_range + self.output_shift
500
501
        return features, hazard
502
503
    def __hasattr__(self, name):
504
        if '_parameters' in self.__dict__:
505
            _parameters = self.__dict__['_parameters']
506
            if name in _parameters:
507
                return True
508
        if '_buffers' in self.__dict__:
509
            _buffers = self.__dict__['_buffers']
510
            if name in _buffers:
511
                return True
512
        if '_modules' in self.__dict__:
513
            modules = self.__dict__['_modules']
514
            if name in modules:
515
                return True
516
        return False
517
518
519
520
#############################################################################
521
# Path + Graph + Omic
522
##############################################################################
523
class PathgraphomicNet(nn.Module):
524
    def __init__(self, opt, act, k):
525
        super(PathgraphomicNet, self).__init__()
526
        self.grph_net = GraphNet(grph_dim=opt.grph_dim, dropout_rate=opt.dropout_rate, use_edges=1, pooling_ratio=0.20, label_dim=opt.label_dim, init_max=False)
527
        self.omic_net = MaxNet(input_dim=opt.input_size_omic, omic_dim=opt.omic_dim, dropout_rate=opt.dropout_rate, act=act, label_dim=opt.label_dim, init_max=False)
528
529
        if k is not None:
530
            pt_fname = '_%d.pt' % k
531
            best_grph_ckpt = torch.load(os.path.join(opt.checkpoints_dir, opt.exp_name, 'graph', 'graph'+pt_fname), map_location=torch.device('cpu'))
532
            best_omic_ckpt = torch.load(os.path.join(opt.checkpoints_dir, opt.exp_name, 'omic', 'omic'+pt_fname), map_location=torch.device('cpu'))
533
            self.grph_net.load_state_dict(best_grph_ckpt['model_state_dict'])
534
            self.omic_net.load_state_dict(best_omic_ckpt['model_state_dict'])
535
            print("Loading Models:\n", os.path.join(opt.checkpoints_dir, opt.exp_name, 'graph', 'graph'+pt_fname), "\n", os.path.join(opt.checkpoints_dir, opt.exp_name, 'omic', 'omic'+pt_fname))
536
537
        self.fusion = define_trifusion(fusion_type=opt.fusion_type, skip=opt.skip, use_bilinear=opt.use_bilinear, gate1=opt.path_gate, gate2=opt.grph_gate, gate3=opt.omic_gate, dim1=opt.path_dim, dim2=opt.grph_dim, dim3=opt.omic_dim, scale_dim1=opt.path_scale, scale_dim2=opt.grph_scale, scale_dim3=opt.omic_scale, mmhid=opt.mmhid, dropout_rate=opt.dropout_rate)
538
        self.classifier = nn.Sequential(nn.Linear(opt.mmhid, opt.label_dim))
539
        self.act = act
540
541
        dfs_freeze(self.grph_net)
542
        dfs_freeze(self.omic_net)
543
        self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False)
544
        self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False)
545
546
    def forward(self, **kwargs):
547
        path_vec = kwargs['x_path']
548
        grph_vec, _ = self.grph_net(x_grph=kwargs['x_grph'])
549
        omic_vec, _ = self.omic_net(x_omic=kwargs['x_omic'])
550
        features = self.fusion(path_vec, grph_vec, omic_vec)
551
        hazard = self.classifier(features)
552
        if self.act is not None:
553
            hazard = self.act(hazard)
554
555
            if isinstance(self.act, nn.Sigmoid):
556
                hazard = hazard * self.output_range + self.output_shift
557
558
        return features, hazard
559
560
    def __hasattr__(self, name):
561
        if '_parameters' in self.__dict__:
562
            _parameters = self.__dict__['_parameters']
563
            if name in _parameters:
564
                return True
565
        if '_buffers' in self.__dict__:
566
            _buffers = self.__dict__['_buffers']
567
            if name in _buffers:
568
                return True
569
        if '_modules' in self.__dict__:
570
            modules = self.__dict__['_modules']
571
            if name in modules:
572
                return True
573
        return False
574
575
576
577
##############################################################################
578
# Ensembling Effects
579
##############################################################################
580
class PathgraphNet(nn.Module):
581
    def __init__(self, opt, act, k):
582
        super(PathgraphNet, self).__init__()
583
        self.grph_net = GraphNet(grph_dim=opt.grph_dim, dropout_rate=opt.dropout_rate, use_edges=1, pooling_ratio=0.20, label_dim=opt.label_dim, init_max=False)
584
585
        if k is not None:
586
            pt_fname = '_%d.pt' % k
587
            best_grph_ckpt = torch.load(os.path.join(opt.checkpoints_dir, opt.exp_name, 'graph', 'graph'+pt_fname), map_location=torch.device('cpu'))
588
            self.grph_net.load_state_dict(best_grph_ckpt['model_state_dict'])
589
            print("Loading Models:\n", os.path.join(opt.checkpoints_dir, opt.exp_name, 'graph', 'graph'+pt_fname))
590
591
        self.fusion = define_bifusion(fusion_type=opt.fusion_type, skip=opt.skip, use_bilinear=opt.use_bilinear, gate1=opt.path_gate, gate2=opt.grph_gate, dim1=opt.path_dim, dim2=opt.grph_dim, scale_dim1=opt.path_scale, scale_dim2=opt.grph_scale, mmhid=opt.mmhid, dropout_rate=opt.dropout_rate)
592
        self.classifier = nn.Sequential(nn.Linear(opt.mmhid, opt.label_dim))
593
        self.act = act
594
595
        dfs_freeze(self.grph_net)
596
        self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False)
597
        self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False)
598
599
    def forward(self, **kwargs):
600
        path_vec = kwargs['x_path']
601
        grph_vec, _ = self.grph_net(x_grph=kwargs['x_grph'])
602
        features = self.fusion(path_vec, grph_vec)
603
        hazard = self.classifier(features)
604
        if self.act is not None:
605
            hazard = self.act(hazard)
606
607
            if isinstance(self.act, nn.Sigmoid):
608
                hazard = hazard * self.output_range + self.output_shift
609
610
        return features, hazard
611
612
    def __hasattr__(self, name):
613
        if '_parameters' in self.__dict__:
614
            _parameters = self.__dict__['_parameters']
615
            if name in _parameters:
616
                return True
617
        if '_buffers' in self.__dict__:
618
            _buffers = self.__dict__['_buffers']
619
            if name in _buffers:
620
                return True
621
        if '_modules' in self.__dict__:
622
            modules = self.__dict__['_modules']
623
            if name in modules:
624
                return True
625
        return False
626
627
628
class PathpathNet(nn.Module):
629
    def __init__(self, opt, act, k):
630
        super(PathpathNet, self).__init__()
631
        self.fusion = define_bifusion(fusion_type=opt.fusion_type, skip=opt.skip, use_bilinear=opt.use_bilinear, gate1=opt.path_gate, gate2=1-opt.path_gate if opt.path_gate else 0, 
632
            dim1=opt.path_dim, dim2=opt.path_dim, scale_dim1=opt.path_scale, scale_dim2=opt.path_scale, mmhid=opt.mmhid, dropout_rate=opt.dropout_rate)
633
        self.classifier = nn.Sequential(nn.Linear(opt.mmhid, opt.label_dim))
634
        self.act = act
635
        self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False)
636
        self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False)
637
638
    def forward(self, **kwargs):
639
        path_vec = kwargs['x_path']
640
        features = self.fusion(path_vec, path_vec)
641
        hazard = self.classifier(features)
642
        if self.act is not None:
643
            hazard = self.act(hazard)
644
            if isinstance(self.act, nn.Sigmoid):
645
                hazard = hazard * self.output_range + self.output_shift
646
        return features, hazard
647
648
    def __hasattr__(self, name):
649
        if '_parameters' in self.__dict__:
650
            _parameters = self.__dict__['_parameters']
651
            if name in _parameters:
652
                return True
653
        if '_buffers' in self.__dict__:
654
            _buffers = self.__dict__['_buffers']
655
            if name in _buffers:
656
                return True
657
        if '_modules' in self.__dict__:
658
            modules = self.__dict__['_modules']
659
            if name in modules:
660
                return True
661
        return False
662
663
664
class GraphgraphNet(nn.Module):
665
    def __init__(self, opt, act, k):
666
        super(GraphgraphNet, self).__init__()
667
        self.grph_net = GraphNet(grph_dim=opt.grph_dim, dropout_rate=opt.dropout_rate, use_edges=1, pooling_ratio=0.20, label_dim=opt.label_dim, init_max=False)
668
        if k is not None:
669
            pt_fname = '_%d.pt' % k
670
            best_grph_ckpt = torch.load(os.path.join(opt.checkpoints_dir, opt.exp_name, 'graph', 'graph'+pt_fname), map_location=torch.device('cpu'))
671
            self.grph_net.load_state_dict(best_grph_ckpt['model_state_dict'])
672
            print("Loading Models:\n", os.path.join(opt.checkpoints_dir, opt.exp_name, 'graph', 'graph'+pt_fname))
673
        self.fusion = define_bifusion(fusion_type=opt.fusion_type, skip=opt.skip, use_bilinear=opt.use_bilinear, gate1=opt.grph_gate, gate2=1-opt.grph_gate if opt.grph_gate else 0, 
674
            dim1=opt.grph_dim, dim2=opt.grph_dim, scale_dim1=opt.grph_scale, scale_dim2=opt.grph_scale, mmhid=opt.mmhid, dropout_rate=opt.dropout_rate)
675
        self.classifier = nn.Sequential(nn.Linear(opt.mmhid, opt.label_dim))
676
        self.act = act
677
        dfs_freeze(self.grph_net)
678
        self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False)
679
        self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False)
680
681
    def forward(self, **kwargs):
682
        grph_vec, _ = self.grph_net(x_grph=kwargs['x_grph'])
683
        features = self.fusion(grph_vec, grph_vec)
684
        hazard = self.classifier(features)
685
        if self.act is not None:
686
            hazard = self.act(hazard)
687
            if isinstance(self.act, nn.Sigmoid):
688
                hazard = hazard * self.output_range + self.output_shift
689
        return features, hazard
690
691
    def __hasattr__(self, name):
692
        if '_parameters' in self.__dict__:
693
            _parameters = self.__dict__['_parameters']
694
            if name in _parameters:
695
                return True
696
        if '_buffers' in self.__dict__:
697
            _buffers = self.__dict__['_buffers']
698
            if name in _buffers:
699
                return True
700
        if '_modules' in self.__dict__:
701
            modules = self.__dict__['_modules']
702
            if name in modules:
703
                return True
704
        return False
705
706
707
class OmicomicNet(nn.Module):
708
    def __init__(self, opt, act, k):
709
        super(OmicomicNet, self).__init__()
710
        self.omic_net = MaxNet(input_dim=opt.input_size_omic, omic_dim=opt.omic_dim, dropout_rate=opt.dropout_rate, act=act, label_dim=opt.label_dim, init_max=False)
711
        if k is not None:
712
            pt_fname = '_%d.pt' % k
713
            best_omic_ckpt = torch.load(os.path.join(opt.checkpoints_dir, opt.exp_name, 'omic', 'omic'+pt_fname), map_location=torch.device('cpu'))
714
            self.omic_net.load_state_dict(best_omic_ckpt['model_state_dict'])
715
            print("Loading Models:\n", os.path.join(opt.checkpoints_dir, opt.exp_name, 'omic', 'omic'+pt_fname))
716
        self.fusion = define_bifusion(fusion_type=opt.fusion_type, skip=opt.skip, use_bilinear=opt.use_bilinear, gate1=opt.omic_gate, gate2=1-opt.omic_gate if opt.omic_gate else 0, 
717
            dim1=opt.omic_dim, dim2=opt.omic_dim, scale_dim1=opt.omic_scale, scale_dim2=opt.omic_scale, mmhid=opt.mmhid, dropout_rate=opt.dropout_rate)
718
        self.classifier = nn.Sequential(nn.Linear(opt.mmhid, opt.label_dim))
719
        self.act = act
720
        dfs_freeze(self.omic_net)
721
        self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False)
722
        self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False)
723
724
    def forward(self, **kwargs):
725
        omic_vec, _ = self.omic_net(x_omic=kwargs['x_omic'])
726
        features = self.fusion(omic_vec, omic_vec)
727
        hazard = self.classifier(features)
728
        if self.act is not None:
729
            hazard = self.act(hazard)
730
            if isinstance(self.act, nn.Sigmoid):
731
                hazard = hazard * self.output_range + self.output_shift
732
        return features, hazard
733
734
    def __hasattr__(self, name):
735
        if '_parameters' in self.__dict__:
736
            _parameters = self.__dict__['_parameters']
737
            if name in _parameters:
738
                return True
739
        if '_buffers' in self.__dict__:
740
            _buffers = self.__dict__['_buffers']
741
            if name in _buffers:
742
                return True
743
        if '_modules' in self.__dict__:
744
            modules = self.__dict__['_modules']
745
            if name in modules:
746
                return True
747
        return False