1522 lines (1521 with data), 66.5 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import socket\n",
"if socket.gethostname() == 'dlm':\n",
" %env CUDA_DEVICE_ORDER=PCI_BUS_ID\n",
" %env CUDA_VISIBLE_DEVICES=3"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using CPU:(\n"
]
}
],
"source": [
"import os\n",
"import sys\n",
"import re\n",
"import collections\n",
"import functools\n",
"import itertools\n",
"import requests, zipfile, io\n",
"import pickle\n",
"import copy\n",
"\n",
"import pandas\n",
"import numpy as np\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"import sklearn\n",
"import sklearn.decomposition\n",
"import sklearn.metrics\n",
"import networkx\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"\n",
"lib_path = 'I:/code'\n",
"if not os.path.exists(lib_path):\n",
" lib_path = '/media/6T/.tianle/.lib'\n",
"if not os.path.exists(lib_path):\n",
" lib_path = '/projects/academic/azhang/tianlema/lib'\n",
"if os.path.exists(lib_path) and lib_path not in sys.path:\n",
" sys.path.append(lib_path)\n",
" \n",
"from dl.models.basic_models import *\n",
"from dl.utils.visualization.visualization import *\n",
"from dl.utils.outlier import *\n",
"from dl.utils.train import *\n",
"from autoencoder.autoencoder import *\n",
"from vin.vin import *\n",
"from dl.utils.utils import get_overlap_samples, filter_clinical_dict, get_target_variable\n",
"from dl.utils.utils import get_shuffled_data, target_to_numpy, discrete_to_id, get_mi_acc\n",
"from dl.utils.utils import get_label_distribution, normalize_continuous_variable\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"\n",
"use_gpu = True\n",
"if use_gpu and torch.cuda.is_available():\n",
" device = torch.device('cuda')\n",
" print('Using GPU:)')\n",
"else:\n",
" device = torch.device('cpu')\n",
" print('Using CPU:(')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# neural net models include nn (mlp), resnet, densenet; another choice is ml (machine learning)\n",
"# model_type, dense, residual are dependent\n",
"model_type = 'resnet'\n",
"dense = False\n",
"residual = True\n",
"hidden_dim = [100, 100]\n",
"train_portion = 0.7\n",
"val_portion = 0.1\n",
"test_portion = 0.2\n",
"num_train_types = -1 # -1 means not used\n",
"num_val_types = -1\n",
"num_test_types = -1 # this will almost never be used \n",
"num_sets = 10\n",
"num_folds = 10 # no longer used anymore\n",
"sel_set_idx = 0\n",
"cv_type = 'instance-shuffle' # or 'group-shuffle'; cross validation shuffle method\n",
"sel_disease_types = 'all'\n",
"# The number of total samples and the numbers for each class in selected disease types must >=\n",
"min_num_samples_per_type_cls = [100, 0]\n",
"# if 'auto-search', will search for the file first; if not exist, then generate random data split\n",
"# and write to the file;\n",
"# if string other than 'auto-search' is provided, assume the string is a proper file name, \n",
"# and read the file;\n",
"# if False, will generate a random data split, but not write to file \n",
"# if True will generate a random data split, and write to file\n",
"predefined_sample_set_file = 'auto-search' \n",
"target_variable = ['PFI', 'DFI', 'PFI.time'] # To do: target variable can be a list (partially handled)\n",
"target_variable_type = ['discrete', 'discrete', 'continuous'] # or 'continuous' real numbers\n",
"target_variable_range = [[0,1],[0,1],[0,float('Inf')]]\n",
"data_type = ['gene', 'methy', 'rppa', 'mirna']\n",
"normal_transform_feature = True\n",
"additional_vars = ['age_at_initial_pathologic_diagnosis', 'gender', 'ajcc_pathologic_tumor_stage']\n",
"additional_var_types = ['continuous', 'discrete', 'discrete']\n",
"additional_var_ranges = [[0, 100], ['MALE', 'FEMALE'], \n",
" ['I/II NOS', 'IS', 'Stage 0', 'Stage I', 'Stage IA', 'Stage IB', \n",
" 'Stage II', 'Stage IIA', 'Stage IIB', 'Stage IIC', 'Stage III',\n",
" 'Stage IIIA', 'Stage IIIB', 'Stage IIIC', 'Stage IV', 'Stage IVA',\n",
" 'Stage IVB', 'Stage IVC', 'Stage X']]\n",
"randomize_labels = False\n",
"lr = 5e-4\n",
"weight_decay = 1e-4\n",
"num_epochs = 100\n",
"reduce_every = 500\n",
"show_results_in_notebook = True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prepare data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"feature_mat: rppa, max=14.141, min=-7.869, mean=0.095, 0.516\n",
"feature_mat: mirna, max=11.813, min=0.000, mean=3.743, 1.000\n",
"feature_mat: gene, max=16.311, min=0.000, mean=8.412, 1.000\n",
"feature_mat: methy, max=1.000, min=0.000, mean=0.553, 1.000\n",
"feature_interaction_mat: rppa, max=1.000, min=0.000, mean=0.198, 0.277\n",
"feature_interaction_mat: mirna, max=1.000, min=0.010, mean=0.492, 1.000\n",
"feature_interaction_mat: gene, max=1.000, min=0.000, mean=0.050, 0.146\n",
"feature_interaction_mat: methy, max=1.000, min=0.000, mean=0.057, 0.127\n",
"rppa (189,) X1433EPSILON\n",
"mirna (662,) hsa-let-7a-2-3p\n",
"gene (4942,) A1BG\n",
"methy (4753,) cg00005847\n",
"rppa (7480,) TCGA-OR-A5J2-01A-21-A39K-20\n",
"mirna (9554,) TCGA-C4-A0F6-01A-11R-A10V-13\n",
"gene (9702,) TCGA-OR-A5J1-01A-11R-A29S-07\n",
"methy (10268,) TCGA-02-0001-01C-01D-0186-05\n"
]
}
],
"source": [
"result_folder = 'results'\n",
"data_split_idx_folder = f'{result_folder}/data_split_idx'\n",
"project_folder = '../../pan-can-atlas' # on dlm or ccr\n",
"print_stats = True\n",
"if not os.path.exists(project_folder):\n",
" project_folder = 'F:/TCGA/Pan-Cancer-Atlas' # on my own desktop\n",
"filepath = f'{project_folder}/data/processed/combined2.pkl'\n",
"with open(filepath, 'rb') as f:\n",
" data = pickle.load(f)\n",
" patient_clinical = data['patient_clinical']\n",
" feature_mat_dict = data['feature_mat_dict']\n",
" feature_interaction_mat_dict = data['feature_interaction_mat_dict']\n",
" feature_id_dict = data['feature_id_dict']\n",
" aliquot_id_dict = data['aliquot_id_dict']\n",
"# sel_patient_ids = data['sample_id_sel']\n",
"# sample_idx_sel_dict = data['sample_idx_sel_dict']\n",
"# for k, v in sample_idx_sel_dict.items():\n",
"# assert [i[:12] for i in aliquot_id_dict[k][v]] == sel_patient_ids\n",
"\n",
"if print_stats:\n",
" for k, v in feature_mat_dict.items():\n",
" print(f'feature_mat: {k}, max={v.max():.3f}, min={v.min():.3f}, '\n",
" f'mean={v.mean():.3f}, {np.mean(v>0):.3f}') \n",
" for k, v in feature_interaction_mat_dict.items():\n",
" print(f'feature_interaction_mat: {k}, max={v.max():.3f}, min={v.min():.3f}, '\n",
" f'mean={v.mean():.3f}, {np.mean(v>0):.3f}') \n",
" for k, v in feature_id_dict.items():\n",
" print(k, v.shape, v[0])\n",
" for k, v in aliquot_id_dict.items():\n",
" print(k, v.shape, v[0])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"THCA {0.0: 240, 1.0: 24}\n",
"BLCA {0.0: 125, 1.0: 25}\n",
"BRCA {0.0: 666, 1.0: 62}\n",
"KIRP {0.0: 101, 1.0: 22}\n",
"STAD {1.0: 37, 0.0: 149}\n",
"LIHC {0.0: 62, 1.0: 68}\n",
"LUAD {1.0: 56, 0.0: 129}\n",
"COAD {0.0: 108, 1.0: 17}\n",
"LUSC {0.0: 136, 1.0: 56}\n",
"Selected 2083 patients from 9 disease_types\n"
]
}
],
"source": [
"# select samples with required clinical variables\n",
"clinical_dict = filter_clinical_dict(target_variable, target_variable_type=target_variable_type, \n",
" target_variable_range=target_variable_range, \n",
" clinical_dict=patient_clinical)\n",
"if len(additional_vars) > 0:\n",
" clinical_dict = filter_clinical_dict(additional_vars, target_variable_type=additional_var_types, \n",
" target_variable_range=additional_var_ranges, \n",
" clinical_dict=clinical_dict)\n",
"\n",
"# select samples with feature matrix of given type(s)\n",
"if isinstance(data_type, str):\n",
" sample_list = {s[:12] for s in aliquot_id_dict[data_type]}\n",
" data_type_str = data_type\n",
"elif isinstance(data_type, (list, tuple)):\n",
" sample_list = get_overlap_samples([aliquot_id_dict[dtype] for dtype in data_type], \n",
" common_list=None, start=0, end=12, return_common_list=True)\n",
" data_type_str = '-'.join(sorted(data_type))\n",
"else:\n",
" raise ValueError(f'data_type must be str or list/tuple, but is {type(data_type)}')\n",
"sample_list = set(sample_list).intersection(clinical_dict)\n",
"\n",
"# select samples with given disease types\n",
"sel_disease_type_str = sel_disease_types # will be overwritten if it is a list\n",
"if isinstance(sel_disease_types, (list, tuple)):\n",
" sample_list = [s for s in sample_list if clinical_dict[s]['type'] in sel_disease_types]\n",
" sel_disease_type_str = '-'.join(sorted(sel_disease_types))\n",
"elif isinstance(sel_disease_types, str) and sel_disease_types!='all':\n",
" sample_list = [s for s in sample_list if clinical_dict[s]['type'] == sel_disease_types]\n",
"else:\n",
" assert sel_disease_types == 'all'\n",
" \n",
"# For classification tasks with given min_num_samples_per_type_cls,\n",
"# only keep disease types that have a minimal number of samples per type and per class\n",
"# Reflection: it might be better to use collections.defaultdict(list) to store samples in each type\n",
"type_cnt = collections.Counter([clinical_dict[s]['type'] for s in sample_list])\n",
"if sum(min_num_samples_per_type_cls)>0 and (target_variable_type=='discrete' \n",
" or target_variable_type[0]=='discrete'):\n",
" # the number of samples in each disease type >= min_num_samples_per_type_cls[0]\n",
" type_cnt = {k: v for k, v in type_cnt.items() if v >= min_num_samples_per_type_cls[0]}\n",
" disease_type_cnt = {}\n",
" for k in type_cnt:\n",
" # collections.Counter can accept generator\n",
" cls_cnt = collections.Counter(clinical_dict[s][target_variable] \n",
" if isinstance(target_variable, str) \n",
" else clinical_dict[s][target_variable[0]] \n",
" for s in sample_list if clinical_dict[s]['type']==k)\n",
" if all([v >= min_num_samples_per_type_cls[1] for v in cls_cnt.values()]):\n",
" # the number of samples in each class >= min_num_samples_per_type_cls[1]\n",
" disease_type_cnt[k] = dict(cls_cnt)\n",
" print(k, disease_type_cnt[k])\n",
" sample_list = [s for s in sample_list if clinical_dict[s]['type'] in disease_type_cnt]\n",
"sel_patient_ids = sorted(sample_list)\n",
"print(f'Selected {len(sel_patient_ids)} patients from {len(disease_type_cnt)} disease_types')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Split data into training, validation, and test sets"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Read predefined_sample_set_file: results/data_split_idx/PFI-DFI-PFI.time_instance-shuffle_age_at_initial_pathologic_diagnosis-ajcc_pathologic_tumor_stage-gender_gene-methy-mirna-rppa_all_100-0_0.7-0.1-0.2_10sets.pkl\n"
]
}
],
"source": [
"predefined_sample_set_filename = (target_variable if isinstance(target_variable,str) \n",
" else '-'.join(target_variable))\n",
"predefined_sample_set_filename += f'_{cv_type}'\n",
"if len(additional_vars) > 0:\n",
" predefined_sample_set_filename += f\"_{'-'.join(sorted(additional_vars))}\"\n",
"\n",
"predefined_sample_set_filename += (f\"_{data_type_str}_{sel_disease_type_str}_\"\n",
" f\"{'-'.join(map(str, min_num_samples_per_type_cls))}\")\n",
"predefined_sample_set_filename += f\"_{'-'.join(map(str, [train_portion, val_portion, test_portion]))}\"\n",
"if cv_type == 'group-shuffle' and num_train_types > 0:\n",
" predefined_sample_set_filename += f\"_{'-'.join(map(str, [num_train_types, num_val_types, num_test_types]))}\"\n",
"predefined_sample_set_filename += f'_{num_sets}sets'\n",
"res_file = f\"{predefined_sample_set_filename}_{sel_set_idx}_{'-'.join(map(str, hidden_dim))}_{model_type}.pkl\"\n",
"predefined_sample_set_filename += '.pkl'\n",
"# This will be overwritten if predefined_sample_set_file == 'auto-search' or filepath, and the file exists\n",
"predefined_sample_sets = [get_shuffled_data(sel_patient_ids, clinical_dict, cv_type=cv_type, \n",
" instance_portions=[train_portion, val_portion, test_portion], \n",
" group_sizes=[num_train_types, num_val_types, num_test_types],\n",
" group_variable_name='type', seed=None, verbose=False) for i in range(num_sets)]\n",
"if predefined_sample_set_file == 'auto-search':\n",
" if os.path.exists(f'{data_split_idx_folder}/{predefined_sample_set_filename}'):\n",
" with open(f'{data_split_idx_folder}/{predefined_sample_set_filename}', 'rb') as f:\n",
" print(f'Read predefined_sample_set_file: '\n",
" f'{data_split_idx_folder}/{predefined_sample_set_filename}')\n",
" tmp = pickle.load(f)\n",
" # overwrite calculated predefined_sample_sets\n",
" predefined_sample_sets = tmp['predefined_sample_sets'] \n",
"elif isinstance(predefined_sample_set_file, str): # but not 'auto-search'; assume it's a file name\n",
" if os.path.exists(predefined_sample_set_file):\n",
" with open(f'{data_split_idx_folder}/{predefined_sample_set_file}', 'rb') as f:\n",
" print(f'Read predefined_sample_set_file: {data_split_idx_folder}/{predefined_sample_set_file}')\n",
" tmp = pickle.load(f)\n",
" predefined_sample_sets = tmp['predefined_sample_sets']\n",
" else:\n",
" raise ValueError(f'predefined_sample_set_file: {data_split_idx_folder}/{predefined_sample_set_file} does not exist!')\n",
"\n",
"if (not os.path.exists(f'{data_split_idx_folder}/{predefined_sample_set_filename}') \n",
" and predefined_sample_set_file == 'auto-search') or predefined_sample_set_file is True:\n",
" with open(f'{data_split_idx_folder}/{predefined_sample_set_filename}', 'wb') as f:\n",
" print(f'Write predefined_sample_set_file: {data_split_idx_folder}/{predefined_sample_set_filename}')\n",
" pickle.dump({'predefined_sample_sets': predefined_sample_sets}, f)\n",
" \n",
"sel_patient_ids, idx_splits = predefined_sample_sets[sel_set_idx]\n",
"train_idx, val_idx, test_idx = idx_splits"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"if isinstance(data_type, str):\n",
" sample_lists = [aliquot_id_dict[data_type]]\n",
"else:\n",
" assert isinstance(data_type, (list, tuple))\n",
" sample_lists = [aliquot_id_dict[dtype] for dtype in data_type]\n",
"idx_lists = get_overlap_samples(sample_lists=sample_lists, common_list=sel_patient_ids, \n",
" start=0, end=12, return_common_list=False)\n",
"sample_idx_sel_dict = {}\n",
"if isinstance(data_type, str):\n",
" sample_idx_sel_dict = {data_type: idx_lists[0]}\n",
"else:\n",
" sample_idx_sel_dict = {dtype: idx_list for dtype, idx_list in zip(data_type, idx_lists)}"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"gene: (2083, 4942); interaction_mat: mean=0.000057, std=0.000194, 4942\n",
"methy: (2083, 4753); interaction_mat: mean=0.000069, std=0.000199, 4753\n",
"rppa: (2083, 189); interaction_mat: mean=0.002668, std=0.004569, 189\n",
"mirna: (2083, 662); interaction_mat: mean=0.001408, std=0.000547, 662\n"
]
}
],
"source": [
"if isinstance(data_type, str):\n",
" print(f'Only use one data type: {data_type}')\n",
" num_data_types = 1\n",
" mat = feature_mat_dict[data_type][sample_idx_sel_dict[data_type]]\n",
" # Data preprocessing: make each row have mean 0 and sd 1.\n",
" x = (mat - mat.mean(axis=1, keepdims=True)) / mat.std(axis=1, keepdims=True)\n",
" interaction_mat = feature_interaction_mat_dict[data_type]\n",
" interaction_mat = torch.from_numpy(interaction_mat).float().to(device)\n",
" # Normalize these interaction mat\n",
" interaction_mat = interaction_mat / interaction_mat.norm()\n",
"else:\n",
" mat = []\n",
" interaction_mats = []\n",
" in_dims = []\n",
" num_data_types = len(data_type)\n",
" # do not handle the special case of [data_type] to avoid too much code complexity\n",
" assert num_data_types > 1 \n",
" for dtype in data_type: # multiple data types\n",
" m = feature_mat_dict[dtype][sample_idx_sel_dict[dtype]]\n",
" #When there are multiple data types, make sure each type is normalized to have mean 0 and std 1\n",
" m = (m - m.mean(axis=1, keepdims=True)) / m.std(axis=1, keepdims=True)\n",
" mat.append(m)\n",
" in_dims.append(m.shape[1])\n",
" # For neural network model graph laplacian regularizer\n",
" interaction_mat = feature_interaction_mat_dict[dtype]\n",
" interaction_mat = torch.from_numpy(interaction_mat).float().to(device)\n",
" # Normalize these interaction mat\n",
" interaction_mat = interaction_mat / interaction_mat.norm()\n",
" interaction_mats.append(interaction_mat)\n",
" print(f'{dtype}: {m.shape}; '\n",
" f'interaction_mat: mean={interaction_mat.mean().item():2f}, '\n",
" f'std={interaction_mat.std().item():2f}, {interaction_mat.shape[0]}')\n",
" # Later interaction_mat will be passed to Loss_feature_interaction\n",
" interaction_mat = interaction_mats\n",
" mat = np.concatenate(mat, axis=1)\n",
" # For machine learing methods that use concatenated features without knowing underlying views,\n",
" # it might be good to make each row have mean 0 and sd 1.\n",
" x = (mat - mat.mean(axis=1, keepdims=True)) / mat.std(axis=1, keepdims=True)\n",
"\n",
"if normal_transform_feature:\n",
" X = x\n",
"else:\n",
" X = mat"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1458, 10546]) torch.Size([208, 10546]) torch.Size([417, 10546])\n"
]
}
],
"source": [
"# sklearn classifiers also accept torch.Tensor\n",
"X = torch.tensor(X).float().to(device)\n",
"x_train = X[train_idx]\n",
"x_val = X[val_idx]\n",
"x_test = X[test_idx]\n",
"print(x_train.shape, x_val.shape, x_test.shape)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Changed class labels for the model: {0.0: 0, 1.0: 1}\n",
"Changed class labels for the model: {0.0: 0, 1.0: 1}\n",
"PFI:\n",
"train:torch.Size([1458]), val:torch.Size([208]), test:torch.Size([417])\n",
"label distribution:\n",
" [[0.82304525 0.84615386 0.81534773]\n",
" [0.17695473 0.15384616 0.18465228]]\n",
"DFI:\n",
"train:torch.Size([1458]), val:torch.Size([208]), test:torch.Size([417])\n",
"label distribution:\n",
" [[0.8360768 0.84615386 0.82254195]\n",
" [0.16392319 0.15384616 0.17745803]]\n",
"PFI.time:\n",
"train:torch.Size([1458, 1]), val:torch.Size([208, 1]), test:torch.Size([417, 1])\n"
]
}
],
"source": [
"y_targets = get_target_variable(target_variable, clinical_dict, sel_patient_ids)\n",
"y_targets = normalize_continuous_variable(y_targets, target_variable_type, transform=True, \n",
" forced=False, threshold=10, rm_outlier=True, whis=1.5, \n",
" only_positive=True, max_val=1)\n",
"y_true = target_to_numpy(y_targets, target_variable_type, target_variable_range)\n",
"if len(additional_vars) > 0:\n",
" additional_variables = get_target_variable(additional_vars, clinical_dict, sel_patient_ids)\n",
" # to do handle additional variables such as age and gender\n",
"\n",
"# should have written a recursive function instead\n",
"if isinstance(target_variable_type, list):\n",
" y_targets = []\n",
" num_cls = []\n",
" y_train = []\n",
" y_val = []\n",
" y_test = []\n",
" for i, var_type in enumerate(target_variable_type):\n",
" y = torch.tensor(y_true[i]).to(device)\n",
" if var_type == 'discrete':\n",
" y = y.long()\n",
" elif var_type == 'continuous':\n",
" y = y.float()\n",
" if y.dim()==1:\n",
" y = y.unsqueeze(-1)\n",
" else:\n",
" raise ValueError(f'target type should be either discrete or continuous but is {var_type}')\n",
" y_targets.append(y)\n",
" num_cls.append(len(torch.unique(y))) # include continous target variables\n",
" y_train.append(y[train_idx])\n",
" y_val.append(y[val_idx])\n",
" y_test.append(y[test_idx])\n",
" print(f'{target_variable[i]}:\\ntrain:{y_train[-1].shape}, val:{y_val[-1].shape}, '\n",
" f'test:{y_test[-1].shape}')\n",
" if var_type == 'discrete':\n",
" label_probs = get_label_distribution([y_train[-1], y_val[-1], y_test[-1]])\n",
" if randomize_labels: # Optionally randomize true class labels\n",
" print('Randomize class labels!')\n",
" y_train[-1] = torch.multinomial(label_probs[0], len(y_train[-1]), replacement=True)\n",
" if len(y_val) > 0:\n",
" y_val[-1] = torch.multinomial(label_probs[1], len(y_val[-1]), replacement=True)\n",
" if len(y_test) > 0:\n",
" y_test[-1] = torch.multinomial(label_probs[2], len(y_test[-1]), replacement=True)\n",
" get_label_distribution([y_train[-1], y_val[-1], y_test[-1]])\n",
" y_true = y_targets\n",
"elif isinstance(target_variable_type, str):\n",
" y = torch.tensor(y_true).to(device)\n",
" if var_type == 'discrete':\n",
" y = y.long()\n",
" elif var_type == 'continuous':\n",
" y = y.float()\n",
" if y.dim()==1:\n",
" y = y.unsqueeze(-1)\n",
" else:\n",
" raise ValueError(f'target type should be either discrete or continuous but is {var_type}')\n",
" y_true = y\n",
" num_cls = len(torch.unique(y_true))\n",
" y_train = y_true[train_idx]\n",
" y_val = y_true[val_idx]\n",
" y_test = y_true[test_idx]\n",
" print(f'{target_variable}:\\ntrain:{y_train.shape}, val:{y_val.shape}, '\n",
" f'test:{y_test.shape}')\n",
" label_probs = get_label_distribution([y_train, y_val, y_test])\n",
" if randomize_labels: # Optionally randomize true class labels\n",
" print('Randomize class labels!')\n",
" y_train = torch.multinomial(label_probs[0], len(y_train), replacement=True)\n",
" if len(y_val) > 0:\n",
" y_val = torch.multinomial(label_probs[1], len(y_val), replacement=True)\n",
" if len(y_test) > 0:\n",
" y_test = torch.multinomial(label_probs[2], len(y_test), replacement=True)\n",
" get_label_distribution([y_train, y_val, y_test])\n",
"else:\n",
" raise ValueError(f'target_variable_type should be str or list, but is {type(target_variable_type)}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use additional variables for prediction"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"age_at_initial_pathologic_diagnosis\n",
"Counter({5: 577, 4: 468, 6: 429, 3: 285, 7: 137, 2: 130, 1: 48, 0: 9})\n",
"gender\n",
"Counter({0: 1322, 1: 761})\n",
"ajcc_pathologic_tumor_stage\n",
"Counter({0: 398, 4: 379, 5: 276, 8: 220, 3: 185, 1: 164, 7: 162, 2: 144, 10: 77, 9: 76, 11: 1, 6: 1})\n"
]
}
],
"source": [
"embedding_dim = 50\n",
"input_list = []\n",
"xs = []\n",
"for v, n, t in zip(additional_variables, additional_vars, additional_var_types):\n",
" if n.startswith('age'): \n",
" bins = [0, 20, 30, 40, 50, 60, 70, 80, 100]\n",
" v = np.digitize(v, bins)\n",
" t = 'discrete'\n",
" if t=='discrete':\n",
" target_ids, cls_id_dict = discrete_to_id(v, start=0, sort=True)\n",
" # some target_ids may have very few instances\n",
" print(n)\n",
" print(collections.Counter(target_ids))\n",
" xs.append(torch.tensor(target_ids, device=device).long())\n",
" # did not handle missing value yet\n",
" input_list.append({'in_dim': len(cls_id_dict), 'in_type': 'discrete', 'padding_idx':None, \n",
" 'embedding_dim':embedding_dim, 'hidden_dim': hidden_dim})\n",
" else: # t=='continuous'\n",
" xs.append(torch.tensor(v, device=device).float())\n",
" input_list.append({'in_dim': len(v[0]), 'in_type': 'continuous', 'hidden_dim': hidden_dim})"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"target: PFI\n",
" Variable \t MI \tAdj_MI\tBayes_ACC\n",
"age_at_initial_pathologic_diagnosis\t0.003\t0.001 \t 0.824 \n",
" gender \t0.009\t0.013 \t 0.824 \n",
" ajcc_pathologic_tumor_stage \t0.012\t0.004 \t 0.824 \n",
" 0-1 \t0.011\t0.003 \t 0.824 \n",
" 0-2 \t0.029\t0.003 \t 0.825 \n",
" 1-2 \t0.033\t0.010 \t 0.831 \n",
" 0-1-2 \t0.061\t0.006 \t 0.836 \n",
"target: DFI\n",
" Variable \t MI \tAdj_MI\tBayes_ACC\n",
"age_at_initial_pathologic_diagnosis\t0.002\t0.000 \t 0.834 \n",
" gender \t0.009\t0.013 \t 0.834 \n",
" ajcc_pathologic_tumor_stage \t0.011\t0.004 \t 0.835 \n",
" 0-1 \t0.011\t0.003 \t 0.834 \n",
" 0-2 \t0.028\t0.003 \t 0.836 \n",
" 1-2 \t0.029\t0.009 \t 0.838 \n",
" 0-1-2 \t0.056\t0.005 \t 0.843 \n"
]
}
],
"source": [
"if isinstance(target_variable, list):\n",
" for i, var_name in enumerate(target_variable):\n",
" if target_variable_type[i]=='discrete':\n",
" print(f'target: {var_name}')\n",
" get_mi_acc(xs, y_true=y_true[i], var_names=additional_vars, var_name_length=35)\n",
"elif isinstance(target_variable, str):\n",
" if target_variable_type=='discrete':\n",
" print(f'target: {target_variable}')\n",
" get_mi_acc(xs, y_true, var_names=additional_vars, var_name_length=35)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"xs_train = [x[train_idx] for x in xs]\n",
"xs_val = [x[val_idx] for x in xs]\n",
"xs_test = [x[test_idx] for x in xs]"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1458])\n",
"torch.Size([1458])\n",
"torch.Size([1458])\n",
"torch.Size([1458, 10546])\n"
]
}
],
"source": [
"last_nonlinearity = True\n",
"input_list.append({'in_dim': x_train.size(1), 'in_type': 'continuous', 'hidden_dim': hidden_dim,\n",
" 'last_nonlinearity':last_nonlinearity, })\n",
"xs_train.append(x_train)\n",
"xs_val.append(x_val)\n",
"xs_test.append(x_test)\n",
"\n",
"for i in xs_train:\n",
" print(i.shape)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"hidden_dim = [100]\n",
"output_info = [{'hidden_dim': hidden_dim+[2]}, {'hidden_dim': hidden_dim+[2]},\n",
" {'hidden_dim': hidden_dim+[1]}]\n",
" \n",
"fusion_lists = [[{'fusion_type': 'repr-weighted-avg_repr', 'hidden_dim': hidden_dim, \n",
" 'last_nonlinearity':last_nonlinearity, 'output_info': output_info}, \n",
" {'fusion_type': 'repr-loss-avg_repr', 'hidden_dim': hidden_dim, \n",
" 'last_nonlinearity':last_nonlinearity, 'output_info': output_info},\n",
" {'fusion_type': 'repr-avg_repr', 'hidden_dim': hidden_dim, \n",
" 'last_nonlinearity':last_nonlinearity, 'output_info': output_info},\n",
" {'fusion_type': 'repr-cat_repr', 'hidden_dim': hidden_dim, \n",
" 'last_nonlinearity':last_nonlinearity, 'output_info': output_info},\n",
" {'fusion_type': 'out-weighted-avg', 'output_info': output_info},\n",
" {'fusion_type': 'out-loss-avg', 'output_info': output_info},\n",
" {'fusion_type': 'out-avg', 'output_info': output_info}],\n",
" [{'fusion_type': 'repr-weighted-avg_repr', 'hidden_dim': hidden_dim, \n",
" 'last_nonlinearity':last_nonlinearity, 'output_info': output_info}, \n",
" {'fusion_type': 'repr-loss-avg_repr', 'hidden_dim': hidden_dim, \n",
" 'last_nonlinearity':last_nonlinearity, 'output_info': output_info},\n",
" {'fusion_type': 'repr-avg_repr', 'hidden_dim': hidden_dim, \n",
" 'last_nonlinearity':last_nonlinearity, 'output_info': output_info},\n",
" {'fusion_type': 'repr-cat_repr', 'hidden_dim': hidden_dim, \n",
" 'last_nonlinearity':last_nonlinearity, 'output_info': output_info},\n",
" {'fusion_type': 'out-weighted-avg', 'output_info': output_info},\n",
" {'fusion_type': 'out-loss-avg', 'output_info': output_info},\n",
" {'fusion_type': 'out-avg', 'output_info': output_info}],\n",
" [{'fusion_type': 'repr0', 'hidden_dim': hidden_dim, \n",
" 'last_nonlinearity':last_nonlinearity, 'output_info': output_info}]\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"model = VIN(input_list, output_info, fusion_lists, nonlinearity=nn.ReLU())\n",
"if target_variable_type=='discrete':\n",
" loss_fn = nn.CrossEntropyLoss()\n",
"elif target_variable_type=='dontinuous':\n",
" loss_fn = nn.MSELoss()\n",
"else:\n",
" loss_fn = []\n",
" for var_type in target_variable_type:\n",
" if var_type == 'discrete':\n",
" loss_fn.append(nn.CrossEntropyLoss())\n",
" elif var_type == 'continuous':\n",
" loss_fn.append(nn.MSELoss())\n",
" else:\n",
" raise ValueError(f'target type should be either discrete or continous, but is {var_type}')"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), \n",
" lr=1e-2, weight_decay=weight_decay, amsgrad=True)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"for n, p in model.named_parameters():\n",
"# print(n, p.size())\n",
" if p.grad is not None and p.grad.norm()==0:\n",
" print(n, p.grad if p.grad is None else p.grad.norm())"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"target_loss_weight = [1., 1., 1.]"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 6.476935386657715\n",
"99 1.0573850870132446\n"
]
}
],
"source": [
"loss_his = []\n",
"num_iters = 100\n",
"print_every = 100\n",
"for i in range(num_iters):\n",
" pred = model(xs_test)\n",
" losses = get_vin_loss(pred, y_test, loss_fn, model, valid_loc=None, target_id=None, \n",
" level_weight=None)\n",
" loss = sum(losses[j][0]*target_loss_weight[j] for j in range(len(losses)))\n",
" losses = losses[0]\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" loss_his.append([losses[0].item()] + \n",
" [[[v.item() for v in losses[1][i][0]], losses[1][i][1].item()] for i in range(2)])\n",
" if i%print_every == 0 or i==num_iters-1:\n",
" print(i, loss.item())"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"acc=1.000, precision=1.000, recall=1.000, fl=1.000, adj_MI=1.000, auc=1.000, ap=1.000, confusion_mat=\n",
"[[340 0]\n",
" [ 0 77]]\n",
"report precision recall f1-score support\n",
"\n",
" 0 1.00 1.00 1.00 340\n",
" 1 1.00 1.00 1.00 77\n",
"\n",
"avg / total 1.00 1.00 1.00 417\n",
"\n"
]
},
{
"data": {
"text/plain": [
"[(tensor(0.3750, grad_fn=<ThAddBackward>),\n",
" [[[tensor(0.4638, grad_fn=<NllLossBackward>),\n",
" tensor(0.4475, grad_fn=<NllLossBackward>),\n",
" tensor(0.4651, grad_fn=<NllLossBackward>),\n",
" tensor(0.0441, grad_fn=<NllLossBackward>)],\n",
" tensor(0.3546, grad_fn=<ThAddBackward>)],\n",
" [[tensor(0.0055, grad_fn=<NllLossBackward>),\n",
" tensor(0.0074, grad_fn=<NllLossBackward>),\n",
" tensor(0.0078, grad_fn=<NllLossBackward>),\n",
" tensor(0.0036, grad_fn=<NllLossBackward>),\n",
" tensor(0.0263, grad_fn=<NllLossBackward>),\n",
" tensor(0.0265, grad_fn=<NllLossBackward>),\n",
" tensor(0.0267, grad_fn=<NllLossBackward>)],\n",
" tensor(0.0128, grad_fn=<ThAddBackward>)],\n",
" [[tensor(0.0036, grad_fn=<NllLossBackward>),\n",
" tensor(0.0036, grad_fn=<NllLossBackward>),\n",
" tensor(0.0037, grad_fn=<NllLossBackward>),\n",
" tensor(0.0031, grad_fn=<NllLossBackward>),\n",
" tensor(0.0083, grad_fn=<NllLossBackward>),\n",
" tensor(0.0080, grad_fn=<NllLossBackward>),\n",
" tensor(0.0091, grad_fn=<NllLossBackward>)],\n",
" tensor(0.0055, grad_fn=<ThAddBackward>)],\n",
" [[tensor(0.0021, grad_fn=<NllLossBackward>)],\n",
" tensor(0.0021, grad_fn=<AddBackward>)]]),\n",
" (tensor(0.3739, grad_fn=<ThAddBackward>),\n",
" [[[tensor(0.4564, grad_fn=<NllLossBackward>),\n",
" tensor(0.4383, grad_fn=<NllLossBackward>),\n",
" tensor(0.4547, grad_fn=<NllLossBackward>),\n",
" tensor(0.0552, grad_fn=<NllLossBackward>)],\n",
" tensor(0.3481, grad_fn=<ThAddBackward>)],\n",
" [[tensor(0.0104, grad_fn=<NllLossBackward>),\n",
" tensor(0.0119, grad_fn=<NllLossBackward>),\n",
" tensor(0.0157, grad_fn=<NllLossBackward>),\n",
" tensor(0.0020, grad_fn=<NllLossBackward>),\n",
" tensor(0.0307, grad_fn=<NllLossBackward>),\n",
" tensor(0.0302, grad_fn=<NllLossBackward>),\n",
" tensor(0.0317, grad_fn=<NllLossBackward>)],\n",
" tensor(0.0178, grad_fn=<ThAddBackward>)],\n",
" [[tensor(0.0038, grad_fn=<NllLossBackward>),\n",
" tensor(0.0047, grad_fn=<NllLossBackward>),\n",
" tensor(0.0053, grad_fn=<NllLossBackward>),\n",
" tensor(0.0008, grad_fn=<NllLossBackward>),\n",
" tensor(0.0114, grad_fn=<NllLossBackward>),\n",
" tensor(0.0114, grad_fn=<NllLossBackward>),\n",
" tensor(0.0120, grad_fn=<NllLossBackward>)],\n",
" tensor(0.0068, grad_fn=<ThAddBackward>)],\n",
" [[tensor(0.0013, grad_fn=<NllLossBackward>)],\n",
" tensor(0.0013, grad_fn=<AddBackward>)]]),\n",
" (tensor(0.3023, grad_fn=<ThAddBackward>),\n",
" [[[tensor(0.0833, grad_fn=<MseLossBackward>),\n",
" tensor(0.0855, grad_fn=<MseLossBackward>),\n",
" tensor(0.0839, grad_fn=<MseLossBackward>),\n",
" tensor(0.0756, grad_fn=<MseLossBackward>)],\n",
" tensor(0.0823, grad_fn=<ThAddBackward>)],\n",
" [[tensor(0.0728, grad_fn=<MseLossBackward>),\n",
" tensor(0.0724, grad_fn=<MseLossBackward>),\n",
" tensor(0.0717, grad_fn=<MseLossBackward>),\n",
" tensor(0.0723, grad_fn=<MseLossBackward>),\n",
" tensor(0.0770, grad_fn=<MseLossBackward>),\n",
" tensor(0.0770, grad_fn=<MseLossBackward>),\n",
" tensor(0.0767, grad_fn=<MseLossBackward>)],\n",
" tensor(0.0741, grad_fn=<ThAddBackward>)],\n",
" [[tensor(0.0739, grad_fn=<MseLossBackward>),\n",
" tensor(0.0723, grad_fn=<MseLossBackward>),\n",
" tensor(0.0734, grad_fn=<MseLossBackward>),\n",
" tensor(0.0714, grad_fn=<MseLossBackward>),\n",
" tensor(0.0717, grad_fn=<MseLossBackward>),\n",
" tensor(0.0717, grad_fn=<MseLossBackward>),\n",
" tensor(0.0718, grad_fn=<MseLossBackward>)],\n",
" tensor(0.0724, grad_fn=<ThAddBackward>)],\n",
" [[tensor(0.0735, grad_fn=<MseLossBackward>)],\n",
" tensor(0.0735, grad_fn=<AddBackward>)]])]"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pred = model(xs_test)\n",
"eval_classification(y_test[0], pred[0][-1][0])\n",
"get_vin_loss(pred, y_test, loss_fn, model, valid_loc=None, target_id=None, \n",
" level_weight=None)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"acc=0.765, precision=0.736, recall=0.765, fl=0.749, adj_MI=0.007, auc=0.627, ap=0.253, confusion_mat=\n",
"[[1067 133]\n",
" [ 209 49]]\n",
"report precision recall f1-score support\n",
"\n",
" 0 0.84 0.89 0.86 1200\n",
" 1 0.27 0.19 0.22 258\n",
"\n",
"avg / total 0.74 0.77 0.75 1458\n",
"\n"
]
},
{
"data": {
"text/plain": [
"[(tensor(6.2221, grad_fn=<ThAddBackward>),\n",
" [[[tensor(0.4945, grad_fn=<NllLossBackward>),\n",
" tensor(0.4714, grad_fn=<NllLossBackward>),\n",
" tensor(0.4632, grad_fn=<NllLossBackward>),\n",
" tensor(2.8643, grad_fn=<NllLossBackward>)],\n",
" tensor(1.0757, grad_fn=<ThAddBackward>)],\n",
" [[tensor(1.6178, grad_fn=<NllLossBackward>),\n",
" tensor(1.8301, grad_fn=<NllLossBackward>),\n",
" tensor(1.5669, grad_fn=<NllLossBackward>),\n",
" tensor(1.8302, grad_fn=<NllLossBackward>),\n",
" tensor(0.8695, grad_fn=<NllLossBackward>),\n",
" tensor(0.8655, grad_fn=<NllLossBackward>),\n",
" tensor(0.8631, grad_fn=<NllLossBackward>)],\n",
" tensor(1.4329, grad_fn=<ThAddBackward>)],\n",
" [[tensor(1.7378, grad_fn=<NllLossBackward>),\n",
" tensor(1.7598, grad_fn=<NllLossBackward>),\n",
" tensor(1.7488, grad_fn=<NllLossBackward>),\n",
" tensor(2.0182, grad_fn=<NllLossBackward>),\n",
" tensor(1.3899, grad_fn=<NllLossBackward>),\n",
" tensor(1.4158, grad_fn=<NllLossBackward>),\n",
" tensor(1.3304, grad_fn=<NllLossBackward>)],\n",
" tensor(1.6415, grad_fn=<ThAddBackward>)],\n",
" [[tensor(2.0721, grad_fn=<NllLossBackward>)],\n",
" tensor(2.0721, grad_fn=<AddBackward>)]]),\n",
" (tensor(9.2298, grad_fn=<ThAddBackward>),\n",
" [[[tensor(0.4708, grad_fn=<NllLossBackward>),\n",
" tensor(0.4488, grad_fn=<NllLossBackward>),\n",
" tensor(0.4466, grad_fn=<NllLossBackward>),\n",
" tensor(2.9583, grad_fn=<NllLossBackward>)],\n",
" tensor(1.1003, grad_fn=<ThAddBackward>)],\n",
" [[tensor(2.5054, grad_fn=<NllLossBackward>),\n",
" tensor(2.8771, grad_fn=<NllLossBackward>),\n",
" tensor(2.2349, grad_fn=<NllLossBackward>),\n",
" tensor(2.5306, grad_fn=<NllLossBackward>),\n",
" tensor(0.9175, grad_fn=<NllLossBackward>),\n",
" tensor(0.9234, grad_fn=<NllLossBackward>),\n",
" tensor(0.9039, grad_fn=<NllLossBackward>)],\n",
" tensor(1.9370, grad_fn=<ThAddBackward>)],\n",
" [[tensor(2.1666, grad_fn=<NllLossBackward>),\n",
" tensor(2.5449, grad_fn=<NllLossBackward>),\n",
" tensor(2.4155, grad_fn=<NllLossBackward>),\n",
" tensor(4.3624, grad_fn=<NllLossBackward>),\n",
" tensor(1.9025, grad_fn=<NllLossBackward>),\n",
" tensor(1.9214, grad_fn=<NllLossBackward>),\n",
" tensor(1.8252, grad_fn=<NllLossBackward>)],\n",
" tensor(2.4909, grad_fn=<ThAddBackward>)],\n",
" [[tensor(3.7016, grad_fn=<NllLossBackward>)],\n",
" tensor(3.7016, grad_fn=<AddBackward>)]]),\n",
" (tensor(0.3275, grad_fn=<ThAddBackward>),\n",
" [[[tensor(0.0827, grad_fn=<MseLossBackward>),\n",
" tensor(0.0782, grad_fn=<MseLossBackward>),\n",
" tensor(0.0821, grad_fn=<MseLossBackward>),\n",
" tensor(0.0795, grad_fn=<MseLossBackward>)],\n",
" tensor(0.0807, grad_fn=<ThAddBackward>)],\n",
" [[tensor(0.0871, grad_fn=<MseLossBackward>),\n",
" tensor(0.0904, grad_fn=<MseLossBackward>),\n",
" tensor(0.0796, grad_fn=<MseLossBackward>),\n",
" tensor(0.0926, grad_fn=<MseLossBackward>),\n",
" tensor(0.0761, grad_fn=<MseLossBackward>),\n",
" tensor(0.0761, grad_fn=<MseLossBackward>),\n",
" tensor(0.0760, grad_fn=<MseLossBackward>)],\n",
" tensor(0.0827, grad_fn=<ThAddBackward>)],\n",
" [[tensor(0.0821, grad_fn=<MseLossBackward>),\n",
" tensor(0.0815, grad_fn=<MseLossBackward>),\n",
" tensor(0.0803, grad_fn=<MseLossBackward>),\n",
" tensor(0.0861, grad_fn=<MseLossBackward>),\n",
" tensor(0.0802, grad_fn=<MseLossBackward>),\n",
" tensor(0.0802, grad_fn=<MseLossBackward>),\n",
" tensor(0.0800, grad_fn=<MseLossBackward>)],\n",
" tensor(0.0815, grad_fn=<ThAddBackward>)],\n",
" [[tensor(0.0826, grad_fn=<MseLossBackward>)],\n",
" tensor(0.0826, grad_fn=<AddBackward>)]])]"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pred = model(xs_train)\n",
"eval_classification(y_train[0], pred[0][-1][0])\n",
"get_vin_loss(pred, y_train, loss_fn, model, valid_loc=None, target_id=None, \n",
" level_weight=None)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train\n",
"acc=0.765, precision=0.736, recall=0.765, fl=0.749, adj_MI=0.007, auc=0.627, ap=0.253, confusion_mat=\n",
"[[1067 133]\n",
" [ 209 49]]\n",
"report precision recall f1-score support\n",
"\n",
" 0 0.84 0.89 0.86 1200\n",
" 1 0.27 0.19 0.22 258\n",
"\n",
"avg / total 0.74 0.77 0.75 1458\n",
"\n",
"Validataion\n",
"acc=0.808, precision=0.798, recall=0.808, fl=0.802, adj_MI=0.041, auc=0.675, ap=0.315, confusion_mat=\n",
"[[158 18]\n",
" [ 22 10]]\n",
"report precision recall f1-score support\n",
"\n",
" 0 0.88 0.90 0.89 176\n",
" 1 0.36 0.31 0.33 32\n",
"\n",
"avg / total 0.80 0.81 0.80 208\n",
"\n",
"Test\n",
"acc=1.000, precision=1.000, recall=1.000, fl=1.000, adj_MI=1.000, auc=1.000, ap=1.000, confusion_mat=\n",
"[[340 0]\n",
" [ 0 77]]\n",
"report precision recall f1-score support\n",
"\n",
" 0 1.00 1.00 1.00 340\n",
" 1 1.00 1.00 1.00 77\n",
"\n",
"avg / total 1.00 1.00 1.00 417\n",
"\n"
]
},
{
"data": {
"text/plain": [
"[(array([0.7654321 , 0.73587779, 0.7654321 , 0.74877395, 0.00738225,\n",
" 0.62744186, 0.25277694]), array([[1067, 133],\n",
" [ 209, 49]], dtype=int64)),\n",
" (array([0.80769231, 0.7976801 , 0.80769231, 0.80236243, 0.04149792,\n",
" 0.67542614, 0.31500486]), array([[158, 18],\n",
" [ 22, 10]], dtype=int64)),\n",
" (array([1., 1., 1., 1., 1., 1., 1.]), array([[340, 0],\n",
" [ 0, 77]], dtype=int64))]"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"eval_classification_multi_splits(model, [xs_train, xs_val, xs_test], [y_train[0], y_val[0], y_test[0]], \n",
" batch_size=None, multi_heads=False, cls_head=0, average='weighted', return_result=True, \n",
" split_names=['Train', 'Validataion', 'Test'], verbose=True, \n",
" predict_func=predict_func, pred_kwargs={'target_idx':0, 'level':-1, 'loc':0, 'train':False})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Neural network models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# loss_fn_cls = torch.nn.CrossEntropyLoss(weight=torch.tensor([0.3, 0.6], device=device))\n",
"loss_fn_cls = torch.nn.CrossEntropyLoss()\n",
"loss_fn_reg = torch.nn.MSELoss()\n",
"loss_fns = [loss_fn_cls, loss_fn_reg]\n",
"# For multiple data types, there are multiple interaction mats\n",
"feat_interact_loss_type = 'graph_laplacian'\n",
"if num_data_types > 1:\n",
" weight_path = ['decoders', range(num_data_types), 'weight'] \n",
"else:\n",
" weight_path = ['decoder', 'weight']\n",
"loss_feat_interact = Loss_feature_interaction(interaction_mat=interaction_mat, \n",
" loss_type=feat_interact_loss_type, \n",
" weight_path=weight_path, \n",
" normalize=True)\n",
"other_loss_fns = [loss_feat_interact]\n",
"if num_data_types > 1:\n",
" view_sim_loss_type = 'hub'\n",
" explicit_target = True\n",
" cal_target='mean-feature'\n",
" # In this set of experiments, the encoders for all views will have the same hidden_dim\n",
" loss_view_sim = Loss_view_similarity(sections=hidden_dim[-1], loss_type=view_sim_loss_type, \n",
" explicit_target=explicit_target, cal_target=cal_target, target=None)\n",
" loss_fns.append(loss_view_sim)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model_names = []\n",
"split_names = ['train', 'val', 'test']\n",
"metric_names = ['acc', 'precision', 'recall', 'f1_score', 'adjusted_mutual_info', 'auc', \n",
" 'average_precision']\n",
"metric_all = []\n",
"confusion_mat_all = []\n",
"loss_his_all = []\n",
"acc_his_all = []"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_result(model, best_model, best_val_acc, best_epoch, x_train, y_train, x_val, y_val, \n",
" x_test, y_test, batch_size, multi_heads, show_results_in_notebook=True, \n",
" loss_idx=0, acc_idx=0):\n",
" if len(x_val) > 0:\n",
" print(f'Best model on validation set: best_val_acc={best_val_acc:.2f}, epoch={best_epoch}')\n",
" metric = eval_classification_multi_splits(best_model, xs=[x_train, x_val, x_test], \n",
" ys=[y_train, y_val, y_test], batch_size=batch_size, multi_heads=multi_heads)\n",
"\n",
" if show_results_in_notebook:\n",
" print('\\nModel after the last training epoch:')\n",
" eval_classification_multi_splits(model, xs=[x_train, x_val, x_test], \n",
" ys=[y_train, y_val, y_test], batch_size=batch_size, \n",
" multi_heads=multi_heads, return_result=False)\n",
"\n",
" plot_history_multi_splits([loss_train_his, loss_val_his, loss_test_his], title='Loss', \n",
" idx=loss_idx)\n",
" plot_history_multi_splits([acc_train_his, acc_val_his, acc_test_his], title='Acc', idx=acc_idx)\n",
" # scatter plot\n",
" plot_data_multi_splits(best_model, [x_train, x_val, x_test], [y_train, y_val, y_test], \n",
" num_heads=2 if multi_heads else 1, \n",
" titles=['Training', 'Validation', 'Test'], batch_size=batch_size)\n",
" return metric"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Plain deep learning model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 1000\n",
"print_every = 100\n",
"eval_every = 1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"in_dim = x_train.shape[1]\n",
"print('Plain deep learning model')\n",
"model_names.append('NN')\n",
"model = DenseLinear(in_dim, hidden_dim+[num_cls], dense=dense, residual=residual).to(device)\n",
"multi_heads = False\n",
"\n",
"loss_train_his = []\n",
"loss_val_his = []\n",
"loss_test_his = []\n",
"acc_train_his = []\n",
"acc_val_his = []\n",
"acc_test_his = []\n",
"best_model = model\n",
"best_val_acc = 0\n",
"best_epoch = 0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"best_model, best_val_acc, best_epoch = train_single_loss(model, x_train, y_train, \n",
" x_val, y_val, x_test, y_test, loss_fn=loss_fn_cls, lr=lr, weight_decay=weight_decay, \n",
" amsgrad=True, batch_size=batch_size, num_epochs=num_epochs, \n",
" reduce_every=reduce_every, eval_every=eval_every, print_every=print_every, verbose=False, \n",
" loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his, \n",
" acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his, \n",
" return_best_val=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"metric = get_result(model, best_model, best_val_acc, best_epoch, x_train, y_train, x_val, y_val, \n",
" x_test, y_test, batch_size, multi_heads, show_results_in_notebook, \n",
" loss_idx=0, acc_idx=0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loss_his_all.append([loss_train_his, loss_val_his, loss_test_his])\n",
"acc_his_all.append([acc_train_his, acc_val_his, acc_test_his])\n",
"metric_all.append([v[0] for v in metric])\n",
"confusion_mat_all.append([v[1] for v in metric])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Factorization AutoEncoder"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def run_one_model(model, loss_weights, other_loss_weights, \n",
" loss_his_all=[], acc_his_all=[], metric_all=[], confusion_mat_all=[],\n",
" heads=[0,1], multi_heads=True, return_results=False, \n",
" loss_fns=loss_fns, other_loss_fns=other_loss_fns, \n",
" lr=lr, weight_decay=weight_decay, batch_size=batch_size, \n",
" num_epochs=num_epochs, reduce_every=reduce_every, eval_every=eval_every, \n",
" print_every=print_every, x_train=x_train, y_train=y_train,\n",
" x_val=x_val, y_val=y_val, x_test=x_test, y_test=y_test,\n",
" show_results_in_notebook=show_results_in_notebook):\n",
" \"\"\"Train a model and get results \n",
" Most of the parameters are from the context; handle it properly\n",
" \"\"\"\n",
" loss_train_his = []\n",
" loss_val_his = []\n",
" loss_test_his = []\n",
" acc_train_his = []\n",
" acc_val_his = []\n",
" acc_test_his = []\n",
" best_model = model\n",
" best_val_acc = 0\n",
" best_epoch = 0\n",
"\n",
" best_model, best_val_acc, best_epoch = train_multiloss(model, x_train, [y_train, x_train], \n",
" x_val, [y_val, x_val], x_test, [y_test, x_test], heads=heads, loss_fns=loss_fns, \n",
" loss_weights=loss_weights, other_loss_fns=other_loss_fns, \n",
" other_loss_weights=other_loss_weights, \n",
" lr=lr, weight_decay=weight_decay, batch_size=batch_size, num_epochs=num_epochs, \n",
" reduce_every=reduce_every, eval_every=eval_every, print_every=print_every,\n",
" loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his, \n",
" acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his, \n",
" return_best_val=True, amsgrad=True, verbose=False)\n",
"\n",
" metric = get_result(model, best_model, best_val_acc, best_epoch, x_train, y_train, x_val, y_val, \n",
" x_test, y_test, batch_size, multi_heads, show_results_in_notebook, \n",
" loss_idx=0, acc_idx=0)\n",
"\n",
" loss_his_all.append([loss_train_his, loss_val_his, loss_test_his])\n",
" acc_his_all.append([acc_train_his, acc_val_his, acc_test_his])\n",
" metric_all.append([v[0] for v in metric])\n",
" confusion_mat_all.append([v[1] for v in metric])\n",
" \n",
" if return_results:\n",
" return loss_his_all, acc_his_all, metric_all, confusion_mat_all"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"decoder_norm = False\n",
"uniform_decoder_norm = False\n",
"print('Plain AutoEncoder model')\n",
"model_names.append('AE')\n",
"model = AutoEncoder(in_dim, hidden_dim, num_cls, dense=dense, residual=residual,\n",
" decoder_norm=decoder_norm, uniform_decoder_norm=uniform_decoder_norm).to(device)\n",
"loss_weights = [1,1]\n",
"other_loss_weights = [0]\n",
"# heads = None should work for all the following; keep this for clarity\n",
"heads = [0,1] \n",
"run_one_model(model, loss_weights, other_loss_weights,\n",
" loss_his_all, acc_his_all, metric_all, confusion_mat_all,\n",
" heads=heads, multi_heads=True, return_results=False, \n",
" loss_fns=loss_fns, other_loss_fns=other_loss_fns, \n",
" lr=lr, weight_decay=weight_decay, batch_size=batch_size, \n",
" num_epochs=num_epochs, reduce_every=reduce_every, eval_every=eval_every, \n",
" print_every=print_every, x_train=x_train, y_train=y_train,\n",
" x_val=x_val, y_val=y_val, x_test=x_test, y_test=y_test,\n",
" show_results_in_notebook=show_results_in_notebook)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Add feature interaction network regularizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if num_data_types > 1:\n",
" fuse_type = 'sum'\n",
" print('MultiviewAE with feature interaction network regularizer')\n",
" model_names.append('MultiviewAE + feat_int')\n",
" model = MultiviewAE(in_dims=in_dims, hidden_dims=hidden_dim, out_dim=num_cls, \n",
" fuse_type=fuse_type, dense=dense, residual=residual, \n",
" residual_layers='all', decoder_norm=decoder_norm, \n",
" decoder_norm_dim=0, uniform_decoder_norm=uniform_decoder_norm, \n",
" nonlinearity=nn.ReLU(), last_nonlinearity=True, bias=True).to(device)\n",
"else:\n",
" print('AutoEncoder with feature interaction network regularizer')\n",
" model_names.append('AE + feat_int')\n",
" model = AutoEncoder(in_dim, hidden_dim, num_cls, dense=dense, residual=residual, \n",
" decoder_norm=decoder_norm, uniform_decoder_norm=uniform_decoder_norm).to(device)\n",
"\n",
"loss_weights = [1,1]\n",
"other_loss_weights = [1]\n",
"heads = [0,1]\n",
"run_one_model(model, loss_weights, other_loss_weights, \n",
" loss_his_all, acc_his_all, metric_all, confusion_mat_all,\n",
" heads=heads, multi_heads=True, return_results=False, \n",
" loss_fns=loss_fns, other_loss_fns=other_loss_fns, \n",
" lr=lr, weight_decay=weight_decay, batch_size=batch_size, \n",
" num_epochs=num_epochs, reduce_every=reduce_every, eval_every=eval_every, \n",
" print_every=print_every, x_train=x_train, y_train=y_train,\n",
" x_val=x_val, y_val=y_val, x_test=x_test, y_test=y_test,\n",
" show_results_in_notebook=show_results_in_notebook)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## For multi-view data, add view similarity network regularizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if num_data_types > 1:\n",
" # plain multiviewAE; compare it with plain AutoEncoder to see \n",
" # if separating views in lower layers in MultiviewAE is better than combining them all the way\n",
" print('Run plain MultiviewAE model')\n",
" model_names.append('MultiviewAE')\n",
" model = MultiviewAE(in_dims=in_dims, hidden_dims=hidden_dim, out_dim=num_cls, \n",
" fuse_type=fuse_type, dense=dense, residual=residual, \n",
" residual_layers='all', decoder_norm=decoder_norm, \n",
" decoder_norm_dim=0, uniform_decoder_norm=uniform_decoder_norm, \n",
" nonlinearity=nn.ReLU(), last_nonlinearity=True, bias=True).to(device)\n",
"\n",
" loss_weights = [1,1]\n",
" other_loss_weights = [0]\n",
" heads = [0,1]\n",
" run_one_model(model, loss_weights, other_loss_weights, \n",
" loss_his_all, acc_his_all, metric_all, confusion_mat_all,\n",
" heads=heads, multi_heads=True, return_results=False, \n",
" loss_fns=loss_fns, other_loss_fns=other_loss_fns, \n",
" lr=lr, weight_decay=weight_decay, batch_size=batch_size, \n",
" num_epochs=num_epochs, reduce_every=reduce_every, eval_every=eval_every, \n",
" print_every=print_every, x_train=x_train, y_train=y_train,\n",
" x_val=x_val, y_val=y_val, x_test=x_test, y_test=y_test,\n",
" show_results_in_notebook=show_results_in_notebook)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if num_data_types > 1:\n",
" print('MultiviewAE with view similarity regularizers')\n",
" model_names.append('MultiviewAE + view_sim')\n",
" model = MultiviewAE(in_dims=in_dims, hidden_dims=hidden_dim, out_dim=num_cls, \n",
" fuse_type=fuse_type, dense=dense, residual=residual, \n",
" residual_layers='all', decoder_norm=decoder_norm, \n",
" decoder_norm_dim=0, uniform_decoder_norm=uniform_decoder_norm, \n",
" nonlinearity=nn.ReLU(), last_nonlinearity=True, bias=True).to(device)\n",
" loss_weights = [1,1,1]\n",
" other_loss_weights = [0]\n",
" heads = [0,1,2]\n",
" run_one_model(model, loss_weights, other_loss_weights, \n",
" loss_his_all, acc_his_all, metric_all, confusion_mat_all,\n",
" heads=heads, multi_heads=True, return_results=False, \n",
" loss_fns=loss_fns, other_loss_fns=other_loss_fns, \n",
" lr=lr, weight_decay=weight_decay, batch_size=batch_size, \n",
" num_epochs=num_epochs, reduce_every=reduce_every, eval_every=eval_every, \n",
" print_every=print_every, x_train=x_train, y_train=y_train,\n",
" x_val=x_val, y_val=y_val, x_test=x_test, y_test=y_test,\n",
" show_results_in_notebook=show_results_in_notebook)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if num_data_types > 1:\n",
" print('MultiviewAE with both feature interaction and view similarity regularizers')\n",
" model_names.append('MultiviewAE + feat_int + view_sim')\n",
" model = MultiviewAE(in_dims=in_dims, hidden_dims=hidden_dim, out_dim=num_cls, \n",
" fuse_type=fuse_type, dense=dense, residual=residual, \n",
" residual_layers='all', decoder_norm=decoder_norm, \n",
" decoder_norm_dim=0, uniform_decoder_norm=uniform_decoder_norm, \n",
" nonlinearity=nn.ReLU(), last_nonlinearity=True, bias=True).to(device)\n",
" loss_weights = [1,1,1]\n",
" other_loss_weights = [1]\n",
" heads = [0,1,2]\n",
" run_one_model(model, loss_weights, other_loss_weights,\n",
" loss_his_all, acc_his_all, metric_all, confusion_mat_all,\n",
" heads=heads, multi_heads=True, return_results=False, \n",
" loss_fns=loss_fns, other_loss_fns=other_loss_fns, \n",
" lr=lr, weight_decay=weight_decay, batch_size=batch_size, \n",
" num_epochs=num_epochs, reduce_every=reduce_every, eval_every=eval_every, \n",
" print_every=print_every, x_train=x_train, y_train=y_train,\n",
" x_val=x_val, y_val=y_val, x_test=x_test, y_test=y_test,\n",
" show_results_in_notebook=show_results_in_notebook)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(f'{result_folder}/{res_file}', 'wb') as f:\n",
" print(f'Write result to file {result_folder}/{res_file}')\n",
" pickle.dump({'loss_his_all': loss_his_all,\n",
" 'acc_his_all': acc_his_all,\n",
" 'metric_all': metric_all,\n",
" 'confusion_mat_all': confusion_mat_all,\n",
" 'model_names': model_names,\n",
" 'split_names': split_names,\n",
" 'metric_names': metric_names\n",
" }, f)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}