a b/exp_template-mv-nn-v2.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": null,
6
   "metadata": {},
7
   "outputs": [],
8
   "source": [
9
    "import socket\n",
10
    "if socket.gethostname() == 'dlm':\n",
11
    "  %env CUDA_DEVICE_ORDER=PCI_BUS_ID\n",
12
    "  %env CUDA_VISIBLE_DEVICES=3"
13
   ]
14
  },
15
  {
16
   "cell_type": "code",
17
   "execution_count": null,
18
   "metadata": {},
19
   "outputs": [],
20
   "source": [
21
    "import os\n",
22
    "import sys\n",
23
    "import re\n",
24
    "import collections\n",
25
    "import functools\n",
26
    "import requests, zipfile, io\n",
27
    "import pickle\n",
28
    "import copy\n",
29
    "\n",
30
    "import pandas\n",
31
    "import numpy as np\n",
32
    "import matplotlib\n",
33
    "import matplotlib.pyplot as plt\n",
34
    "import sklearn\n",
35
    "import sklearn.decomposition\n",
36
    "import sklearn.metrics\n",
37
    "import networkx\n",
38
    "\n",
39
    "import torch\n",
40
    "import torch.nn as nn\n",
41
    "\n",
42
    "lib_path = 'I:/code'\n",
43
    "if not os.path.exists(lib_path):\n",
44
    "  lib_path = '/media/6T/.tianle/.lib'\n",
45
    "if not os.path.exists(lib_path):\n",
46
    "  lib_path = '/projects/academic/azhang/tianlema/lib'\n",
47
    "if os.path.exists(lib_path) and lib_path not in sys.path:\n",
48
    "  sys.path.append(lib_path)\n",
49
    "  \n",
50
    "from dl.models.basic_models import *\n",
51
    "from dl.utils.visualization.visualization import *\n",
52
    "from dl.utils.outlier import *\n",
53
    "from dl.utils.train import *\n",
54
    "from autoencoder.autoencoder import *\n",
55
    "from dl.utils.utils import get_overlap_samples, filter_clinical_dict, get_target_variable\n",
56
    "from dl.utils.utils import get_shuffled_data, target_to_numpy\n",
57
    "\n",
58
    "%load_ext autoreload\n",
59
    "%autoreload 2\n",
60
    "\n",
61
    "\n",
62
    "use_gpu = True\n",
63
    "if use_gpu and torch.cuda.is_available():\n",
64
    "  device = torch.device('cuda')\n",
65
    "  print('Using GPU:)')\n",
66
    "else:\n",
67
    "  device = torch.device('cpu')\n",
68
    "  print('Using CPU:(')"
69
   ]
70
  },
71
  {
72
   "cell_type": "code",
73
   "execution_count": null,
74
   "metadata": {},
75
   "outputs": [],
76
   "source": [
77
    "# neural net models include nn (mlp), resnet, densenet; another choice is ml (machine learning)\n",
78
    "# model_type, dense, residual are dependent\n",
79
    "model_type = 'resnet'\n",
80
    "dense = False\n",
81
    "residual = True\n",
82
    "hidden_dim = [100, 100]\n",
83
    "train_portion = 0.7\n",
84
    "val_portion = 0.1\n",
85
    "test_portion = 0.2\n",
86
    "num_train_types = -1 # -1 means not used\n",
87
    "num_val_types = -1\n",
88
    "num_test_types = -1 # this will almost never be used \n",
89
    "num_sets = 10\n",
90
    "num_folds = 10 # no longer used anymore\n",
91
    "sel_set_idx = 0\n",
92
    "cv_type = 'instance-shuffle' # or 'group-shuffle'; cross validation shuffle method\n",
93
    "sel_disease_types = 'all'\n",
94
    "# The number of total samples and the numbers for each class in selected disease types must >=\n",
95
    "min_num_samples_per_type_cls = [100, 0]\n",
96
    "# if 'auto-search', will search for the file first; if not exist, then generate random data split\n",
97
    "# and write to the file;\n",
98
    "# if string other than 'auto-search' is provided, assume the string is a proper file name, \n",
99
    "# and read the file;\n",
100
    "# if False, will generate a random data split, but not write to file \n",
101
    "# if True will generate a random data split, and write to file\n",
102
    "predefined_sample_set_file = 'auto-search' \n",
103
    "target_variable = 'PFI' # To do: target variable can be a list (partially handled)\n",
104
    "target_variable_type = 'discrete' # or 'continuous' real numbers\n",
105
    "target_variable_range = [0, 1]\n",
106
    "data_type = ['gene', 'methy', 'rppa', 'mirna']\n",
107
    "normal_transform_feature = True\n",
108
    "additional_vars = []#['age_at_initial_pathologic_diagnosis', 'gender']\n",
109
    "additional_var_types = []#['continuous', 'discrete']\n",
110
    "additional_var_ranges = []#[[0, 100], ['MALE', 'FEMALE']]\n",
111
    "randomize_labels = False\n",
112
    "lr = 5e-4\n",
113
    "weight_decay = 1e-4\n",
114
    "num_epochs = 1000\n",
115
    "reduce_every = 500\n",
116
    "show_results_in_notebook = True"
117
   ]
118
  },
119
  {
120
   "cell_type": "markdown",
121
   "metadata": {},
122
   "source": [
123
    "## Prepare data"
124
   ]
125
  },
