In [None]:
import socket
if socket.gethostname() == 'dlm':
  %env CUDA_DEVICE_ORDER=PCI_BUS_ID
  %env CUDA_VISIBLE_DEVICES=3

In [None]:
import os
import sys
import re
import collections
import functools
import requests, zipfile, io
import pickle
import copy

import pandas
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import sklearn
import sklearn.decomposition
import sklearn.metrics
import networkx

import torch
import torch.nn as nn

lib_path = 'I:/code'
if not os.path.exists(lib_path):
  lib_path = '/media/6T/.tianle/.lib'
if not os.path.exists(lib_path):
  lib_path = '/projects/academic/azhang/tianlema/lib'
if os.path.exists(lib_path) and lib_path not in sys.path:
  sys.path.append(lib_path)
  
from dl.models.basic_models import *
from dl.utils.visualization.visualization import *
from dl.utils.outlier import *
from dl.utils.train import *
from autoencoder.autoencoder import *
from dl.utils.utils import get_overlap_samples, filter_clinical_dict, get_target_variable
from dl.utils.utils import get_shuffled_data, target_to_numpy

%load_ext autoreload
%autoreload 2


use_gpu = True
if use_gpu and torch.cuda.is_available():
  device = torch.device('cuda')
  print('Using GPU:)')
else:
  device = torch.device('cpu')
  print('Using CPU:(')

In [None]:
# neural net models include nn (mlp), resnet, densenet; another choice is ml (machine learning)
# model_type, dense, residual are dependent
model_type = 'resnet'
dense = False
residual = True
hidden_dim = [100, 100]
train_portion = 0.7
val_portion = 0.1
test_portion = 0.2
num_train_types = -1 # -1 means not used
num_val_types = -1
num_test_types = -1 # this will almost never be used 
num_sets = 10
num_folds = 10 # no longer used anymore
sel_set_idx = 0
cv_type = 'instance-shuffle' # or 'group-shuffle'; cross validation shuffle method
sel_disease_types = 'all'
# The number of total samples and the numbers for each class in selected disease types must >=
min_num_samples_per_type_cls = [100, 0]
# if 'auto-search', will search for the file first; if not exist, then generate random data split
# and write to the file;
# if string other than 'auto-search' is provided, assume the string is a proper file name, 
# and read the file;
# if False, will generate a random data split, but not write to file 
# if True will generate a random data split, and write to file
predefined_sample_set_file = 'auto-search' 
target_variable = 'PFI' # To do: target variable can be a list (partially handled)
target_variable_type = 'discrete' # or 'continuous' real numbers
target_variable_range = [0, 1]
data_type = ['gene', 'methy', 'rppa', 'mirna']
normal_transform_feature = True
additional_vars = []#['age_at_initial_pathologic_diagnosis', 'gender']
additional_var_types = []#['continuous', 'discrete']
additional_var_ranges = []#[[0, 100], ['MALE', 'FEMALE']]
randomize_labels = False
lr = 5e-4
weight_decay = 1e-4
num_epochs = 1000
reduce_every = 500
show_results_in_notebook = True

## Prepare data

In [None]:
result_folder = 'results'
data_split_idx_folder = f'{result_folder}/data_split_idx'
project_folder = '../../pan-can-atlas'
print_stats = True
if not os.path.exists(project_folder):
  project_folder = 'F:/TCGA/Pan-Cancer-Atlas'
filepath = f'{project_folder}/data/processed/combined2.pkl'
with open(filepath, 'rb') as f:
  data = pickle.load(f)
  patient_clinical = data['patient_clinical']
  feature_mat_dict = data['feature_mat_dict']
  feature_interaction_mat_dict = data['feature_interaction_mat_dict']
  feature_id_dict = data['feature_id_dict']
  aliquot_id_dict = data['aliquot_id_dict']
#   sel_patient_ids = data['sample_id_sel']
#   sample_idx_sel_dict = data['sample_idx_sel_dict']
#   for k, v in sample_idx_sel_dict.items():
#     assert [i[:12] for i in aliquot_id_dict[k][v]] == sel_patient_ids

