--- a
+++ b/networks.py
@@ -0,0 +1,747 @@
+# Base / Native
+import csv
+from collections import Counter
+import copy
+import json
+import functools
+import gc
+import logging
+import math
+import os
+import pdb
+import pickle
+import random
+import sys
+import tables
+import time
+from tqdm import tqdm
+
+# Numerical / Array
+import numpy as np
+
+# Torch
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch import Tensor
+from torch.autograd import Variable
+from torch.nn import init, Parameter
+from torch.utils.data import DataLoader
+from torch.utils.model_zoo import load_url as load_state_dict_from_url
+from torchvision import datasets, transforms
+import torch.optim.lr_scheduler as lr_scheduler
+from torch_geometric.nn import GCNConv, SAGEConv, GraphConv, GatedGraphConv, GATConv
+from torch_geometric.nn import GraphConv, TopKPooling, SAGPooling
+from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
+from torch_geometric.transforms.normalize_features import NormalizeFeatures
+
+# Env
+from fusion import *
+from options import parse_args
+from utils import *
+
+
+################
+# Network Utils
+################
+def define_net(opt, k):
+    net = None
+    act = define_act_layer(act_type=opt.act_type)
+    init_max = True if opt.init_type == "max" else False
+
+    if opt.mode == "path":
+        net = get_vgg(path_dim=opt.path_dim, act=act, label_dim=opt.label_dim)
+    elif opt.mode == "graph":
+        net = GraphNet(grph_dim=opt.grph_dim, dropout_rate=opt.dropout_rate, GNN=opt.GNN, use_edges=opt.use_edges, pooling_ratio=opt.pooling_ratio, act=act, label_dim=opt.label_dim, init_max=init_max)
+    elif opt.mode == "omic":
+        net = MaxNet(input_dim=opt.input_size_omic, omic_dim=opt.omic_dim, dropout_rate=opt.dropout_rate, act=act, label_dim=opt.label_dim, init_max=init_max)
+    elif opt.mode == "graphomic":
+        net = GraphomicNet(opt=opt, act=act, k=k)
+    elif opt.mode == "pathomic":
+        net = PathomicNet(opt=opt, act=act, k=k)
+    elif opt.mode == "pathgraphomic":
+        net = PathgraphomicNet(opt=opt, act=act, k=k)
+    elif opt.mode == "pathpath":
+        net = PathpathNet(opt=opt, act=act, k=k)
+    elif opt.mode == "graphgraph":
+        net = GraphgraphNet(opt=opt, act=act, k=k)
+    elif opt.mode == "omicomic":
+        net = OmicomicNet(opt=opt, act=act, k=k)
+    else:
+        raise NotImplementedError('model [%s] is not implemented' % opt.model)
+    return init_net(net, opt.init_type, opt.init_gain, opt.gpu_ids)
+
+
+def define_optimizer(opt, model):
+    optimizer = None
+    if opt.optimizer_type == 'adabound':
+        optimizer = adabound.AdaBound(model.parameters(), lr=opt.lr, final_lr=opt.final_lr)
+    elif opt.optimizer_type == 'adam':
+        optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay)
+    elif opt.optimizer_type == 'adagrad':
+        optimizer = torch.optim.Adagrad(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay, initial_accumulator_value=0.1)
+    else:
+        raise NotImplementedError('initialization method [%s] is not implemented' % opt.optimizer)
+    return optimizer
+
+
+def define_reg(opt, model):
+    loss_reg = None
+    
+    if opt.reg_type == 'none':
+        loss_reg = 0
+    elif opt.reg_type == 'path':
+        loss_reg = regularize_path_weights(model=model)
+    elif opt.reg_type == 'mm':
+        loss_reg = regularize_MM_weights(model=model)
+    elif opt.reg_type == 'all':
+        loss_reg = regularize_weights(model=model)
+    elif opt.reg_type == 'omic':
+        loss_reg = regularize_MM_omic(model=model)
+    else:
+        raise NotImplementedError('reg method [%s] is not implemented' % opt.reg_type)
+    return loss_reg
+
+
+def define_scheduler(opt, optimizer):
+    if opt.lr_policy == 'linear':
+        def lambda_rule(epoch):
+            lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
+            return lr_l
+        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
+    elif opt.lr_policy == 'exp':
+        scheduler = lr_scheduler.ExponentialLR(optimizer, 0.1, last_epoch=-1)
+    elif opt.lr_policy == 'step':
+       scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
+    elif opt.lr_policy == 'plateau':
+       scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
+    elif opt.lr_policy == 'cosine':
+       scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
+    else:
+       return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
+    return scheduler
+
+
+def define_act_layer(act_type='Tanh'):
+    if act_type == 'Tanh':
+        act_layer = nn.Tanh()
+    elif act_type == 'ReLU':
+        act_layer = nn.ReLU()
+    elif act_type == 'Sigmoid':
+        act_layer = nn.Sigmoid()
+    elif act_type == 'LSM':
+        act_layer = nn.LogSoftmax(dim=1)
+    elif act_type == "none":
+        act_layer = None
+    else:
+        raise NotImplementedError('activation layer [%s] is not found' % act_type)
+    return act_layer
+
+
+def define_bifusion(fusion_type, skip=1, use_bilinear=1, gate1=1, gate2=1, dim1=32, dim2=32, scale_dim1=1, scale_dim2=1, mmhid=64, dropout_rate=0.25):
+    fusion = None
+    if fusion_type == 'pofusion':
+        fusion = BilinearFusion(skip=skip, use_bilinear=use_bilinear, gate1=gate1, gate2=gate2, dim1=dim1, dim2=dim2, scale_dim1=scale_dim1, scale_dim2=scale_dim2, mmhid=mmhid, dropout_rate=dropout_rate)
+    else:
+        raise NotImplementedError('fusion type [%s] is not found' % fusion_type)
+    return fusion
+
+
+def define_trifusion(fusion_type, skip=1, use_bilinear=1, gate1=1, gate2=1, gate3=3, dim1=32, dim2=32, dim3=32, scale_dim1=1, scale_dim2=1, scale_dim3=1, mmhid=96, dropout_rate=0.25):
+    fusion = None
+    if fusion_type == 'pofusion_A':
+        fusion = TrilinearFusion_A(skip=skip, use_bilinear=use_bilinear, gate1=gate1, gate2=gate2, gate3=gate3, dim1=dim1, dim2=dim2, dim3=dim3, scale_dim1=scale_dim1, scale_dim2=scale_dim2, scale_dim3=scale_dim3, mmhid=mmhid, dropout_rate=dropout_rate)
+    elif fusion_type == 'pofusion_B':
+        fusion = TrilinearFusion_B(skip=skip, use_bilinear=use_bilinear, gate1=gate1, gate2=gate2, gate3=gate3, dim1=dim1, dim2=dim2, dim3=dim3, scale_dim1=scale_dim1, scale_dim2=scale_dim2, scale_dim3=scale_dim3, mmhid=mmhid, dropout_rate=dropout_rate)
+    else:
+        raise NotImplementedError('fusion type [%s] is not found' % fusion_type)
+    return fusion
+
+
+
+############
+# Omic Model
+############
+class MaxNet(nn.Module):
+    def __init__(self, input_dim=80, omic_dim=32, dropout_rate=0.25, act=None, label_dim=1, init_max=True):
+        super(MaxNet, self).__init__()
+        hidden = [64, 48, 32, 32]
+        self.act = act
+
+        encoder1 = nn.Sequential(
+            nn.Linear(input_dim, hidden[0]),
+            nn.ELU(),
+            nn.AlphaDropout(p=dropout_rate, inplace=False))
+        
+        encoder2 = nn.Sequential(
+            nn.Linear(hidden[0], hidden[1]),
+            nn.ELU(),
+            nn.AlphaDropout(p=dropout_rate, inplace=False))
+        
+        encoder3 = nn.Sequential(
+            nn.Linear(hidden[1], hidden[2]),
+            nn.ELU(),
+            nn.AlphaDropout(p=dropout_rate, inplace=False))
+
+        encoder4 = nn.Sequential(
+            nn.Linear(hidden[2], omic_dim),
+            nn.ELU(),
+            nn.AlphaDropout(p=dropout_rate, inplace=False))
+        
+        self.encoder = nn.Sequential(encoder1, encoder2, encoder3, encoder4)
+        self.classifier = nn.Sequential(nn.Linear(omic_dim, label_dim))
+
+        if init_max: init_max_weights(self)
+
+        self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False)
+        self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False)
+
+    def forward(self, **kwargs):
+        x = kwargs['x_omic']
+        features = self.encoder(x)
+        out = self.classifier(features)
+        if self.act is not None:
+            out = self.act(out)
+
+            if isinstance(self.act, nn.Sigmoid):
+                out = out * self.output_range + self.output_shift
+
+        return features, out
+
+
+
+############
+# Graph Model
+############
+class NormalizeFeaturesV2(object):
+    r"""Column-normalizes node features to sum-up to one."""
+
+    def __call__(self, data):
+        data.x = data.x / data.x.max(0, keepdim=True)[0]#.type(torch.cuda.FloatTensor)
+        return data
+
+    def __repr__(self):
+        return '{}()'.format(self.__class__.__name__)
+
+class NormalizeFeaturesV2(object):
+    r"""Column-normalizes node features to sum-up to one."""
+
+    def __call__(self, data):
+        data.x[:, :12] = data.x[:, :12] / data.x[:, :12].max(0, keepdim=True)[0]
+        data.x = data.x.type(torch.cuda.FloatTensor)
+        return data
+
+    def __repr__(self):
+        return '{}()'.format(self.__class__.__name__)
+
+
+class NormalizeEdgesV2(object):
+    r"""Column-normalizes node features to sum-up to one."""
+
+    def __call__(self, data):
+        data.edge_attr = data.edge_attr.type(torch.cuda.FloatTensor)
+        data.edge_attr = data.edge_attr / data.edge_attr.max(0, keepdim=True)[0]#.type(torch.cuda.FloatTensor)
+        return data
+
+    def __repr__(self):
+        return '{}()'.format(self.__class__.__name__)
+
+
+class GraphNet(torch.nn.Module):
+    def __init__(self, features=1036, nhid=128, grph_dim=32, nonlinearity=torch.tanh, 
+        dropout_rate=0.25, GNN='GCN', use_edges=0, pooling_ratio=0.20, act=None, label_dim=1, init_max=True):
+        super(GraphNet, self).__init__()
+
+        self.dropout_rate = dropout_rate
+        self.use_edges = use_edges
+        self.act = act
+
+        self.conv1 = SAGEConv(features, nhid)
+        self.pool1 = SAGPooling(nhid, ratio=pooling_ratio, gnn=GNN)#, nonlinearity=nonlinearity)
+        self.conv2 = SAGEConv(nhid, nhid)
+        self.pool2 = SAGPooling(nhid, ratio=pooling_ratio, gnn=GNN)#, nonlinearity=nonlinearity)
+        self.conv3 = SAGEConv(nhid, nhid)
+        self.pool3 = SAGPooling(nhid, ratio=pooling_ratio, gnn=GNN)#, nonlinearity=nonlinearity)
+
+        self.lin1 = torch.nn.Linear(nhid*2, nhid)
+        self.lin2 = torch.nn.Linear(nhid, grph_dim)
+        self.lin3 = torch.nn.Linear(grph_dim, label_dim)
+
+        self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False)
+        self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False)
+
+        if init_max: 
+            init_max_weights(self)
+            print("Initialzing with Max")
+
+    def forward(self, **kwargs):
+        data = kwargs['x_grph']
+        data = NormalizeFeaturesV2()(data)
+        data = NormalizeEdgesV2()(data)
+        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
+
+        #x, edge_index, edge_attr, batch = data.x.type(torch.cuda.FloatTensor), data.edge_index.type(torch.cuda.LongTensor), data.edge_attr.type(torch.cuda.FloatTensor), data.batch
+        x = F.relu(self.conv1(x, edge_index))
+        x, edge_index, edge_attr, batch, _ = self.pool1(x, edge_index, edge_attr, batch)
+        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
+
+        x = F.relu(self.conv2(x, edge_index))
+        x, edge_index, edge_attr, batch, _ = self.pool2(x, edge_index, edge_attr, batch)
+        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
+
+        x = F.relu(self.conv3(x, edge_index))
+        x, edge_index, edge_attr, batch, _ = self.pool3(x, edge_index, edge_attr, batch)
+        x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)
+
+        x = x1 + x2 + x3 
+
+        x = F.relu(self.lin1(x))
+        x = F.dropout(x, p=self.dropout_rate, training=self.training)
+        features = F.relu(self.lin2(x))
+        out = self.lin3(features)
+        if self.act is not None:
+            out = self.act(out)
+
+            if isinstance(self.act, nn.Sigmoid):
+                out = out * self.output_range + self.output_shift
+
+        return features, out
+
+
+
+############
+# Path Model
+############
+model_urls = {
+    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
+    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
+    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
+    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
+    'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
+    'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
+    'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
+    'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
+}
+
+
+class PathNet(nn.Module):
+
+    def __init__(self, features, path_dim=32, act=None, num_classes=1):
+        super(PathNet, self).__init__()
+        self.features = features
+        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
+        
+        self.classifier = nn.Sequential(
+            nn.Linear(512 * 7 * 7, 1024),
+            nn.ReLU(True),
+            nn.Dropout(0.25),
+            nn.Linear(1024, 1024),
+            nn.ReLU(True),
+            nn.Dropout(0.25),
+            nn.Linear(1024, path_dim),
+            nn.ReLU(True),
+            nn.Dropout(0.05)
+        )
+
+        self.linear = nn.Linear(path_dim, num_classes)
+        self.act = act
+
+        self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False)
+        self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False)
+
+        dfs_freeze(self.features)
+
+    def forward(self, **kwargs):
+        x = kwargs['x_path']
+        x = self.features(x)
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1)
+        features = self.classifier(x)
+        hazard = self.linear(features)
+
+        if self.act is not None:
+            hazard = self.act(hazard)
+
+            if isinstance(self.act, nn.Sigmoid):
+                hazard = hazard * self.output_range + self.output_shift
+
+        return features, hazard
+
+
+def make_layers(cfg, batch_norm=False):
+    layers = []
+    in_channels = 3
+    for v in cfg:
+        if v == 'M':
+            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
+        else:
+            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
+            if batch_norm:
+                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
+            else:
+                layers += [conv2d, nn.ReLU(inplace=True)]
+            in_channels = v
+    return nn.Sequential(*layers)
+
+
+cfgs = {
+    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
+    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
+}
+
+
+
+def get_vgg(arch='vgg19_bn', cfg='E', act=None, batch_norm=True, label_dim=1, pretrained=True, progress=True, **kwargs):
+    model = PathNet(make_layers(cfgs[cfg], batch_norm=batch_norm), act=act, num_classes=label_dim, **kwargs)
+    
+    if pretrained:
+        pretrained_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
+
+        for key in list(pretrained_dict.keys()):
+            if 'classifier' in key: pretrained_dict.pop(key)
+
+        model.load_state_dict(pretrained_dict, strict=False)
+        print("Initializing Path Weights")
+
+    return model
+
+
+
+##############################################################################
+# Graph + Omic
+##############################################################################
+class GraphomicNet(nn.Module):
+    def __init__(self, opt, act, k):
+        super(GraphomicNet, self).__init__()
+        self.grph_net = GraphNet(grph_dim=opt.grph_dim, dropout_rate=opt.dropout_rate, use_edges=1, pooling_ratio=0.20, label_dim=opt.label_dim, init_max=False)
+        self.omic_net = MaxNet(input_dim=opt.input_size_omic, omic_dim=opt.omic_dim, dropout_rate=opt.dropout_rate, act=act, label_dim=opt.label_dim, init_max=False)
+
+        if k is not None:
+            pt_fname = '_%d.pt' % k
+            best_grph_ckpt = torch.load(os.path.join(opt.checkpoints_dir, opt.exp_name, 'graph', 'graph'+pt_fname), map_location=torch.device('cpu'))
+            best_omic_ckpt = torch.load(os.path.join(opt.checkpoints_dir, opt.exp_name, 'omic', 'omic'+pt_fname), map_location=torch.device('cpu'))
+            self.grph_net.load_state_dict(best_grph_ckpt['model_state_dict'])
+            self.omic_net.load_state_dict(best_omic_ckpt['model_state_dict'])
+            print("Loading Models:\n", os.path.join(opt.checkpoints_dir, opt.exp_name, 'graph', 'graph'+pt_fname), "\n", os.path.join(opt.checkpoints_dir, opt.exp_name, 'omic', 'omic'+pt_fname))
+
+        self.fusion = define_bifusion(fusion_type=opt.fusion_type, skip=opt.skip, use_bilinear=opt.use_bilinear, gate1=opt.grph_gate, gate2=opt.omic_gate, dim1=opt.grph_dim, dim2=opt.omic_dim, scale_dim1=opt.grph_scale, scale_dim2=opt.omic_scale, mmhid=opt.mmhid, dropout_rate=opt.dropout_rate)
+        self.classifier = nn.Sequential(nn.Linear(opt.mmhid, opt.label_dim))
+        self.act = act
+
+        dfs_freeze(self.grph_net)
+        dfs_freeze(self.omic_net)
+        self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False)
+        self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False)
+
+    def forward(self, **kwargs):
+        grph_vec, _ = self.grph_net(x_grph=kwargs['x_grph'])
+        omic_vec, _ = self.omic_net(x_omic=kwargs['x_omic'])
+        features = self.fusion(grph_vec, omic_vec)
+        hazard = self.classifier(features)
+        if self.act is not None:
+            hazard = self.act(hazard)
+
+            if isinstance(self.act, nn.Sigmoid):
+                hazard = hazard * self.output_range + self.output_shift
+
+        return features, hazard
+
+    def __hasattr__(self, name):
+        if '_parameters' in self.__dict__:
+            _parameters = self.__dict__['_parameters']
+            if name in _parameters:
+                return True
+        if '_buffers' in self.__dict__:
+            _buffers = self.__dict__['_buffers']
+            if name in _buffers:
+                return True
+        if '_modules' in self.__dict__:
+            modules = self.__dict__['_modules']
+            if name in modules:
+                return True
+        return False
+
+
+##############################################################################
+# Path + Omic
+##############################################################################
+class PathomicNet(nn.Module):
+    def __init__(self, opt, act, k):
+        super(PathomicNet, self).__init__()
+        self.omic_net = MaxNet(input_dim=opt.input_size_omic, omic_dim=opt.omic_dim, dropout_rate=opt.dropout_rate, act=act, label_dim=opt.label_dim, init_max=False)
+
+        if k is not None:
+            pt_fname = '_%d.pt' % k
+            best_omic_ckpt = torch.load(os.path.join(opt.checkpoints_dir, opt.exp_name, 'omic', 'omic'+pt_fname), map_location=torch.device('cpu'))
+            self.omic_net.load_state_dict(best_omic_ckpt['model_state_dict'])
+            print("Loading Models:\n", os.path.join(opt.checkpoints_dir, opt.exp_name, 'omic', 'omic'+pt_fname))
+
+        self.fusion = define_bifusion(fusion_type=opt.fusion_type, skip=opt.skip, use_bilinear=opt.use_bilinear, gate1=opt.path_gate, gate2=opt.omic_gate, dim1=opt.path_dim, dim2=opt.omic_dim, scale_dim1=opt.path_scale, scale_dim2=opt.omic_scale, mmhid=opt.mmhid, dropout_rate=opt.dropout_rate)
+        self.classifier = nn.Sequential(nn.Linear(opt.mmhid, opt.label_dim))
+        self.act = act
+
+        dfs_freeze(self.omic_net)
+        self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False)
+        self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False)
+
+    def forward(self, **kwargs):
+        path_vec = kwargs['x_path']
+        omic_vec, _ = self.omic_net(x_omic=kwargs['x_omic'])
+        features = self.fusion(path_vec, omic_vec)
+        hazard = self.classifier(features)
+        if self.act is not None:
+            hazard = self.act(hazard)
+
+            if isinstance(self.act, nn.Sigmoid):
+                hazard = hazard * self.output_range + self.output_shift
+
+        return features, hazard
+
+    def __hasattr__(self, name):
+        if '_parameters' in self.__dict__:
+            _parameters = self.__dict__['_parameters']
+            if name in _parameters:
+                return True
+        if '_buffers' in self.__dict__:
+            _buffers = self.__dict__['_buffers']
+            if name in _buffers:
+                return True
+        if '_modules' in self.__dict__:
+            modules = self.__dict__['_modules']
+            if name in modules:
+                return True
+        return False
+
+
+
+#############################################################################
+# Path + Graph + Omic
+##############################################################################
+class PathgraphomicNet(nn.Module):
+    def __init__(self, opt, act, k):
+        super(PathgraphomicNet, self).__init__()
+        self.grph_net = GraphNet(grph_dim=opt.grph_dim, dropout_rate=opt.dropout_rate, use_edges=1, pooling_ratio=0.20, label_dim=opt.label_dim, init_max=False)
+        self.omic_net = MaxNet(input_dim=opt.input_size_omic, omic_dim=opt.omic_dim, dropout_rate=opt.dropout_rate, act=act, label_dim=opt.label_dim, init_max=False)
+
+        if k is not None:
+            pt_fname = '_%d.pt' % k
+            best_grph_ckpt = torch.load(os.path.join(opt.checkpoints_dir, opt.exp_name, 'graph', 'graph'+pt_fname), map_location=torch.device('cpu'))
+            best_omic_ckpt = torch.load(os.path.join(opt.checkpoints_dir, opt.exp_name, 'omic', 'omic'+pt_fname), map_location=torch.device('cpu'))
+            self.grph_net.load_state_dict(best_grph_ckpt['model_state_dict'])
+            self.omic_net.load_state_dict(best_omic_ckpt['model_state_dict'])
+            print("Loading Models:\n", os.path.join(opt.checkpoints_dir, opt.exp_name, 'graph', 'graph'+pt_fname), "\n", os.path.join(opt.checkpoints_dir, opt.exp_name, 'omic', 'omic'+pt_fname))
+
+        self.fusion = define_trifusion(fusion_type=opt.fusion_type, skip=opt.skip, use_bilinear=opt.use_bilinear, gate1=opt.path_gate, gate2=opt.grph_gate, gate3=opt.omic_gate, dim1=opt.path_dim, dim2=opt.grph_dim, dim3=opt.omic_dim, scale_dim1=opt.path_scale, scale_dim2=opt.grph_scale, scale_dim3=opt.omic_scale, mmhid=opt.mmhid, dropout_rate=opt.dropout_rate)
+        self.classifier = nn.Sequential(nn.Linear(opt.mmhid, opt.label_dim))
+        self.act = act
+
+        dfs_freeze(self.grph_net)
+        dfs_freeze(self.omic_net)
+        self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False)
+        self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False)
+
+    def forward(self, **kwargs):
+        path_vec = kwargs['x_path']
+        grph_vec, _ = self.grph_net(x_grph=kwargs['x_grph'])
+        omic_vec, _ = self.omic_net(x_omic=kwargs['x_omic'])
+        features = self.fusion(path_vec, grph_vec, omic_vec)
+        hazard = self.classifier(features)
+        if self.act is not None:
+            hazard = self.act(hazard)
+
+            if isinstance(self.act, nn.Sigmoid):
+                hazard = hazard * self.output_range + self.output_shift
+
+        return features, hazard
+
+    def __hasattr__(self, name):
+        if '_parameters' in self.__dict__:
+            _parameters = self.__dict__['_parameters']
+            if name in _parameters:
+                return True
+        if '_buffers' in self.__dict__:
+            _buffers = self.__dict__['_buffers']
+            if name in _buffers:
+                return True
+        if '_modules' in self.__dict__:
+            modules = self.__dict__['_modules']
+            if name in modules:
+                return True
+        return False
+
+
+
+##############################################################################
+# Ensembling Effects
+##############################################################################
+class PathgraphNet(nn.Module):
+    def __init__(self, opt, act, k):
+        super(PathgraphNet, self).__init__()
+        self.grph_net = GraphNet(grph_dim=opt.grph_dim, dropout_rate=opt.dropout_rate, use_edges=1, pooling_ratio=0.20, label_dim=opt.label_dim, init_max=False)
+
+        if k is not None:
+            pt_fname = '_%d.pt' % k
+            best_grph_ckpt = torch.load(os.path.join(opt.checkpoints_dir, opt.exp_name, 'graph', 'graph'+pt_fname), map_location=torch.device('cpu'))
+            self.grph_net.load_state_dict(best_grph_ckpt['model_state_dict'])
+            print("Loading Models:\n", os.path.join(opt.checkpoints_dir, opt.exp_name, 'graph', 'graph'+pt_fname))
+
+        self.fusion = define_bifusion(fusion_type=opt.fusion_type, skip=opt.skip, use_bilinear=opt.use_bilinear, gate1=opt.path_gate, gate2=opt.grph_gate, dim1=opt.path_dim, dim2=opt.grph_dim, scale_dim1=opt.path_scale, scale_dim2=opt.grph_scale, mmhid=opt.mmhid, dropout_rate=opt.dropout_rate)
+        self.classifier = nn.Sequential(nn.Linear(opt.mmhid, opt.label_dim))
+        self.act = act
+
+        dfs_freeze(self.grph_net)
+        self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False)
+        self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False)
+
+    def forward(self, **kwargs):
+        path_vec = kwargs['x_path']
+        grph_vec, _ = self.grph_net(x_grph=kwargs['x_grph'])
+        features = self.fusion(path_vec, grph_vec)
+        hazard = self.classifier(features)
+        if self.act is not None:
+            hazard = self.act(hazard)
+
+            if isinstance(self.act, nn.Sigmoid):
+                hazard = hazard * self.output_range + self.output_shift
+
+        return features, hazard
+
+    def __hasattr__(self, name):
+        if '_parameters' in self.__dict__:
+            _parameters = self.__dict__['_parameters']
+            if name in _parameters:
+                return True
+        if '_buffers' in self.__dict__:
+            _buffers = self.__dict__['_buffers']
+            if name in _buffers:
+                return True
+        if '_modules' in self.__dict__:
+            modules = self.__dict__['_modules']
+            if name in modules:
+                return True
+        return False
+
+
+class PathpathNet(nn.Module):
+    def __init__(self, opt, act, k):
+        super(PathpathNet, self).__init__()
+        self.fusion = define_bifusion(fusion_type=opt.fusion_type, skip=opt.skip, use_bilinear=opt.use_bilinear, gate1=opt.path_gate, gate2=1-opt.path_gate if opt.path_gate else 0, 
+            dim1=opt.path_dim, dim2=opt.path_dim, scale_dim1=opt.path_scale, scale_dim2=opt.path_scale, mmhid=opt.mmhid, dropout_rate=opt.dropout_rate)
+        self.classifier = nn.Sequential(nn.Linear(opt.mmhid, opt.label_dim))
+        self.act = act
+        self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False)
+        self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False)
+
+    def forward(self, **kwargs):
+        path_vec = kwargs['x_path']
+        features = self.fusion(path_vec, path_vec)
+        hazard = self.classifier(features)
+        if self.act is not None:
+            hazard = self.act(hazard)
+            if isinstance(self.act, nn.Sigmoid):
+                hazard = hazard * self.output_range + self.output_shift
+        return features, hazard
+
+    def __hasattr__(self, name):
+        if '_parameters' in self.__dict__:
+            _parameters = self.__dict__['_parameters']
+            if name in _parameters:
+                return True
+        if '_buffers' in self.__dict__:
+            _buffers = self.__dict__['_buffers']
+            if name in _buffers:
+                return True
+        if '_modules' in self.__dict__:
+            modules = self.__dict__['_modules']
+            if name in modules:
+                return True
+        return False
+
+
+class GraphgraphNet(nn.Module):
+    def __init__(self, opt, act, k):
+        super(GraphgraphNet, self).__init__()
+        self.grph_net = GraphNet(grph_dim=opt.grph_dim, dropout_rate=opt.dropout_rate, use_edges=1, pooling_ratio=0.20, label_dim=opt.label_dim, init_max=False)
+        if k is not None:
+            pt_fname = '_%d.pt' % k
+            best_grph_ckpt = torch.load(os.path.join(opt.checkpoints_dir, opt.exp_name, 'graph', 'graph'+pt_fname), map_location=torch.device('cpu'))
+            self.grph_net.load_state_dict(best_grph_ckpt['model_state_dict'])
+            print("Loading Models:\n", os.path.join(opt.checkpoints_dir, opt.exp_name, 'graph', 'graph'+pt_fname))
+        self.fusion = define_bifusion(fusion_type=opt.fusion_type, skip=opt.skip, use_bilinear=opt.use_bilinear, gate1=opt.grph_gate, gate2=1-opt.grph_gate if opt.grph_gate else 0, 
+            dim1=opt.grph_dim, dim2=opt.grph_dim, scale_dim1=opt.grph_scale, scale_dim2=opt.grph_scale, mmhid=opt.mmhid, dropout_rate=opt.dropout_rate)
+        self.classifier = nn.Sequential(nn.Linear(opt.mmhid, opt.label_dim))
+        self.act = act
+        dfs_freeze(self.grph_net)
+        self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False)
+        self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False)
+
+    def forward(self, **kwargs):
+        grph_vec, _ = self.grph_net(x_grph=kwargs['x_grph'])
+        features = self.fusion(grph_vec, grph_vec)
+        hazard = self.classifier(features)
+        if self.act is not None:
+            hazard = self.act(hazard)
+            if isinstance(self.act, nn.Sigmoid):
+                hazard = hazard * self.output_range + self.output_shift
+        return features, hazard
+
+    def __hasattr__(self, name):
+        if '_parameters' in self.__dict__:
+            _parameters = self.__dict__['_parameters']
+            if name in _parameters:
+                return True
+        if '_buffers' in self.__dict__:
+            _buffers = self.__dict__['_buffers']
+            if name in _buffers:
+                return True
+        if '_modules' in self.__dict__:
+            modules = self.__dict__['_modules']
+            if name in modules:
+                return True
+        return False
+
+
+class OmicomicNet(nn.Module):
+    def __init__(self, opt, act, k):
+        super(OmicomicNet, self).__init__()
+        self.omic_net = MaxNet(input_dim=opt.input_size_omic, omic_dim=opt.omic_dim, dropout_rate=opt.dropout_rate, act=act, label_dim=opt.label_dim, init_max=False)
+        if k is not None:
+            pt_fname = '_%d.pt' % k
+            best_omic_ckpt = torch.load(os.path.join(opt.checkpoints_dir, opt.exp_name, 'omic', 'omic'+pt_fname), map_location=torch.device('cpu'))
+            self.omic_net.load_state_dict(best_omic_ckpt['model_state_dict'])
+            print("Loading Models:\n", os.path.join(opt.checkpoints_dir, opt.exp_name, 'omic', 'omic'+pt_fname))
+        self.fusion = define_bifusion(fusion_type=opt.fusion_type, skip=opt.skip, use_bilinear=opt.use_bilinear, gate1=opt.omic_gate, gate2=1-opt.omic_gate if opt.omic_gate else 0, 
+            dim1=opt.omic_dim, dim2=opt.omic_dim, scale_dim1=opt.omic_scale, scale_dim2=opt.omic_scale, mmhid=opt.mmhid, dropout_rate=opt.dropout_rate)
+        self.classifier = nn.Sequential(nn.Linear(opt.mmhid, opt.label_dim))
+        self.act = act
+        dfs_freeze(self.omic_net)
+        self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False)
+        self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False)
+
+    def forward(self, **kwargs):
+        omic_vec, _ = self.omic_net(x_omic=kwargs['x_omic'])
+        features = self.fusion(omic_vec, omic_vec)
+        hazard = self.classifier(features)
+        if self.act is not None:
+            hazard = self.act(hazard)
+            if isinstance(self.act, nn.Sigmoid):
+                hazard = hazard * self.output_range + self.output_shift
+        return features, hazard
+
+    def __hasattr__(self, name):
+        if '_parameters' in self.__dict__:
+            _parameters = self.__dict__['_parameters']
+            if name in _parameters:
+                return True
+        if '_buffers' in self.__dict__:
+            _buffers = self.__dict__['_buffers']
+            if name in _buffers:
+                return True
+        if '_modules' in self.__dict__:
+            modules = self.__dict__['_modules']
+            if name in modules:
+                return True
+        return False
\ No newline at end of file