126
  {
127
   "cell_type": "code",
128
   "execution_count": null,
129
   "metadata": {},
130
   "outputs": [],
131
   "source": [
132
    "result_folder = 'results'\n",
133
    "data_split_idx_folder = f'{result_folder}/data_split_idx'\n",
134
    "project_folder = '../../pan-can-atlas'\n",
135
    "print_stats = True\n",
136
    "if not os.path.exists(project_folder):\n",
137
    "  project_folder = 'F:/TCGA/Pan-Cancer-Atlas'\n",
138
    "filepath = f'{project_folder}/data/processed/combined2.pkl'\n",
139
    "with open(filepath, 'rb') as f:\n",
140
    "  data = pickle.load(f)\n",
141
    "  patient_clinical = data['patient_clinical']\n",
142
    "  feature_mat_dict = data['feature_mat_dict']\n",
143
    "  feature_interaction_mat_dict = data['feature_interaction_mat_dict']\n",
144
    "  feature_id_dict = data['feature_id_dict']\n",
145
    "  aliquot_id_dict = data['aliquot_id_dict']\n",
146
    "#   sel_patient_ids = data['sample_id_sel']\n",
147
    "#   sample_idx_sel_dict = data['sample_idx_sel_dict']\n",
148
    "#   for k, v in sample_idx_sel_dict.items():\n",
149
    "#     assert [i[:12] for i in aliquot_id_dict[k][v]] == sel_patient_ids\n",
150
    "\n",
151
    "if print_stats:\n",
152
    "  for k, v in feature_mat_dict.items():\n",
153
    "    print(f'feature_mat: {k}, max={v.max():.3f}, min={v.min():.3f}, '\n",
154
    "          f'mean={v.mean():.3f}, {np.mean(v>0):.3f}')  \n",
155
    "  for k, v in feature_interaction_mat_dict.items():\n",
156
    "    print(f'feature_interaction_mat: {k}, max={v.max():.3f}, min={v.min():.3f}, '\n",
157
    "          f'mean={v.mean():.3f}, {np.mean(v>0):.3f}') \n",
158
    "  for k, v in feature_id_dict.items():\n",
159
    "    print(k, v.shape, v[0])\n",
160
    "  for k, v in aliquot_id_dict.items():\n",
161
    "    print(k, v.shape, v[0])"
162
   ]
163
  },
164
  {
165
   "cell_type": "code",
166
   "execution_count": null,
167
   "metadata": {},
168
   "outputs": [],
169
   "source": [
170
    "# select samples with required clinical variables\n",
171
    "clinical_dict = filter_clinical_dict(target_variable, target_variable_type=target_variable_type, \n",
172
    "                                     target_variable_range=target_variable_range, \n",
173
    "                                     clinical_dict=patient_clinical)\n",
174
    "if len(additional_vars) > 0:\n",
175
    "  clinical_dict = filter_clinical_dict(additional_vars, target_variable_type=additional_var_types, \n",
176
    "                                       target_variable_range=additional_var_ranges, \n",
177
    "                                       clinical_dict=clinical_dict)\n",
178
    "\n",
179
    "# select samples with feature matrix of given type(s)\n",
180
    "if isinstance(data_type, str):\n",
181
    "  sample_list = {s[:12] for s in aliquot_id_dict[data_type]}\n",
182
    "  data_type_str = data_type\n",
183
    "elif isinstance(data_type, (list, tuple)):\n",
184
    "  sample_list = get_overlap_samples([aliquot_id_dict[dtype] for dtype in data_type], \n",
185
    "                                    common_list=None, start=0, end=12, return_common_list=True)\n",
186
    "  data_type_str = '-'.join(sorted(data_type))\n",
187
    "else:\n",
188
    "  raise ValueError(f'data_type must be str or list/tuple, but is {type(data_type)}')\n",
189
    "sample_list = sample_list.intersection(clinical_dict)\n",
190
    "\n",
191
    "# select samples with given disease types\n",
192
    "sel_disease_type_str = sel_disease_types # will be overwritten if it is a list\n",
193
    "if isinstance(sel_disease_types, (list, tuple)):\n",
194
    "  sample_list = [s for s in sample_list if clinical_dict[s]['type'] in sel_disease_types]\n",
195
    "  sel_disease_type_str = '-'.join(sorted(sel_disease_types))\n",
196
    "elif isinstance(sel_disease_types, str) and sel_disease_types!='all':\n",
197
    "  sample_list = [s for s in sample_list if clinical_dict[s]['type'] == sel_disease_types]\n",
198
    "else:\n",
199
    "  assert sel_disease_types == 'all'\n",
200
    " \n",
201
    "# For classification tasks with given min_num_samples_per_type_cls,\n",
202
    "# only keep disease types that have a minimal number of samples per type and per class\n",
203
    "# Reflection: it might be better to use collections.defaultdict(list) to store samples in each type\n",
204
    "type_cnt = collections.Counter([clinical_dict[s]['type'] for s in sample_list])\n",
205
    "if sum(min_num_samples_per_type_cls)>0 and (target_variable_type=='discrete' \n",
206
    "                                            or target_variable_type[0]=='discrete'):\n",
207
    "  # the number of samples in each disease type >= min_num_samples_per_type_cls[0]\n",
208
    "  type_cnt = {k: v for k, v in type_cnt.items() if v >= min_num_samples_per_type_cls[0]}\n",
209
    "  disease_type_cnt = {}\n",
210
    "  for k in type_cnt:\n",
211
    "    # collections.Counter can accept generator\n",
212
    "    cls_cnt = collections.Counter(clinical_dict[s][target_variable] \n",
213
    "                                  if isinstance(target_variable, str) \n",
214
    "                                  else clinical_dict[s][target_variable[0]] \n",
215
    "                                  for s in sample_list if clinical_dict[s]['type']==k)\n",
216
    "    if all([v >= min_num_samples_per_type_cls[1] for v in cls_cnt.values()]):\n",
217
    "      # the number of samples in each class >= min_num_samples_per_type_cls[1]\n",
218
    "      disease_type_cnt[k] = dict(cls_cnt)\n",
219
    "      print(k, disease_type_cnt[k])\n",
220
    "  sample_list = [s for s in sample_list if clinical_dict[s]['type'] in disease_type_cnt]\n",
221
    "sel_patient_ids = sorted(sample_list)\n",
222
    "print(f'Selected {len(sel_patient_ids)} patients from {len(disease_type_cnt)} disease_types')"
223
   ]
224
  },