if print_stats:
  for k, v in feature_mat_dict.items():
    print(f'feature_mat: {k}, max={v.max():.3f}, min={v.min():.3f}, '
          f'mean={v.mean():.3f}, {np.mean(v>0):.3f}')  
  for k, v in feature_interaction_mat_dict.items():
    print(f'feature_interaction_mat: {k}, max={v.max():.3f}, min={v.min():.3f}, '
          f'mean={v.mean():.3f}, {np.mean(v>0):.3f}') 
  for k, v in feature_id_dict.items():
    print(k, v.shape, v[0])
  for k, v in aliquot_id_dict.items():
    print(k, v.shape, v[0])

In [None]:
# select samples with required clinical variables
clinical_dict = filter_clinical_dict(target_variable, target_variable_type=target_variable_type, 
                                     target_variable_range=target_variable_range, 
                                     clinical_dict=patient_clinical)
if len(additional_vars) > 0:
  clinical_dict = filter_clinical_dict(additional_vars, target_variable_type=additional_var_types, 
                                       target_variable_range=additional_var_ranges, 
                                       clinical_dict=clinical_dict)

# select samples with feature matrix of given type(s)
if isinstance(data_type, str):
  sample_list = {s[:12] for s in aliquot_id_dict[data_type]}
  data_type_str = data_type
elif isinstance(data_type, (list, tuple)):
  sample_list = get_overlap_samples([aliquot_id_dict[dtype] for dtype in data_type], 
                                    common_list=None, start=0, end=12, return_common_list=True)
  data_type_str = '-'.join(sorted(data_type))
else:
  raise ValueError(f'data_type must be str or list/tuple, but is {type(data_type)}')
sample_list = sample_list.intersection(clinical_dict)

# select samples with given disease types
sel_disease_type_str = sel_disease_types # will be overwritten if it is a list
if isinstance(sel_disease_types, (list, tuple)):
  sample_list = [s for s in sample_list if clinical_dict[s]['type'] in sel_disease_types]
  sel_disease_type_str = '-'.join(sorted(sel_disease_types))
elif isinstance(sel_disease_types, str) and sel_disease_types!='all':
  sample_list = [s for s in sample_list if clinical_dict[s]['type'] == sel_disease_types]
else:
  assert sel_disease_types == 'all'
 
# For classification tasks with given min_num_samples_per_type_cls,
# only keep disease types that have a minimal number of samples per type and per class
# Reflection: it might be better to use collections.defaultdict(list) to store samples in each type
type_cnt = collections.Counter([clinical_dict[s]['type'] for s in sample_list])
if sum(min_num_samples_per_type_cls)>0 and (target_variable_type=='discrete' 
                                            or target_variable_type[0]=='discrete'):
  # the number of samples in each disease type >= min_num_samples_per_type_cls[0]
  type_cnt = {k: v for k, v in type_cnt.items() if v >= min_num_samples_per_type_cls[0]}
  disease_type_cnt = {}
  for k in type_cnt:
    # collections.Counter can accept generator
    cls_cnt = collections.Counter(clinical_dict[s][target_variable] 
                                  if isinstance(target_variable, str) 
                                  else clinical_dict[s][target_variable[0]] 
                                  for s in sample_list if clinical_dict[s]['type']==k)
    if all([v >= min_num_samples_per_type_cls[1] for v in cls_cnt.values()]):
      # the number of samples in each class >= min_num_samples_per_type_cls[1]
      disease_type_cnt[k] = dict(cls_cnt)
      print(k, disease_type_cnt[k])
  sample_list = [s for s in sample_list if clinical_dict[s]['type'] in disease_type_cnt]
sel_patient_ids = sorted(sample_list)
print(f'Selected {len(sel_patient_ids)} patients from {len(disease_type_cnt)} disease_types')

### Split data into training, validation, and test sets

In [None]:
predefined_sample_set_filename = (target_variable if isinstance(target_variable,str) 
                                else '-'.join(target_variable))
predefined_sample_set_filename += f'_{cv_type}'
if len(additional_vars) > 0:
  predefined_sample_set_filename += f"_{'-'.join(sorted(additional_vars))}"

predefined_sample_set_filename += (f"_{data_type_str}_{sel_disease_type_str}_"
                                   f"{'-'.join(map(str, min_num_samples_per_type_cls))}")
predefined_sample_set_filename += f"_{'-'.join(map(str, [train_portion, val_portion, test_portion]))}"
if cv_type == 'group-shuffle' and num_train_types > 0:
  predefined_sample_set_filename += f"_{'-'.join(map(str, [num_train_types, num_val_types, num_test_types]))}"
