Diff of /workflow-v2.ipynb [000000] .. [4807fa]

Switch to side-by-side view

--- a
+++ b/workflow-v2.ipynb
@@ -0,0 +1,392 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "import sys\n",
+    "import subprocess\n",
+    "import collections\n",
+    "import time\n",
+    "import nbformat\n",
+    "import socket\n",
+    "import re\n",
+    "import pickle\n",
+    "\n",
+    "import numpy as np\n",
+    "import sklearn.metrics\n",
+    "\n",
+    "import torch\n",
+    "\n",
+    "lib_path = 'I:/code'\n",
+    "if not os.path.exists(lib_path):\n",
+    "  lib_path = '/media/6T/.tianle/.lib'\n",
+    "if not os.path.exists(lib_path):\n",
+    "  lib_path = '/projects/academic/azhang/tianlema/lib'\n",
+    "if os.path.exists(lib_path) and lib_path not in sys.path:\n",
+    "  sys.path.append(lib_path)\n",
+    "  \n",
+    "from dl.utils.visualization.visualization import *\n",
+    "from dl.utils.train import eval_classification, get_label_prob\n",
+    "from dl.utils.utils import *\n",
+    "\n",
+    "%load_ext autoreload\n",
+    "%autoreload 2"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def submit_job(model_type='nn', dense=False, residual=True, hidden_dim=[500, 500], \n",
+    "               train_portion=0.7, val_portion=0.1, test_portion=0.2, \n",
+    "               num_sets=10, num_folds=10, sel_set_idx=0,\n",
+    "               num_train_types=-1, \n",
+    "               num_val_types=-1,\n",
+    "               num_test_types=-1,\n",
+    "               cv_type='instance-shuffle',\n",
+    "               sel_disease_types='all', \n",
+    "               min_num_samples_per_type_cls=[100, 0],\n",
+    "               predefined_sample_set_file='auto-search',\n",
+    "               target_variable='PFI',\n",
+    "               target_variable_type='discrete',\n",
+    "               target_variable_range=[0, 1],\n",
+    "               data_type=['gene', 'mirna', 'methy', 'rppa'], \n",
+    "               additional_vars=[],#['age_at_initial_pathologic_diagnosis', 'gender']\n",
+    "               additional_var_types=[],#['continuous', 'discrete']\n",
+    "               additional_var_ranges=[],\n",
+    "               normal_transform_feature=True, \n",
+    "               randomize_labels=False,\n",
+    "               lr=5e-4,\n",
+    "               weight_decay=1e-4,\n",
+    "               num_epochs=1000,\n",
+    "               reduce_every=500,\n",
+    "               show_results_in_notebook=True, \n",
+    "               idx_folder='results/data_split_idx', # no longer used\n",
+    "               notebook_folder='.', \n",
+    "               template_file='exp_template.ipynb', \n",
+    "               slurm_script='../gpu-slurm', \n",
+    "               new_file=True, submit=True,\n",
+    "               cell_idx=2, gpu_id=3):\n",
+    "  \"\"\"Create notebook and run it on dlm or submit to ccr slurm\n",
+    "  \"\"\"\n",
+    "  # This is for filename\n",
+    "  if sel_disease_types == 'all':\n",
+    "    sel_disease_type_str = 'all' \n",
+    "  else:\n",
+    "    sel_disease_type_str = '-'.join(sorted(sel_disease_types))\n",
+    "  if isinstance(data_type, str):\n",
+    "    data_type_str = data_type\n",
+    "  else:\n",
+    "    data_type_str = '-'.join(sorted(data_type))\n",
+    "  if model_type == 'nn': # model_type, dense, residual are dependent\n",
+    "    assert not (residual and dense)\n",
+    "    if residual:\n",
+    "      model_type = 'resnet' \n",
+    "    if dense:\n",
+    "      model_type = 'densenet'\n",
+    "  \n",
+    "  args = {'model_type': model_type, # model_type may be different from the argument\n",
+    "          'dense': dense,\n",
+    "          'residual': residual,\n",
+    "          'hidden_dim': hidden_dim,\n",
+    "          'train_portion': train_portion,\n",
+    "          'val_portion': val_portion,\n",
+    "          'test_portion': test_portion,\n",
+    "          'num_sets': num_sets,\n",
+    "          'num_folds': num_folds,\n",
+    "          'num_train_types': num_train_types, \n",
+    "          'num_val_types': num_val_types,\n",
+    "          'num_test_types': num_test_types,\n",
+    "          'cv_type': cv_type,\n",
+    "          'sel_set_idx': sel_set_idx,\n",
+    "          'sel_disease_types': sel_disease_types,\n",
+    "          'min_num_samples_per_type_cls': min_num_samples_per_type_cls,\n",
+    "          'predefined_sample_set_file': predefined_sample_set_file,\n",
+    "          'target_variable': target_variable,\n",
+    "          'target_variable_type': target_variable_type,\n",
+    "          'target_variable_range': target_variable_range,\n",
+    "          'data_type': data_type,\n",
+    "          'additional_vars': additional_vars,#['age_at_initial_pathologic_diagnosis', 'gender']\n",
+    "          'additional_var_types': additional_var_types,#['continuous', 'discrete']\n",
+    "          'additional_var_ranges': additional_var_ranges,\n",
+    "          'normal_transform_feature': normal_transform_feature,\n",
+    "          'randomize_labels': randomize_labels,\n",
+    "          'lr': lr,\n",
+    "          'weight_decay': weight_decay,\n",
+    "          'num_epochs': num_epochs,\n",
+    "          'reduce_every': reduce_every,\n",
+    "          'show_results_in_notebook': show_results_in_notebook\n",
+    "         }\n",
+    "  \n",
+    "  predefined_sample_set_filename = (target_variable if isinstance(target_variable,str) \n",
+    "                                else '-'.join(target_variable))\n",
+    "  predefined_sample_set_filename += f'_{cv_type}'\n",
+    "  if len(additional_vars) > 0:\n",
+    "    predefined_sample_set_filename += f\"_{'-'.join(sorted(additional_vars))}\"\n",
+    "  predefined_sample_set_filename += (f\"_{data_type_str}_{sel_disease_type_str}_\"\n",
+    "                                     f\"{'-'.join(map(str, min_num_samples_per_type_cls))}\")\n",
+    "  predefined_sample_set_filename += f\"_{'-'.join(map(str, [train_portion, val_portion, test_portion]))}\"\n",
+    "  if cv_type == 'group-shuffle' and num_train_types > 0:\n",
+    "    predefined_sample_set_filename += f\"_{'-'.join(map(str, [num_train_types, num_val_types, num_test_types]))}\"\n",
+    "  predefined_sample_set_filename += f'_{num_sets}sets'\n",
+    "  filename_prefix = f\"{predefined_sample_set_filename}_{sel_set_idx}_{'-'.join(map(str, hidden_dim))}_{model_type}\"\n",
+    "  filename = f'{filename_prefix}.ipynb'\n",
+    "  nb = nbformat.read(f'{notebook_folder}/{template_file}', 4)\n",
+    "  nb['cells'][0]['source'] = (\"import socket\\nif socket.gethostname() == 'dlm':\\n\"\n",
+    "                              \"  %env CUDA_DEVICE_ORDER=PCI_BUS_ID\\n\"\n",
+    "                              f\"  %env CUDA_VISIBLE_DEVICES={gpu_id}\")\n",
+    "  nb['cells'][cell_idx]['source'] = '\\n'.join(\n",
+    "    [f\"{k} = '{v}'\" if isinstance(v, str) else f'{k} = {v}' for k, v in args.items()])\n",
+    "  if os.path.exists(f'{notebook_folder}/{filename}'):\n",
+    "    print(f'To overwrite file {notebook_folder}/{filename}')\n",
+    "  else:\n",
+    "    print(f'To create file {notebook_folder}/{filename}')\n",
+    "  if new_file:\n",
+    "    nbformat.write(nb, f'{notebook_folder}/{filename}')\n",
+    "  \n",
+    "  if submit: # sometimes I just want to create files\n",
+    "    if re.search('ccr.buffalo.edu$', socket.gethostname()):\n",
+    "      command = f'sbatch {slurm_script} {notebook_folder}/{filename} {filename}'\n",
+    "      subprocess.run(command, shell=True)\n",
+    "      print(command)\n",
+    "    else:\n",
+    "      command = ['jupyter nbconvert', '--ExecutePreprocessor.timeout=360000',\n",
+    "               '--ExecutePreprocessor.allow_errors=True', '--to notebook', '--execute']\n",
+    "      command.append(f'{notebook_folder}/{filename} --output {filename}')\n",
+    "      command = ' '.join(command)\n",
+    "      start_time = time.time()\n",
+    "      tmp = subprocess.run(command, shell=True)\n",
+    "      end_time = time.time()\n",
+    "      print(f'Time spent: {end_time-start_time:.2f}')\n",
+    "  return filename_prefix"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "BLCA 0 [191, 147]\n",
+      "BRCA 1 [740, 107]\n",
+      "KIRC 6 [295, 151]\n",
+      "LGG 8 [273, 150]\n",
+      "LUAD 10 [144, 211]\n",
+      "LUSC 11 [102, 202]\n",
+      "SARC 16 [115, 102]\n",
+      "STAD 17 [106, 224]\n"
+     ]
+    }
+   ],
+   "source": [
+    "data_folder = '../../pan-can-atlas/data/processed'\n",
+    "if not os.path.exists(data_folder):\n",
+    "  data_folder = 'F:/TCGA/Pan-Cancer-Atlas/data/processed'\n",
+    "with open(f'{data_folder}/sel_patient_clinical.pkl', 'rb') as f:\n",
+    "  data = pickle.load(f)\n",
+    "  disease_types = data['disease_types']\n",
+    "  disease_type_dict = data['disease_type_dict']\n",
+    "  pfi = data['pfi']\n",
+    "disease_stats = {}\n",
+    "for idx, name in disease_type_dict.items():\n",
+    "  cnt = list(collections.Counter(pfi[disease_types==idx]).values())\n",
+    "  if cnt[0] > 100 and cnt[1] > 100:\n",
+    "    disease_stats[idx] = f'{name}: {cnt}'\n",
+    "    print(name, idx, cnt)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# additional_vars=[],#['age_at_initial_pathologic_diagnosis', 'gender']\n",
+    "# additional_var_types=[],#['continuous', 'discrete']\n",
+    "# additional_var_ranges=[],\n",
+    "\n",
+    "additional_vars = ['age_at_initial_pathologic_diagnosis', 'gender', 'ajcc_pathologic_tumor_stage']\n",
+    "additional_var_types = ['continuous', 'discrete', 'discrete']\n",
+    "additional_var_ranges = [[0, 100], ['MALE', 'FEMALE'], \n",
+    "                         ['I/II NOS', 'IS', 'Stage 0', 'Stage I', 'Stage IA', 'Stage IB', \n",
+    "                          'Stage II', 'Stage IIA', 'Stage IIB', 'Stage IIC', 'Stage III',\n",
+    "                          'Stage IIIA', 'Stage IIIB', 'Stage IIIC', 'Stage IV', 'Stage IVA',\n",
+    "                          'Stage IVB', 'Stage IVC', 'Stage X']]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "To create file ./DFI_instance-shuffle_age_at_initial_pathologic_diagnosis-ajcc_pathologic_tumor_stage-gender_gene-methy-mirna-rppa_all_100-0_0.7-0.1-0.2_10sets_0_100-100_nn.ipynb\n",
+      "To create file ./DFI_instance-shuffle_age_at_initial_pathologic_diagnosis-ajcc_pathologic_tumor_stage-gender_gene-methy-mirna-rppa_all_100-0_0.7-0.1-0.2_10sets_1_100-100_nn.ipynb\n",
+      "To create file ./DFI_instance-shuffle_age_at_initial_pathologic_diagnosis-ajcc_pathologic_tumor_stage-gender_gene-methy-mirna-rppa_all_100-0_0.7-0.1-0.2_10sets_2_100-100_nn.ipynb\n",
+      "To create file ./DFI_instance-shuffle_age_at_initial_pathologic_diagnosis-ajcc_pathologic_tumor_stage-gender_gene-methy-mirna-rppa_all_100-0_0.7-0.1-0.2_10sets_3_100-100_nn.ipynb\n",
+      "To create file ./DFI_instance-shuffle_age_at_initial_pathologic_diagnosis-ajcc_pathologic_tumor_stage-gender_gene-methy-mirna-rppa_all_100-0_0.7-0.1-0.2_10sets_4_100-100_nn.ipynb\n",
+      "To create file ./DFI_instance-shuffle_age_at_initial_pathologic_diagnosis-ajcc_pathologic_tumor_stage-gender_gene-methy-mirna-rppa_all_100-0_0.7-0.1-0.2_10sets_5_100-100_nn.ipynb\n",
+      "To create file ./DFI_instance-shuffle_age_at_initial_pathologic_diagnosis-ajcc_pathologic_tumor_stage-gender_gene-methy-mirna-rppa_all_100-0_0.7-0.1-0.2_10sets_6_100-100_nn.ipynb\n",
+      "To create file ./DFI_instance-shuffle_age_at_initial_pathologic_diagnosis-ajcc_pathologic_tumor_stage-gender_gene-methy-mirna-rppa_all_100-0_0.7-0.1-0.2_10sets_7_100-100_nn.ipynb\n",
+      "To create file ./DFI_instance-shuffle_age_at_initial_pathologic_diagnosis-ajcc_pathologic_tumor_stage-gender_gene-methy-mirna-rppa_all_100-0_0.7-0.1-0.2_10sets_8_100-100_nn.ipynb\n",
+      "To create file ./DFI_instance-shuffle_age_at_initial_pathologic_diagnosis-ajcc_pathologic_tumor_stage-gender_gene-methy-mirna-rppa_all_100-0_0.7-0.1-0.2_10sets_9_100-100_nn.ipynb\n"
+     ]
+    }
+   ],
+   "source": [
+    "for i in ['all']:#[0, 1, 6, 8, 10, 11, 16, 17]:\n",
+    "  for j in range(10):\n",
+    "      for dtype in [['gene', 'mirna', 'rppa', 'methy']]:\n",
+    "        submit_job(model_type='nn', dense=False, residual=False, hidden_dim=[100,100], \n",
+    "               train_portion=0.7, val_portion=0.1, test_portion=0.2, \n",
+    "               num_sets=10, num_folds=10, sel_set_idx=j,\n",
+    "               num_train_types=-1, \n",
+    "               num_val_types=-1,\n",
+    "               num_test_types=-1,\n",
+    "               cv_type='instance-shuffle',\n",
+    "               sel_disease_types=i, \n",
+    "               min_num_samples_per_type_cls=[100, 0],\n",
+    "               predefined_sample_set_file='auto-search',\n",
+    "               target_variable='DFI',\n",
+    "               target_variable_type='discrete',\n",
+    "               target_variable_range=[0,1],\n",
+    "               data_type=dtype, \n",
+    "               additional_vars=additional_vars,\n",
+    "               additional_var_types=additional_var_types,\n",
+    "               additional_var_ranges=additional_var_ranges,\n",
+    "               normal_transform_feature=True, \n",
+    "               randomize_labels=False,\n",
+    "               lr=5e-4,\n",
+    "               weight_decay=1e-4,\n",
+    "               num_epochs=100,\n",
+    "               reduce_every=500,\n",
+    "               show_results_in_notebook=True, \n",
+    "               idx_folder='results/data_split_idx', # no longer used\n",
+    "               notebook_folder='.', \n",
+    "               template_file='exp_template-mv-nn-v3.ipynb', \n",
+    "               slurm_script='../run-slurm', \n",
+    "               new_file=False, submit=False,\n",
+    "               cell_idx=2, gpu_id=1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 199,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def load_results(disease_type_str = '0', #0-1-6-8-10-11-16-17\n",
+    "                  model_name = 'ml',\n",
+    "                  sel_set_idx = 0,\n",
+    "                  data_type_str = 'gene-mirna-rppa-methy',\n",
+    "                  data_split_str = '70-10-20',\n",
+    "                  hidden_dim_str = '100-100',\n",
+    "                  filefolder = 'results',\n",
+    "                  target_variable = 'pfi',\n",
+    "                  return_variable='metric_all',\n",
+    "                  filename=None, plot_acc=True, plot_loss=True):\n",
+    "  if filename is None:\n",
+    "    filename = (f'{filefolder}/{disease_type_str}_{data_type_str}_set{sel_set_idx}' \n",
+    "                f'_{data_split_str}_{target_variable}_{hidden_dim_str}_{model_name}.pkl')\n",
+    "    \n",
+    "  with open(filename, 'rb') as f:\n",
+    "    data = pickle.load(f)\n",
+    "  if return_variable in data:\n",
+    "    return np.array(data[return_variable])\n",
+    "  metric = np.array(data['metric_all'])\n",
+    "  confusion_mat = np.array(data['confusion_mat_all'])\n",
+    "  model_names, split_names, metric_names = (data['model_names'], data['split_names'], \n",
+    "                                            data['metric_names'])\n",
+    "  # sanity check\n",
+    "  assert metric.shape == (len(model_names), len(split_names), len(metric_names))\n",
+    "  assert confusion_mat.shape[:2] == (len(model_names), len(split_names))\n",
+    "  loss_his = data['loss_his_all']\n",
+    "  acc_his = np.array(data['acc_his_all'])\n",
+    "  title =  disease_type_str if len(disease_type_str)>2 else disease_stats[int(disease_type_str)]\n",
+    "  if plot_acc and len(acc_his)>0:\n",
+    "    for i, n in enumerate(split_names):\n",
+    "      plot_history(acc_his[:, i].T, title=f'{title} {n} acc', \n",
+    "                   indices=None, colors='rgbkmc', markers='ov+*,<',\n",
+    "                       labels=model_names, linestyles=['']*6, markersize=3)\n",
+    "    for i, n in enumerate(model_names):\n",
+    "      plot_history(acc_his[i].T, title=f'{title} {n} acc', \n",
+    "                   indices=None, colors='rgbkmc', markers='ov+*,<',\n",
+    "                       labels=split_names, linestyles=['']*6, markersize=3)\n",
+    "  if plot_loss and len(loss_his)>0:\n",
+    "    for i, n in enumerate(model_names):\n",
+    "      history = np.array(loss_his[i])\n",
+    "      if history.ndim == 2:\n",
+    "        plot_history(history.T, title=f'{title} {n} loss', indices=None, colors='rgbkmc', \n",
+    "                     markers='ov+*,<',\n",
+    "                       labels=split_names, linestyles=['']*6, markersize=3)\n",
+    "      elif history.ndim == 3:\n",
+    "        for j in range(history.shape[2]):\n",
+    "           plot_history(history[:,:,j].T, title=f'{title} {n} loss{j}', indices=None, \n",
+    "                        colors='rgbkmc', markers='ov+*,<',\n",
+    "                       labels=split_names, linestyles=['']*6, markersize=3)\n",
+    "      else:\n",
+    "        raise ValueError(f'{filename} {n} loss has unexpected shape')\n",
+    "  if return_variable == 'all':\n",
+    "    return metric, confusion_mat, model_names, split_names, metric_names, acc_his, loss_his"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 206,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "array(['acc', 'precision', 'recall', 'f1_score', 'adjusted_mutual_info',\n",
+       "       'auc', 'average_precision'], dtype='<U20')"
+      ]
+     },
+     "execution_count": 206,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "load_results(return_variable='metric_names', \n",
+    "  filename=(f'results/PFI_instance-shuffle_gene-methy-mirna-rppa_all'\n",
+    "            f'_100-0_0.7-0.1-0.2_10sets_0_100-100_nn.pkl'))"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.5"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}