225
  {
226
   "cell_type": "markdown",
227
   "metadata": {},
228
   "source": [
229
    "### Split data into training, validation, and test sets"
230
   ]
231
  },
232
  {
233
   "cell_type": "code",
234
   "execution_count": null,
235
   "metadata": {},
236
   "outputs": [],
237
   "source": [
238
    "predefined_sample_set_filename = (target_variable if isinstance(target_variable,str) \n",
239
    "                                else '-'.join(target_variable))\n",
240
    "predefined_sample_set_filename += f'_{cv_type}'\n",
241
    "if len(additional_vars) > 0:\n",
242
    "  predefined_sample_set_filename += f\"_{'-'.join(sorted(additional_vars))}\"\n",
243
    "\n",
244
    "predefined_sample_set_filename += (f\"_{data_type_str}_{sel_disease_type_str}_\"\n",
245
    "                                   f\"{'-'.join(map(str, min_num_samples_per_type_cls))}\")\n",
246
    "predefined_sample_set_filename += f\"_{'-'.join(map(str, [train_portion, val_portion, test_portion]))}\"\n",
247
    "if cv_type == 'group-shuffle' and num_train_types > 0:\n",
248
    "  predefined_sample_set_filename += f\"_{'-'.join(map(str, [num_train_types, num_val_types, num_test_types]))}\"\n",
249
    "predefined_sample_set_filename += f'_{num_sets}sets'\n",
250
    "res_file = f\"{predefined_sample_set_filename}_{sel_set_idx}_{'-'.join(map(str, hidden_dim))}_{model_type}.pkl\"\n",
251
    "predefined_sample_set_filename += '.pkl'\n",
252
    "# This will be overwritten if predefined_sample_set_file == 'auto-search' or filepath, and the file exists\n",
253
    "predefined_sample_sets = [get_shuffled_data(sel_patient_ids, clinical_dict, cv_type=cv_type, \n",
254
    "                  instance_portions=[train_portion, val_portion, test_portion], \n",
255
    "                  group_sizes=[num_train_types, num_val_types, num_test_types],\n",
256
    "                  group_variable_name='type', seed=None, verbose=False) for i in range(num_sets)]\n",
257
    "if predefined_sample_set_file == 'auto-search':\n",
258
    "  if os.path.exists(f'{data_split_idx_folder}/{predefined_sample_set_filename}'):\n",
259
    "    with open(f'{data_split_idx_folder}/{predefined_sample_set_filename}', 'rb') as f:\n",
260
    "      print(f'Read predefined_sample_set_file: '\n",
261
    "            f'{data_split_idx_folder}/{predefined_sample_set_filename}')\n",
262
    "      tmp = pickle.load(f)\n",
263
    "      # overwrite calculated predefined_sample_sets\n",
264
    "      predefined_sample_sets = tmp['predefined_sample_sets']    \n",
265
    "elif isinstance(predefined_sample_set_file, str): # but not 'auto-search'; assume it's a file name\n",
266
    "  if os.path.exists(predefined_sample_set_file):\n",
267
    "    with open(f'{data_split_idx_folder}/{predefined_sample_set_file}', 'rb') as f:\n",
268
    "      print(f'Read predefined_sample_set_file: {data_split_idx_folder}/{predefined_sample_set_file}')\n",
269
    "      tmp = pickle.load(f)\n",
270
    "      predefined_sample_sets = tmp['predefined_sample_sets']\n",
271
    "  else:\n",
272
    "    raise ValueError(f'predefined_sample_set_file: {data_split_idx_folder}/{predefined_sample_set_file} does not exist!')\n",
273
    "\n",
274
    "if (not os.path.exists(f'{data_split_idx_folder}/{predefined_sample_set_filename}') \n",
275
    "    and predefined_sample_set_file == 'auto-search') or predefined_sample_set_file is True:\n",
276
    "  with open(f'{data_split_idx_folder}/{predefined_sample_set_filename}', 'wb') as f:\n",
277
    "      print(f'Write predefined_sample_set_file: {data_split_idx_folder}/{predefined_sample_set_filename}')\n",
278
    "      pickle.dump({'predefined_sample_sets': predefined_sample_sets}, f)\n",
279
    "     \n",
280
    "sel_patient_ids, idx_splits = predefined_sample_sets[sel_set_idx]\n",
281
    "train_idx, val_idx, test_idx = idx_splits"
282
   ]
283
  },
284
  {
285
   "cell_type": "code",
286
   "execution_count": null,
287
   "metadata": {},
288
   "outputs": [],
289
   "source": [
290
    "if isinstance(data_type, str):\n",
291
    "  sample_lists = [aliquot_id_dict[data_type]]\n",
292
    "else:\n",
293
    "  assert isinstance(data_type, (list, tuple))\n",
294
    "  sample_lists = [aliquot_id_dict[dtype] for dtype in data_type]\n",
295
    "idx_lists = get_overlap_samples(sample_lists=sample_lists, common_list=sel_patient_ids, \n",
296
    "                    start=0, end=12, return_common_list=False)\n",
297
    "sample_idx_sel_dict = {}\n",
298
    "if isinstance(data_type, str):\n",
299
    "  sample_idx_sel_dict = {data_type: idx_lists[0]}\n",
300
    "else:\n",
301
    "  sample_idx_sel_dict = {dtype: idx_list for dtype, idx_list in zip(data_type, idx_lists)}"
302
   ]
303
  },