predefined_sample_set_filename += f'_{num_sets}sets'
res_file = f"{predefined_sample_set_filename}_{sel_set_idx}_{'-'.join(map(str, hidden_dim))}_{model_type}.pkl"
predefined_sample_set_filename += '.pkl'
# This will be overwritten if predefined_sample_set_file == 'auto-search' or filepath, and the file exists
predefined_sample_sets = [get_shuffled_data(sel_patient_ids, clinical_dict, cv_type=cv_type, 
                  instance_portions=[train_portion, val_portion, test_portion], 
                  group_sizes=[num_train_types, num_val_types, num_test_types],
                  group_variable_name='type', seed=None, verbose=False) for i in range(num_sets)]
if predefined_sample_set_file == 'auto-search':
  if os.path.exists(f'{data_split_idx_folder}/{predefined_sample_set_filename}'):
    with open(f'{data_split_idx_folder}/{predefined_sample_set_filename}', 'rb') as f:
      print(f'Read predefined_sample_set_file: '
            f'{data_split_idx_folder}/{predefined_sample_set_filename}')
      tmp = pickle.load(f)
      # overwrite calculated predefined_sample_sets
      predefined_sample_sets = tmp['predefined_sample_sets']    
elif isinstance(predefined_sample_set_file, str): # but not 'auto-search'; assume it's a file name
  if os.path.exists(predefined_sample_set_file):
    with open(f'{data_split_idx_folder}/{predefined_sample_set_file}', 'rb') as f:
      print(f'Read predefined_sample_set_file: {data_split_idx_folder}/{predefined_sample_set_file}')
      tmp = pickle.load(f)
      predefined_sample_sets = tmp['predefined_sample_sets']
  else:
    raise ValueError(f'predefined_sample_set_file: {data_split_idx_folder}/{predefined_sample_set_file} does not exist!')

if (not os.path.exists(f'{data_split_idx_folder}/{predefined_sample_set_filename}') 
    and predefined_sample_set_file == 'auto-search') or predefined_sample_set_file is True:
  with open(f'{data_split_idx_folder}/{predefined_sample_set_filename}', 'wb') as f:
      print(f'Write predefined_sample_set_file: {data_split_idx_folder}/{predefined_sample_set_filename}')
      pickle.dump({'predefined_sample_sets': predefined_sample_sets}, f)
     
sel_patient_ids, idx_splits = predefined_sample_sets[sel_set_idx]
train_idx, val_idx, test_idx = idx_splits

In [None]:
if isinstance(data_type, str):
  sample_lists = [aliquot_id_dict[data_type]]
else:
  assert isinstance(data_type, (list, tuple))
  sample_lists = [aliquot_id_dict[dtype] for dtype in data_type]
idx_lists = get_overlap_samples(sample_lists=sample_lists, common_list=sel_patient_ids, 
                    start=0, end=12, return_common_list=False)
sample_idx_sel_dict = {}
if isinstance(data_type, str):
  sample_idx_sel_dict = {data_type: idx_lists[0]}
else:
  sample_idx_sel_dict = {dtype: idx_list for dtype, idx_list in zip(data_type, idx_lists)}

In [None]:
if isinstance(data_type, str):
  print(f'Only use one data type: {data_type}')
  num_data_types = 1
  mat = feature_mat_dict[data_type][sample_idx_sel_dict[data_type]]
  # Data preprocessing: make each row have mean 0 and sd 1.
  x = (mat - mat.mean(axis=1, keepdims=True)) / mat.std(axis=1, keepdims=True)
  interaction_mat = feature_interaction_mat_dict[data_type]
  interaction_mat = torch.from_numpy(interaction_mat).float().to(device)
  # Normalize these interaction mat
  interaction_mat = interaction_mat / interaction_mat.norm()
else:
  mat = []
  interaction_mats = []
  in_dims = []
  num_data_types = len(data_type)
  # do not handle the special case of [data_type] to avoid too much code complexity
  assert num_data_types > 1 
  for dtype in data_type: # multiple data types
    m = feature_mat_dict[dtype][sample_idx_sel_dict[dtype]]
    #When there are multiple data types, make sure each type is normalized to have mean 0 and std 1
    m = (m - m.mean(axis=1, keepdims=True)) / m.std(axis=1, keepdims=True)
    mat.append(m)
    in_dims.append(m.shape[1])
    # For neural network model graph laplacian regularizer
    interaction_mat = feature_interaction_mat_dict[dtype]
    interaction_mat = torch.from_numpy(interaction_mat).float().to(device)
    # Normalize these interaction mat
    interaction_mat = interaction_mat / interaction_mat.norm()
    interaction_mats.append(interaction_mat)
    print(f'{dtype}: {m.shape}; '
          f'interaction_mat: mean={interaction_mat.mean().item():2f}, '
          f'std={interaction_mat.std().item():2f}, {interaction_mat.shape[0]}')
  # Later interaction_mat will be passed to Loss_feature_interaction
  interaction_mat = interaction_mats
  mat = np.concatenate(mat, axis=1)
  # For machine learing methods that use concatenated features without knowing underlying views,
  # it might be good to make each row have mean 0 and sd 1.
  x = (mat - mat.mean(axis=1, keepdims=True)) / mat.std(axis=1, keepdims=True)

if normal_transform_feature:
  X = x
else:
  X = mat

In [None]:
y_targets = get_target_variable(target_variable, clinical_dict, sel_patient_ids)
y_true = target_to_numpy(y_targets, target_variable_type, target_variable_range)
if len(additional_vars) > 0:
  additional_variables = get_target_variable(additional_vars, clinical_dict, sel_patient_ids)
  # to do handle additional variables such as age and gender

### To do: handle multiple inputs, multiple targets

In [None]:
# sklearn classifiers also accept torch.Tensor
X = torch.tensor(X).float().to(device)
y_true = torch.tensor(y_true).long().to(device)
num_cls = len(torch.unique(y_true))

x_train, y_train = X[train_idx], y_true[train_idx]
x_val, y_val = X[val_idx], y_true[val_idx]
x_test, y_test = X[test_idx], y_true[test_idx]
print(x_train.shape, x_val.shape, x_test.shape, y_train.shape, y_val.shape, y_test.shape)

label_prob_train = get_label_prob(y_train, verbose=False)
label_probs = [label_prob_train]
if len(y_val)>0:
  label_prob_val = get_label_prob(y_val, verbose=False)
  assert len(label_prob_train) == len(label_prob_val)
  label_probs.append(label_prob_val)
if len(y_test)>0:
  label_prob_test = get_label_prob(y_test, verbose=False)
  assert len(label_prob_train) == len(label_prob_test)
  label_probs.append(label_prob_test)
if isinstance(label_probs, torch.Tensor):
  print('label distribution:\n', torch.stack(label_probs, dim=1))
else:
  print('label distribution:\n', np.stack(label_probs, axis=1))

### Optionally randomize true class labels

In [None]:
if randomize_labels:
  print('Randomize class labels!')
  y_train = torch.multinomial(label_prob_train, len(y_train), replacement=True)
  if len(y_val) > 0:
    y_val = torch.multinomial(label_prob_val, len(y_val), replacement=True)
  if len(y_test) > 0:
    y_test = torch.multinomial(label_prob_test, len(y_test), replacement=True)

## Neural network models

In [None]:
# loss_fn_cls = torch.nn.CrossEntropyLoss(weight=torch.tensor([0.3, 0.6], device=device))
loss_fn_cls = torch.nn.CrossEntropyLoss()
loss_fn_reg = torch.nn.MSELoss()
loss_fns = [loss_fn_cls, loss_fn_reg]
# For multiple data types, there are multiple interaction mats
feat_interact_loss_type = 'graph_laplacian'
if num_data_types > 1:
  weight_path = ['decoders', range(num_data_types), 'weight']  
else:
  weight_path = ['decoder', 'weight']
loss_feat_interact = Loss_feature_interaction(interaction_mat=interaction_mat, 
                                              loss_type=feat_interact_loss_type, 
                                              weight_path=weight_path, 
                                              normalize=True)
other_loss_fns = [loss_feat_interact]
if num_data_types > 1:
  view_sim_loss_type = 'hub'
  explicit_target = True
  cal_target='mean-feature'
  # In this set of experiments, the encoders for all views will have the same hidden_dim
  loss_view_sim = Loss_view_similarity(sections=hidden_dim[-1], loss_type=view_sim_loss_type, 
    explicit_target=explicit_target, cal_target=cal_target, target=None)
  loss_fns.append(loss_view_sim)