304
  {
305
   "cell_type": "code",
306
   "execution_count": null,
307
   "metadata": {},
308
   "outputs": [],
309
   "source": [
310
    "if isinstance(data_type, str):\n",
311
    "  print(f'Only use one data type: {data_type}')\n",
312
    "  num_data_types = 1\n",
313
    "  mat = feature_mat_dict[data_type][sample_idx_sel_dict[data_type]]\n",
314
    "  # Data preprocessing: make each row have mean 0 and sd 1.\n",
315
    "  x = (mat - mat.mean(axis=1, keepdims=True)) / mat.std(axis=1, keepdims=True)\n",
316
    "  interaction_mat = feature_interaction_mat_dict[data_type]\n",
317
    "  interaction_mat = torch.from_numpy(interaction_mat).float().to(device)\n",
318
    "  # Normalize these interaction mat\n",
319
    "  interaction_mat = interaction_mat / interaction_mat.norm()\n",
320
    "else:\n",
321
    "  mat = []\n",
322
    "  interaction_mats = []\n",
323
    "  in_dims = []\n",
324
    "  num_data_types = len(data_type)\n",
325
    "  # do not handle the special case of [data_type] to avoid too much code complexity\n",
326
    "  assert num_data_types > 1 \n",
327
    "  for dtype in data_type: # multiple data types\n",
328
    "    m = feature_mat_dict[dtype][sample_idx_sel_dict[dtype]]\n",
329
    "    #When there are multiple data types, make sure each type is normalized to have mean 0 and std 1\n",
330
    "    m = (m - m.mean(axis=1, keepdims=True)) / m.std(axis=1, keepdims=True)\n",
331
    "    mat.append(m)\n",
332
    "    in_dims.append(m.shape[1])\n",
333
    "    # For neural network model graph laplacian regularizer\n",
334
    "    interaction_mat = feature_interaction_mat_dict[dtype]\n",
335
    "    interaction_mat = torch.from_numpy(interaction_mat).float().to(device)\n",
336
    "    # Normalize these interaction mat\n",
337
    "    interaction_mat = interaction_mat / interaction_mat.norm()\n",
338
    "    interaction_mats.append(interaction_mat)\n",
339
    "    print(f'{dtype}: {m.shape}; '\n",
340
    "          f'interaction_mat: mean={interaction_mat.mean().item():2f}, '\n",
341
    "          f'std={interaction_mat.std().item():2f}, {interaction_mat.shape[0]}')\n",
342
    "  # Later interaction_mat will be passed to Loss_feature_interaction\n",
343
    "  interaction_mat = interaction_mats\n",
344
    "  mat = np.concatenate(mat, axis=1)\n",
345
    "  # For machine learing methods that use concatenated features without knowing underlying views,\n",
346
    "  # it might be good to make each row have mean 0 and sd 1.\n",
347
    "  x = (mat - mat.mean(axis=1, keepdims=True)) / mat.std(axis=1, keepdims=True)\n",
348
    "\n",
349
    "if normal_transform_feature:\n",
350
    "  X = x\n",
351
    "else:\n",
352
    "  X = mat"
353
   ]
354
  },
355
  {
356
   "cell_type": "code",
357
   "execution_count": null,
358
   "metadata": {},
359
   "outputs": [],
360
   "source": [
361
    "y_targets = get_target_variable(target_variable, clinical_dict, sel_patient_ids)\n",
362
    "y_true = target_to_numpy(y_targets, target_variable_type, target_variable_range)\n",
363
    "if len(additional_vars) > 0:\n",
364
    "  additional_variables = get_target_variable(additional_vars, clinical_dict, sel_patient_ids)\n",
365
    "  # to do handle additional variables such as age and gender"
366
   ]
367
  },
368
  {
369
   "cell_type": "markdown",
370
   "metadata": {},
371
   "source": [
372
    "### To do: handle multiple inputs, multiple targets"
373
   ]
374
  },
375
  {
376
   "cell_type": "code",
377
   "execution_count": null,
378
   "metadata": {},
379
   "outputs": [],
380
   "source": [
381
    "# sklearn classifiers also accept torch.Tensor\n",
382
    "X = torch.tensor(X).float().to(device)\n",
383
    "y_true = torch.tensor(y_true).long().to(device)\n",
384
    "num_cls = len(torch.unique(y_true))\n",
385
    "\n",
386
    "x_train, y_train = X[train_idx], y_true[train_idx]\n",
387
    "x_val, y_val = X[val_idx], y_true[val_idx]\n",
388
    "x_test, y_test = X[test_idx], y_true[test_idx]\n",
389
    "print(x_train.shape, x_val.shape, x_test.shape, y_train.shape, y_val.shape, y_test.shape)\n",
390
    "\n",
391
    "label_prob_train = get_label_prob(y_train, verbose=False)\n",
392
    "label_probs = [label_prob_train]\n",
393
    "if len(y_val)>0:\n",
394
    "  label_prob_val = get_label_prob(y_val, verbose=False)\n",
395
    "  assert len(label_prob_train) == len(label_prob_val)\n",
396
    "  label_probs.append(label_prob_val)\n",
397
    "if len(y_test)>0:\n",
398
    "  label_prob_test = get_label_prob(y_test, verbose=False)\n",
399
    "  assert len(label_prob_train) == len(label_prob_test)\n",
400
    "  label_probs.append(label_prob_test)\n",
401
    "if isinstance(label_probs, torch.Tensor):\n",
402
    "  print('label distribution:\\n', torch.stack(label_probs, dim=1))\n",
403
    "else:\n",
404
    "  print('label distribution:\\n', np.stack(label_probs, axis=1))"
405
   ]
406
  },
407
  {
408
   "cell_type": "markdown",
409
   "metadata": {},
410
   "source": [
411
    "### Optionally randomize true class labels"
412
   ]
413
  },