In [None]:
model_names = []
split_names = ['train', 'val', 'test']
metric_names = ['acc', 'precision', 'recall', 'f1_score', 'adjusted_mutual_info', 'auc', 
                'average_precision']
metric_all = []
confusion_mat_all = []
loss_his_all = []
acc_his_all = []

In [None]:
def get_result(model, best_model, best_val_acc, best_epoch, x_train, y_train, x_val, y_val, 
               x_test, y_test, batch_size, multi_heads, show_results_in_notebook=True, 
               loss_idx=0, acc_idx=0):
  if len(x_val) > 0:
    print(f'Best model on validation set: best_val_acc={best_val_acc:.2f}, epoch={best_epoch}')
    metric = eval_classification_multi_splits(best_model, xs=[x_train, x_val, x_test], 
      ys=[y_train, y_val, y_test], batch_size=batch_size, multi_heads=multi_heads)

  if show_results_in_notebook:
    print('\nModel after the last training epoch:')
    eval_classification_multi_splits(model, xs=[x_train, x_val, x_test], 
                                     ys=[y_train, y_val, y_test], batch_size=batch_size, 
                                     multi_heads=multi_heads, return_result=False)

    plot_history_multi_splits([loss_train_his, loss_val_his, loss_test_his], title='Loss', 
                              idx=loss_idx)
    plot_history_multi_splits([acc_train_his, acc_val_his, acc_test_his], title='Acc', idx=acc_idx)
    # scatter plot
    plot_data_multi_splits(best_model, [x_train, x_val, x_test], [y_train, y_val, y_test], 
                           num_heads=2 if multi_heads else 1, 
                           titles=['Training', 'Validation', 'Test'], batch_size=batch_size)
    return metric

# Plain deep learning model

In [None]:
batch_size = 1000
print_every = 100
eval_every = 1

In [None]:
in_dim = x_train.shape[1]
print('Plain deep learning model')
model_names.append('NN')
model = DenseLinear(in_dim, hidden_dim+[num_cls], dense=dense, residual=residual).to(device)
multi_heads = False

loss_train_his = []
loss_val_his = []
loss_test_his = []
acc_train_his = []
acc_val_his = []
acc_test_his = []
best_model = model
best_val_acc = 0
best_epoch = 0

In [None]:
best_model, best_val_acc, best_epoch = train_single_loss(model, x_train, y_train, 
    x_val, y_val, x_test, y_test, loss_fn=loss_fn_cls, lr=lr, weight_decay=weight_decay, 
    amsgrad=True, batch_size=batch_size, num_epochs=num_epochs, 
    reduce_every=reduce_every, eval_every=eval_every, print_every=print_every, verbose=False, 
    loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his, 
    acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his, 
    return_best_val=True)

In [None]:
metric = get_result(model, best_model, best_val_acc, best_epoch, x_train, y_train, x_val, y_val, 
                    x_test, y_test, batch_size, multi_heads, show_results_in_notebook, 
                    loss_idx=0, acc_idx=0)

In [None]:
loss_his_all.append([loss_train_his, loss_val_his, loss_test_his])
acc_his_all.append([acc_train_his, acc_val_his, acc_test_his])
metric_all.append([v[0] for v in metric])
confusion_mat_all.append([v[1] for v in metric])

# Factorization AutoEncoder

In [None]:
def run_one_model(model, loss_weights, other_loss_weights, 
                  loss_his_all=[], acc_his_all=[], metric_all=[], confusion_mat_all=[],
                  heads=[0,1], multi_heads=True, return_results=False, 
                  loss_fns=loss_fns, other_loss_fns=other_loss_fns, 
                  lr=lr, weight_decay=weight_decay, batch_size=batch_size, 
                  num_epochs=num_epochs, reduce_every=reduce_every, eval_every=eval_every, 
                  print_every=print_every, x_train=x_train, y_train=y_train,
                  x_val=x_val, y_val=y_val, x_test=x_test, y_test=y_test,
                  show_results_in_notebook=show_results_in_notebook):
  """Train a model and get results  
    Most of the parameters are from the context; handle it properly
  """
  loss_train_his = []
  loss_val_his = []
  loss_test_his = []
  acc_train_his = []
  acc_val_his = []
  acc_test_his = []
  best_model = model
  best_val_acc = 0
  best_epoch = 0

  best_model, best_val_acc, best_epoch = train_multiloss(model, x_train, [y_train, x_train], 
    x_val, [y_val, x_val], x_test, [y_test, x_test], heads=heads, loss_fns=loss_fns, 
    loss_weights=loss_weights, other_loss_fns=other_loss_fns, 
    other_loss_weights=other_loss_weights, 
    lr=lr, weight_decay=weight_decay, batch_size=batch_size, num_epochs=num_epochs, 
    reduce_every=reduce_every, eval_every=eval_every, print_every=print_every,
    loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his, 
    acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his, 
    return_best_val=True, amsgrad=True, verbose=False)

  metric = get_result(model, best_model, best_val_acc, best_epoch, x_train, y_train, x_val, y_val, 
                      x_test, y_test, batch_size, multi_heads, show_results_in_notebook, 
                      loss_idx=0, acc_idx=0)

  loss_his_all.append([loss_train_his, loss_val_his, loss_test_his])
  acc_his_all.append([acc_train_his, acc_val_his, acc_test_his])
  metric_all.append([v[0] for v in metric])
  confusion_mat_all.append([v[1] for v in metric])
  
  if return_results:
    return loss_his_all, acc_his_all, metric_all, confusion_mat_all

In [None]:
decoder_norm = False
uniform_decoder_norm = False
print('Plain AutoEncoder model')
model_names.append('AE')
model = AutoEncoder(in_dim, hidden_dim, num_cls, dense=dense, residual=residual,
          decoder_norm=decoder_norm, uniform_decoder_norm=uniform_decoder_norm).to(device)
loss_weights = [1,1]
other_loss_weights = [0]
# heads = None should work for all the following; keep this for clarity
heads = [0,1] 
run_one_model(model, loss_weights, other_loss_weights,
              loss_his_all, acc_his_all, metric_all, confusion_mat_all,
              heads=heads, multi_heads=True, return_results=False, 
              loss_fns=loss_fns, other_loss_fns=other_loss_fns, 
              lr=lr, weight_decay=weight_decay, batch_size=batch_size, 
              num_epochs=num_epochs, reduce_every=reduce_every, eval_every=eval_every, 
              print_every=print_every, x_train=x_train, y_train=y_train,
              x_val=x_val, y_val=y_val, x_test=x_test, y_test=y_test,
              show_results_in_notebook=show_results_in_notebook)

## Add feature interaction network regularizer

In [None]:
if num_data_types > 1:
  fuse_type = 'sum'
  print('MultiviewAE with feature interaction network regularizer')
  model_names.append('MultiviewAE + feat_int')
  model = MultiviewAE(in_dims=in_dims, hidden_dims=hidden_dim, out_dim=num_cls, 
                      fuse_type=fuse_type, dense=dense, residual=residual, 
                      residual_layers='all', decoder_norm=decoder_norm, 
                      decoder_norm_dim=0, uniform_decoder_norm=uniform_decoder_norm, 
                      nonlinearity=nn.ReLU(), last_nonlinearity=True, bias=True).to(device)
else:
  print('AutoEncoder with feature interaction network regularizer')
  model_names.append('AE + feat_int')
  model = AutoEncoder(in_dim, hidden_dim, num_cls, dense=dense, residual=residual, 
          decoder_norm=decoder_norm, uniform_decoder_norm=uniform_decoder_norm).to(device)

loss_weights = [1,1]
other_loss_weights = [1]
heads = [0,1]
run_one_model(model, loss_weights, other_loss_weights, 
              loss_his_all, acc_his_all, metric_all, confusion_mat_all,
              heads=heads, multi_heads=True, return_results=False, 
              loss_fns=loss_fns, other_loss_fns=other_loss_fns, 
              lr=lr, weight_decay=weight_decay, batch_size=batch_size, 
              num_epochs=num_epochs, reduce_every=reduce_every, eval_every=eval_every, 
              print_every=print_every, x_train=x_train, y_train=y_train,
              x_val=x_val, y_val=y_val, x_test=x_test, y_test=y_test,
              show_results_in_notebook=show_results_in_notebook)

## For multi-view data, add view similarity network regularizer