414
  {
415
   "cell_type": "code",
416
   "execution_count": null,
417
   "metadata": {
418
    "scrolled": true
419
   },
420
   "outputs": [],
421
   "source": [
422
    "if randomize_labels:\n",
423
    "  print('Randomize class labels!')\n",
424
    "  y_train = torch.multinomial(label_prob_train, len(y_train), replacement=True)\n",
425
    "  if len(y_val) > 0:\n",
426
    "    y_val = torch.multinomial(label_prob_val, len(y_val), replacement=True)\n",
427
    "  if len(y_test) > 0:\n",
428
    "    y_test = torch.multinomial(label_prob_test, len(y_test), replacement=True)"
429
   ]
430
  },
431
  {
432
   "cell_type": "markdown",
433
   "metadata": {},
434
   "source": [
435
    "## Neural network models"
436
   ]
437
  },
438
  {
439
   "cell_type": "code",
440
   "execution_count": null,
441
   "metadata": {},
442
   "outputs": [],
443
   "source": [
444
    "# loss_fn_cls = torch.nn.CrossEntropyLoss(weight=torch.tensor([0.3, 0.6], device=device))\n",
445
    "loss_fn_cls = torch.nn.CrossEntropyLoss()\n",
446
    "loss_fn_reg = torch.nn.MSELoss()\n",
447
    "loss_fns = [loss_fn_cls, loss_fn_reg]\n",
448
    "# For multiple data types, there are multiple interaction mats\n",
449
    "feat_interact_loss_type = 'graph_laplacian'\n",
450
    "if num_data_types > 1:\n",
451
    "  weight_path = ['decoders', range(num_data_types), 'weight']  \n",
452
    "else:\n",
453
    "  weight_path = ['decoder', 'weight']\n",
454
    "loss_feat_interact = Loss_feature_interaction(interaction_mat=interaction_mat, \n",
455
    "                                              loss_type=feat_interact_loss_type, \n",
456
    "                                              weight_path=weight_path, \n",
457
    "                                              normalize=True)\n",
458
    "other_loss_fns = [loss_feat_interact]\n",
459
    "if num_data_types > 1:\n",
460
    "  view_sim_loss_type = 'hub'\n",
461
    "  explicit_target = True\n",
462
    "  cal_target='mean-feature'\n",
463
    "  # In this set of experiments, the encoders for all views will have the same hidden_dim\n",
464
    "  loss_view_sim = Loss_view_similarity(sections=hidden_dim[-1], loss_type=view_sim_loss_type, \n",
465
    "    explicit_target=explicit_target, cal_target=cal_target, target=None)\n",
466
    "  loss_fns.append(loss_view_sim)"
467
   ]
468
  },
469
  {
470
   "cell_type": "code",
471
   "execution_count": null,
472
   "metadata": {},
473
   "outputs": [],
474
   "source": [
475
    "model_names = []\n",
476
    "split_names = ['train', 'val', 'test']\n",
477
    "metric_names = ['acc', 'precision', 'recall', 'f1_score', 'adjusted_mutual_info', 'auc', \n",
478
    "                'average_precision']\n",
479
    "metric_all = []\n",
480
    "confusion_mat_all = []\n",
481
    "loss_his_all = []\n",
482
    "acc_his_all = []"
483
   ]
484
  },
485
  {
486
   "cell_type": "code",
487
   "execution_count": null,
488
   "metadata": {},
489
   "outputs": [],
490
   "source": [
491
    "def get_result(model, best_model, best_val_acc, best_epoch, x_train, y_train, x_val, y_val, \n",
492
    "               x_test, y_test, batch_size, multi_heads, show_results_in_notebook=True, \n",
493
    "               loss_idx=0, acc_idx=0):\n",
494
    "  if len(x_val) > 0:\n",
495
    "    print(f'Best model on validation set: best_val_acc={best_val_acc:.2f}, epoch={best_epoch}')\n",
496
    "    metric = eval_classification_multi_splits(best_model, xs=[x_train, x_val, x_test], \n",
497
    "      ys=[y_train, y_val, y_test], batch_size=batch_size, multi_heads=multi_heads)\n",
498
    "\n",
499
    "  if show_results_in_notebook:\n",
500
    "    print('\\nModel after the last training epoch:')\n",
501
    "    eval_classification_multi_splits(model, xs=[x_train, x_val, x_test], \n",
502
    "                                     ys=[y_train, y_val, y_test], batch_size=batch_size, \n",
503
    "                                     multi_heads=multi_heads, return_result=False)\n",
504
    "\n",
505
    "    plot_history_multi_splits([loss_train_his, loss_val_his, loss_test_his], title='Loss', \n",
506
    "                              idx=loss_idx)\n",
507
    "    plot_history_multi_splits([acc_train_his, acc_val_his, acc_test_his], title='Acc', idx=acc_idx)\n",
508
    "    # scatter plot\n",
509
    "    plot_data_multi_splits(best_model, [x_train, x_val, x_test], [y_train, y_val, y_test], \n",
510
    "                           num_heads=2 if multi_heads else 1, \n",
511
    "                           titles=['Training', 'Validation', 'Test'], batch_size=batch_size)\n",
512
    "    return metric"
513
   ]
514
  },
515
  {
516
   "cell_type": "markdown",
517
   "metadata": {},
518
   "source": [
519
    "# Plain deep learning model"
520
   ]
521
  },
522
  {
523
   "cell_type": "code",
524
   "execution_count": null,
525
   "metadata": {},
526
   "outputs": [],
527
   "source": [
528
    "batch_size = 1000\n",
529
    "print_every = 100\n",
530
    "eval_every = 1"
531
   ]
532
  },