In [None]:
if num_data_types > 1:
  # plain multiviewAE; compare it with plain AutoEncoder to see 
  # if separating views in lower layers in MultiviewAE is better than combining them all the way
  print('Run plain MultiviewAE model')
  model_names.append('MultiviewAE')
  model = MultiviewAE(in_dims=in_dims, hidden_dims=hidden_dim, out_dim=num_cls, 
                    fuse_type=fuse_type, dense=dense, residual=residual, 
                    residual_layers='all', decoder_norm=decoder_norm, 
                    decoder_norm_dim=0, uniform_decoder_norm=uniform_decoder_norm, 
                    nonlinearity=nn.ReLU(), last_nonlinearity=True, bias=True).to(device)

  loss_weights = [1,1]
  other_loss_weights = [0]
  heads = [0,1]
  run_one_model(model, loss_weights, other_loss_weights, 
                loss_his_all, acc_his_all, metric_all, confusion_mat_all,
                heads=heads, multi_heads=True, return_results=False, 
                loss_fns=loss_fns, other_loss_fns=other_loss_fns, 
                lr=lr, weight_decay=weight_decay, batch_size=batch_size, 
                num_epochs=num_epochs, reduce_every=reduce_every, eval_every=eval_every, 
                print_every=print_every, x_train=x_train, y_train=y_train,
                x_val=x_val, y_val=y_val, x_test=x_test, y_test=y_test,
                show_results_in_notebook=show_results_in_notebook)

In [None]:
if num_data_types > 1:
  print('MultiviewAE with view similarity regularizers')
  model_names.append('MultiviewAE + view_sim')
  model = MultiviewAE(in_dims=in_dims, hidden_dims=hidden_dim, out_dim=num_cls, 
                      fuse_type=fuse_type, dense=dense, residual=residual, 
                      residual_layers='all', decoder_norm=decoder_norm, 
                      decoder_norm_dim=0, uniform_decoder_norm=uniform_decoder_norm, 
                      nonlinearity=nn.ReLU(), last_nonlinearity=True, bias=True).to(device)
  loss_weights = [1,1,1]
  other_loss_weights = [0]
  heads = [0,1,2]
  run_one_model(model, loss_weights, other_loss_weights, 
                loss_his_all, acc_his_all, metric_all, confusion_mat_all,
                heads=heads, multi_heads=True, return_results=False, 
                loss_fns=loss_fns, other_loss_fns=other_loss_fns, 
                lr=lr, weight_decay=weight_decay, batch_size=batch_size, 
                num_epochs=num_epochs, reduce_every=reduce_every, eval_every=eval_every, 
                print_every=print_every, x_train=x_train, y_train=y_train,
                x_val=x_val, y_val=y_val, x_test=x_test, y_test=y_test,
                show_results_in_notebook=show_results_in_notebook)

In [None]:
if num_data_types > 1:
  print('MultiviewAE with both feature interaction and view similarity regularizers')
  model_names.append('MultiviewAE + feat_int + view_sim')
  model = MultiviewAE(in_dims=in_dims, hidden_dims=hidden_dim, out_dim=num_cls, 
                      fuse_type=fuse_type, dense=dense, residual=residual, 
                      residual_layers='all', decoder_norm=decoder_norm, 
                      decoder_norm_dim=0, uniform_decoder_norm=uniform_decoder_norm, 
                      nonlinearity=nn.ReLU(), last_nonlinearity=True, bias=True).to(device)
  loss_weights = [1,1,1]
  other_loss_weights = [1]
  heads = [0,1,2]
  run_one_model(model, loss_weights, other_loss_weights,
                loss_his_all, acc_his_all, metric_all, confusion_mat_all,
                heads=heads, multi_heads=True, return_results=False, 
                loss_fns=loss_fns, other_loss_fns=other_loss_fns, 
                lr=lr, weight_decay=weight_decay, batch_size=batch_size, 
                num_epochs=num_epochs, reduce_every=reduce_every, eval_every=eval_every, 
                print_every=print_every, x_train=x_train, y_train=y_train,
                x_val=x_val, y_val=y_val, x_test=x_test, y_test=y_test,
                show_results_in_notebook=show_results_in_notebook)

In [None]:
with open(f'{result_folder}/{res_file}', 'wb') as f:
  print(f'Write result to file {result_folder}/{res_file}')
  pickle.dump({'loss_his_all': loss_his_all,
               'acc_his_all': acc_his_all,
               'metric_all': metric_all,
               'confusion_mat_all': confusion_mat_all,
               'model_names': model_names,
               'split_names': split_names,
               'metric_names': metric_names
              }, f)