533
  {
534
   "cell_type": "code",
535
   "execution_count": null,
536
   "metadata": {},
537
   "outputs": [],
538
   "source": [
539
    "in_dim = x_train.shape[1]\n",
540
    "print('Plain deep learning model')\n",
541
    "model_names.append('NN')\n",
542
    "model = DenseLinear(in_dim, hidden_dim+[num_cls], dense=dense, residual=residual).to(device)\n",
543
    "multi_heads = False\n",
544
    "\n",
545
    "loss_train_his = []\n",
546
    "loss_val_his = []\n",
547
    "loss_test_his = []\n",
548
    "acc_train_his = []\n",
549
    "acc_val_his = []\n",
550
    "acc_test_his = []\n",
551
    "best_model = model\n",
552
    "best_val_acc = 0\n",
553
    "best_epoch = 0"
554
   ]
555
  },
556
  {
557
   "cell_type": "code",
558
   "execution_count": null,
559
   "metadata": {},
560
   "outputs": [],
561
   "source": [
562
    "best_model, best_val_acc, best_epoch = train_single_loss(model, x_train, y_train, \n",
563
    "    x_val, y_val, x_test, y_test, loss_fn=loss_fn_cls, lr=lr, weight_decay=weight_decay, \n",
564
    "    amsgrad=True, batch_size=batch_size, num_epochs=num_epochs, \n",
565
    "    reduce_every=reduce_every, eval_every=eval_every, print_every=print_every, verbose=False, \n",
566
    "    loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his, \n",
567
    "    acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his, \n",
568
    "    return_best_val=True)"
569
   ]
570
  },
571
  {
572
   "cell_type": "code",
573
   "execution_count": null,
574
   "metadata": {},
575
   "outputs": [],
576
   "source": [
577
    "metric = get_result(model, best_model, best_val_acc, best_epoch, x_train, y_train, x_val, y_val, \n",
578
    "                    x_test, y_test, batch_size, multi_heads, show_results_in_notebook, \n",
579
    "                    loss_idx=0, acc_idx=0)"
580
   ]
581
  },
582
  {
583
   "cell_type": "code",
584
   "execution_count": null,
585
   "metadata": {},
586
   "outputs": [],
587
   "source": [
588
    "loss_his_all.append([loss_train_his, loss_val_his, loss_test_his])\n",
589
    "acc_his_all.append([acc_train_his, acc_val_his, acc_test_his])\n",
590
    "metric_all.append([v[0] for v in metric])\n",
591
    "confusion_mat_all.append([v[1] for v in metric])"
592
   ]
593
  },
594
  {
595
   "cell_type": "markdown",
596
   "metadata": {},
597
   "source": [
598
    "# Factorization AutoEncoder"
599
   ]
600
  },
601
  {
602
   "cell_type": "code",
603
   "execution_count": null,
604
   "metadata": {},
605
   "outputs": [],
606
   "source": [
607
    "def run_one_model(model, loss_weights, other_loss_weights, \n",
608
    "                  loss_his_all=[], acc_his_all=[], metric_all=[], confusion_mat_all=[],\n",
609
    "                  heads=[0,1], multi_heads=True, return_results=False, \n",
610
    "                  loss_fns=loss_fns, other_loss_fns=other_loss_fns, \n",
611
    "                  lr=lr, weight_decay=weight_decay, batch_size=batch_size, \n",
612
    "                  num_epochs=num_epochs, reduce_every=reduce_every, eval_every=eval_every, \n",
613
    "                  print_every=print_every, x_train=x_train, y_train=y_train,\n",
614
    "                  x_val=x_val, y_val=y_val, x_test=x_test, y_test=y_test,\n",
615
    "                  show_results_in_notebook=show_results_in_notebook):\n",
616
    "  \"\"\"Train a model and get results  \n",
617
    "    Most of the parameters are from the context; handle it properly\n",
618
    "  \"\"\"\n",
619
    "  loss_train_his = []\n",
620
    "  loss_val_his = []\n",
621
    "  loss_test_his = []\n",
622
    "  acc_train_his = []\n",
623
    "  acc_val_his = []\n",
624
    "  acc_test_his = []\n",
625
    "  best_model = model\n",
626
    "  best_val_acc = 0\n",
627
    "  best_epoch = 0\n",
628
    "\n",
629
    "  best_model, best_val_acc, best_epoch = train_multiloss(model, x_train, [y_train, x_train], \n",
630
    "    x_val, [y_val, x_val], x_test, [y_test, x_test], heads=heads, loss_fns=loss_fns, \n",
631
    "    loss_weights=loss_weights, other_loss_fns=other_loss_fns, \n",
632
    "    other_loss_weights=other_loss_weights, \n",
633
    "    lr=lr, weight_decay=weight_decay, batch_size=batch_size, num_epochs=num_epochs, \n",
634
    "    reduce_every=reduce_every, eval_every=eval_every, print_every=print_every,\n",
635
    "    loss_train_his=loss_train_his, loss_val_his=loss_val_his, loss_test_his=loss_test_his, \n",
636
    "    acc_train_his=acc_train_his, acc_val_his=acc_val_his, acc_test_his=acc_test_his, \n",
637
    "    return_best_val=True, amsgrad=True, verbose=False)\n",
638
    "\n",
639
    "  metric = get_result(model, best_model, best_val_acc, best_epoch, x_train, y_train, x_val, y_val, \n",
640
    "                      x_test, y_test, batch_size, multi_heads, show_results_in_notebook, \n",
641
    "                      loss_idx=0, acc_idx=0)\n",
642
    "\n",
643
    "  loss_his_all.append([loss_train_his, loss_val_his, loss_test_his])\n",
644
    "  acc_his_all.append([acc_train_his, acc_val_his, acc_test_his])\n",
645
    "  metric_all.append([v[0] for v in metric])\n",
646
    "  confusion_mat_all.append([v[1] for v in metric])\n",
647
    "  \n",
648
    "  if return_results:\n",
649
    "    return loss_his_all, acc_his_all, metric_all, confusion_mat_all"
650
   ]
651
  },
652
  {
653
   "cell_type": "code",
654
   "execution_count": null,
655
   "metadata": {},
656
   "outputs": [],
657
   "source": [
658
    "decoder_norm = False\n",
659
    "uniform_decoder_norm = False\n",
660
    "print('Plain AutoEncoder model')\n",
661
    "model_names.append('AE')\n",
662
    "model = AutoEncoder(in_dim, hidden_dim, num_cls, dense=dense, residual=residual,\n",
663
    "          decoder_norm=decoder_norm, uniform_decoder_norm=uniform_decoder_norm).to(device)\n",
664
    "loss_weights = [1,1]\n",
665
    "other_loss_weights = [0]\n",
666
    "# heads = None should work for all the following; keep this for clarity\n",
667
    "heads = [0,1] \n",
668
    "run_one_model(model, loss_weights, other_loss_weights,\n",
669
    "              loss_his_all, acc_his_all, metric_all, confusion_mat_all,\n",
670
    "              heads=heads, multi_heads=True, return_results=False, \n",
671
    "              loss_fns=loss_fns, other_loss_fns=other_loss_fns, \n",
672
    "              lr=lr, weight_decay=weight_decay, batch_size=batch_size, \n",
673
    "              num_epochs=num_epochs, reduce_every=reduce_every, eval_every=eval_every, \n",
674
    "              print_every=print_every, x_train=x_train, y_train=y_train,\n",
675
    "              x_val=x_val, y_val=y_val, x_test=x_test, y_test=y_test,\n",
676
    "              show_results_in_notebook=show_results_in_notebook)"
677
   ]
678
  },
679
  {
680
   "cell_type": "markdown",
681
   "metadata": {},
682
   "source": [
683
    "## Add feature interaction network regularizer"
684
   ]
685
  },
686
  {
687
   "cell_type": "code",
688
   "execution_count": null,
689
   "metadata": {},
690
   "outputs": [],
691
   "source": [
692
    "if num_data_types > 1:\n",
693
    "  fuse_type = 'sum'\n",
694
    "  print('MultiviewAE with feature interaction network regularizer')\n",
695
    "  model_names.append('MultiviewAE + feat_int')\n",
696
    "  model = MultiviewAE(in_dims=in_dims, hidden_dims=hidden_dim, out_dim=num_cls, \n",
697
    "                      fuse_type=fuse_type, dense=dense, residual=residual, \n",
698
    "                      residual_layers='all', decoder_norm=decoder_norm, \n",
699
    "                      decoder_norm_dim=0, uniform_decoder_norm=uniform_decoder_norm, \n",
700
    "                      nonlinearity=nn.ReLU(), last_nonlinearity=True, bias=True).to(device)\n",
701
    "else:\n",
702
    "  print('AutoEncoder with feature interaction network regularizer')\n",
703
    "  model_names.append('AE + feat_int')\n",
704
    "  model = AutoEncoder(in_dim, hidden_dim, num_cls, dense=dense, residual=residual, \n",
705
    "          decoder_norm=decoder_norm, uniform_decoder_norm=uniform_decoder_norm).to(device)\n",
706
    "\n",
707
    "loss_weights = [1,1]\n",
708
    "other_loss_weights = [1]\n",
709
    "heads = [0,1]\n",
710
    "run_one_model(model, loss_weights, other_loss_weights, \n",
711
    "              loss_his_all, acc_his_all, metric_all, confusion_mat_all,\n",
712
    "              heads=heads, multi_heads=True, return_results=False, \n",
713
    "              loss_fns=loss_fns, other_loss_fns=other_loss_fns, \n",
714
    "              lr=lr, weight_decay=weight_decay, batch_size=batch_size, \n",
715
    "              num_epochs=num_epochs, reduce_every=reduce_every, eval_every=eval_every, \n",
716
    "              print_every=print_every, x_train=x_train, y_train=y_train,\n",
717
    "              x_val=x_val, y_val=y_val, x_test=x_test, y_test=y_test,\n",
718
    "              show_results_in_notebook=show_results_in_notebook)"
719
   ]
720
  },
721
  {
722
   "cell_type": "markdown",
723
   "metadata": {},
724
   "source": [
725
    "## For multi-view data, add view similarity network regularizer"
726
   ]
727
  },
728
  {
729
   "cell_type": "code",
730
   "execution_count": null,
731
   "metadata": {},
732
   "outputs": [],
733
   "source": [
734
    "if num_data_types > 1:\n",
735
    "  # plain multiviewAE; compare it with plain AutoEncoder to see \n",
736
    "  # if separating views in lower layers in MultiviewAE is better than combining them all the way\n",
737
    "  print('Run plain MultiviewAE model')\n",
738
    "  model_names.append('MultiviewAE')\n",
739
    "  model = MultiviewAE(in_dims=in_dims, hidden_dims=hidden_dim, out_dim=num_cls, \n",
740
    "                    fuse_type=fuse_type, dense=dense, residual=residual, \n",
741
    "                    residual_layers='all', decoder_norm=decoder_norm, \n",
742
    "                    decoder_norm_dim=0, uniform_decoder_norm=uniform_decoder_norm, \n",
743
    "                    nonlinearity=nn.ReLU(), last_nonlinearity=True, bias=True).to(device)\n",
744
    "\n",
745
    "  loss_weights = [1,1]\n",
746
    "  other_loss_weights = [0]\n",
747
    "  heads = [0,1]\n",
748
    "  run_one_model(model, loss_weights, other_loss_weights, \n",
749
    "                loss_his_all, acc_his_all, metric_all, confusion_mat_all,\n",
750
    "                heads=heads, multi_heads=True, return_results=False, \n",
751
    "                loss_fns=loss_fns, other_loss_fns=other_loss_fns, \n",
752
    "                lr=lr, weight_decay=weight_decay, batch_size=batch_size, \n",
753
    "                num_epochs=num_epochs, reduce_every=reduce_every, eval_every=eval_every, \n",
754
    "                print_every=print_every, x_train=x_train, y_train=y_train,\n",
755
    "                x_val=x_val, y_val=y_val, x_test=x_test, y_test=y_test,\n",
756
    "                show_results_in_notebook=show_results_in_notebook)"
757
   ]
758
  },
759
  {
760
   "cell_type": "code",
761
   "execution_count": null,
762
   "metadata": {},
763
   "outputs": [],
764
   "source": [
765
    "if num_data_types > 1:\n",
766
    "  print('MultiviewAE with view similarity regularizers')\n",
767
    "  model_names.append('MultiviewAE + view_sim')\n",
768
    "  model = MultiviewAE(in_dims=in_dims, hidden_dims=hidden_dim, out_dim=num_cls, \n",
769
    "                      fuse_type=fuse_type, dense=dense, residual=residual, \n",
770
    "                      residual_layers='all', decoder_norm=decoder_norm, \n",
771
    "                      decoder_norm_dim=0, uniform_decoder_norm=uniform_decoder_norm, \n",
772
    "                      nonlinearity=nn.ReLU(), last_nonlinearity=True, bias=True).to(device)\n",
773
    "  loss_weights = [1,1,1]\n",
774
    "  other_loss_weights = [0]\n",
775
    "  heads = [0,1,2]\n",
776
    "  run_one_model(model, loss_weights, other_loss_weights, \n",
777
    "                loss_his_all, acc_his_all, metric_all, confusion_mat_all,\n",
778
    "                heads=heads, multi_heads=True, return_results=False, \n",
779
    "                loss_fns=loss_fns, other_loss_fns=other_loss_fns, \n",
780
    "                lr=lr, weight_decay=weight_decay, batch_size=batch_size, \n",
781
    "                num_epochs=num_epochs, reduce_every=reduce_every, eval_every=eval_every, \n",
782
    "                print_every=print_every, x_train=x_train, y_train=y_train,\n",
783
    "                x_val=x_val, y_val=y_val, x_test=x_test, y_test=y_test,\n",
784
    "                show_results_in_notebook=show_results_in_notebook)"
785
   ]
786
  },
787
  {
788
   "cell_type": "code",
789
   "execution_count": null,
790
   "metadata": {},
791
   "outputs": [],
792
   "source": [
793
    "if num_data_types > 1:\n",
794
    "  print('MultiviewAE with both feature interaction and view similarity regularizers')\n",
795
    "  model_names.append('MultiviewAE + feat_int + view_sim')\n",
796
    "  model = MultiviewAE(in_dims=in_dims, hidden_dims=hidden_dim, out_dim=num_cls, \n",
797
    "                      fuse_type=fuse_type, dense=dense, residual=residual, \n",
798
    "                      residual_layers='all', decoder_norm=decoder_norm, \n",
799
    "                      decoder_norm_dim=0, uniform_decoder_norm=uniform_decoder_norm, \n",
800
    "                      nonlinearity=nn.ReLU(), last_nonlinearity=True, bias=True).to(device)\n",
801
    "  loss_weights = [1,1,1]\n",
802
    "  other_loss_weights = [1]\n",
803
    "  heads = [0,1,2]\n",
804
    "  run_one_model(model, loss_weights, other_loss_weights,\n",
805
    "                loss_his_all, acc_his_all, metric_all, confusion_mat_all,\n",
806
    "                heads=heads, multi_heads=True, return_results=False, \n",
807
    "                loss_fns=loss_fns, other_loss_fns=other_loss_fns, \n",
808
    "                lr=lr, weight_decay=weight_decay, batch_size=batch_size, \n",
809
    "                num_epochs=num_epochs, reduce_every=reduce_every, eval_every=eval_every, \n",
810
    "                print_every=print_every, x_train=x_train, y_train=y_train,\n",
811
    "                x_val=x_val, y_val=y_val, x_test=x_test, y_test=y_test,\n",
812
    "                show_results_in_notebook=show_results_in_notebook)"
813
   ]
814
  },
815
  {
816
   "cell_type": "code",
817
   "execution_count": null,
818
   "metadata": {},
819
   "outputs": [],
820
   "source": [
821
    "with open(f'{result_folder}/{res_file}', 'wb') as f:\n",
822
    "  print(f'Write result to file {result_folder}/{res_file}')\n",
823
    "  pickle.dump({'loss_his_all': loss_his_all,\n",
824
    "               'acc_his_all': acc_his_all,\n",
825
    "               'metric_all': metric_all,\n",
826
    "               'confusion_mat_all': confusion_mat_all,\n",
827
    "               'model_names': model_names,\n",
828
    "               'split_names': split_names,\n",
829
    "               'metric_names': metric_names\n",
830
    "              }, f)"
831
   ]
832
  }
833
 ],
834
 "metadata": {
835
  "kernelspec": {
836
   "display_name": "Python 3",
837
   "language": "python",
838
   "name": "python3"
839
  },
840
  "language_info": {
841
   "codemirror_mode": {
842
    "name": "ipython",
843
    "version": 3
844
   },
845
   "file_extension": ".py",
846
   "mimetype": "text/x-python",
847
   "name": "python",
848
   "nbconvert_exporter": "python",
849
   "pygments_lexer": "ipython3",
850
   "version": "3.6.5"
851
  }
852
 },
853
 "nbformat": 4,
854
 "nbformat_minor": 2
855
}