[4807fa]: / workflow-v2.ipynb

Download this file

393 lines (392 with data), 18.5 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import subprocess\n",
    "import collections\n",
    "import time\n",
    "import nbformat\n",
    "import socket\n",
    "import re\n",
    "import pickle\n",
    "\n",
    "import numpy as np\n",
    "import sklearn.metrics\n",
    "\n",
    "import torch\n",
    "\n",
    "lib_path = 'I:/code'\n",
    "if not os.path.exists(lib_path):\n",
    "  lib_path = '/media/6T/.tianle/.lib'\n",
    "if not os.path.exists(lib_path):\n",
    "  lib_path = '/projects/academic/azhang/tianlema/lib'\n",
    "if os.path.exists(lib_path) and lib_path not in sys.path:\n",
    "  sys.path.append(lib_path)\n",
    "  \n",
    "from dl.utils.visualization.visualization import *\n",
    "from dl.utils.train import eval_classification, get_label_prob\n",
    "from dl.utils.utils import *\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def submit_job(model_type='nn', dense=False, residual=True, hidden_dim=[500, 500], \n",
    "               train_portion=0.7, val_portion=0.1, test_portion=0.2, \n",
    "               num_sets=10, num_folds=10, sel_set_idx=0,\n",
    "               num_train_types=-1, \n",
    "               num_val_types=-1,\n",
    "               num_test_types=-1,\n",
    "               cv_type='instance-shuffle',\n",
    "               sel_disease_types='all', \n",
    "               min_num_samples_per_type_cls=[100, 0],\n",
    "               predefined_sample_set_file='auto-search',\n",
    "               target_variable='PFI',\n",
    "               target_variable_type='discrete',\n",
    "               target_variable_range=[0, 1],\n",
    "               data_type=['gene', 'mirna', 'methy', 'rppa'], \n",
    "               additional_vars=[],#['age_at_initial_pathologic_diagnosis', 'gender']\n",
    "               additional_var_types=[],#['continuous', 'discrete']\n",
    "               additional_var_ranges=[],\n",
    "               normal_transform_feature=True, \n",
    "               randomize_labels=False,\n",
    "               lr=5e-4,\n",
    "               weight_decay=1e-4,\n",
    "               num_epochs=1000,\n",
    "               reduce_every=500,\n",
    "               show_results_in_notebook=True, \n",
    "               idx_folder='results/data_split_idx', # no longer used\n",
    "               notebook_folder='.', \n",
    "               template_file='exp_template.ipynb', \n",
    "               slurm_script='../gpu-slurm', \n",
    "               new_file=True, submit=True,\n",
    "               cell_idx=2, gpu_id=3):\n",
    "  \"\"\"Create notebook and run it on dlm or submit to ccr slurm\n",
    "  \"\"\"\n",
    "  # This is for filename\n",
    "  if sel_disease_types == 'all':\n",
    "    sel_disease_type_str = 'all' \n",
    "  else:\n",
    "    sel_disease_type_str = '-'.join(sorted(sel_disease_types))\n",
    "  if isinstance(data_type, str):\n",
    "    data_type_str = data_type\n",
    "  else:\n",
    "    data_type_str = '-'.join(sorted(data_type))\n",
    "  if model_type == 'nn': # model_type, dense, residual are dependent\n",
    "    assert not (residual and dense)\n",
    "    if residual:\n",
    "      model_type = 'resnet' \n",
    "    if dense:\n",
    "      model_type = 'densenet'\n",
    "  \n",
    "  args = {'model_type': model_type, # model_type may be different from the argument\n",
    "          'dense': dense,\n",
    "          'residual': residual,\n",
    "          'hidden_dim': hidden_dim,\n",
    "          'train_portion': train_portion,\n",
    "          'val_portion': val_portion,\n",
    "          'test_portion': test_portion,\n",
    "          'num_sets': num_sets,\n",
    "          'num_folds': num_folds,\n",
    "          'num_train_types': num_train_types, \n",
    "          'num_val_types': num_val_types,\n",
    "          'num_test_types': num_test_types,\n",
    "          'cv_type': cv_type,\n",
    "          'sel_set_idx': sel_set_idx,\n",
    "          'sel_disease_types': sel_disease_types,\n",
    "          'min_num_samples_per_type_cls': min_num_samples_per_type_cls,\n",
    "          'predefined_sample_set_file': predefined_sample_set_file,\n",
    "          'target_variable': target_variable,\n",
    "          'target_variable_type': target_variable_type,\n",
    "          'target_variable_range': target_variable_range,\n",
    "          'data_type': data_type,\n",
    "          'additional_vars': additional_vars,#['age_at_initial_pathologic_diagnosis', 'gender']\n",
    "          'additional_var_types': additional_var_types,#['continuous', 'discrete']\n",
    "          'additional_var_ranges': additional_var_ranges,\n",
    "          'normal_transform_feature': normal_transform_feature,\n",
    "          'randomize_labels': randomize_labels,\n",
    "          'lr': lr,\n",
    "          'weight_decay': weight_decay,\n",
    "          'num_epochs': num_epochs,\n",
    "          'reduce_every': reduce_every,\n",
    "          'show_results_in_notebook': show_results_in_notebook\n",
    "         }\n",
    "  \n",
    "  predefined_sample_set_filename = (target_variable if isinstance(target_variable,str) \n",
    "                                else '-'.join(target_variable))\n",
    "  predefined_sample_set_filename += f'_{cv_type}'\n",
    "  if len(additional_vars) > 0:\n",
    "    predefined_sample_set_filename += f\"_{'-'.join(sorted(additional_vars))}\"\n",
    "  predefined_sample_set_filename += (f\"_{data_type_str}_{sel_disease_type_str}_\"\n",
    "                                     f\"{'-'.join(map(str, min_num_samples_per_type_cls))}\")\n",
    "  predefined_sample_set_filename += f\"_{'-'.join(map(str, [train_portion, val_portion, test_portion]))}\"\n",
    "  if cv_type == 'group-shuffle' and num_train_types > 0:\n",
    "    predefined_sample_set_filename += f\"_{'-'.join(map(str, [num_train_types, num_val_types, num_test_types]))}\"\n",
    "  predefined_sample_set_filename += f'_{num_sets}sets'\n",
    "  filename_prefix = f\"{predefined_sample_set_filename}_{sel_set_idx}_{'-'.join(map(str, hidden_dim))}_{model_type}\"\n",
    "  filename = f'{filename_prefix}.ipynb'\n",
    "  nb = nbformat.read(f'{notebook_folder}/{template_file}', 4)\n",
    "  nb['cells'][0]['source'] = (\"import socket\\nif socket.gethostname() == 'dlm':\\n\"\n",
    "                              \"  %env CUDA_DEVICE_ORDER=PCI_BUS_ID\\n\"\n",
    "                              f\"  %env CUDA_VISIBLE_DEVICES={gpu_id}\")\n",
    "  nb['cells'][cell_idx]['source'] = '\\n'.join(\n",
    "    [f\"{k} = '{v}'\" if isinstance(v, str) else f'{k} = {v}' for k, v in args.items()])\n",
    "  if os.path.exists(f'{notebook_folder}/{filename}'):\n",
    "    print(f'To overwrite file {notebook_folder}/{filename}')\n",
    "  else:\n",
    "    print(f'To create file {notebook_folder}/{filename}')\n",
    "  if new_file:\n",
    "    nbformat.write(nb, f'{notebook_folder}/{filename}')\n",
    "  \n",
    "  if submit: # sometimes I just want to create files\n",
    "    if re.search('ccr.buffalo.edu$', socket.gethostname()):\n",
    "      command = f'sbatch {slurm_script} {notebook_folder}/{filename} {filename}'\n",
    "      subprocess.run(command, shell=True)\n",
    "      print(command)\n",
    "    else:\n",
    "      command = ['jupyter nbconvert', '--ExecutePreprocessor.timeout=360000',\n",
    "               '--ExecutePreprocessor.allow_errors=True', '--to notebook', '--execute']\n",
    "      command.append(f'{notebook_folder}/{filename} --output {filename}')\n",
    "      command = ' '.join(command)\n",
    "      start_time = time.time()\n",
    "      tmp = subprocess.run(command, shell=True)\n",
    "      end_time = time.time()\n",
    "      print(f'Time spent: {end_time-start_time:.2f}')\n",
    "  return filename_prefix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BLCA 0 [191, 147]\n",
      "BRCA 1 [740, 107]\n",
      "KIRC 6 [295, 151]\n",
      "LGG 8 [273, 150]\n",
      "LUAD 10 [144, 211]\n",
      "LUSC 11 [102, 202]\n",
      "SARC 16 [115, 102]\n",
      "STAD 17 [106, 224]\n"
     ]
    }
   ],
   "source": [
    "data_folder = '../../pan-can-atlas/data/processed'\n",
    "if not os.path.exists(data_folder):\n",
    "  data_folder = 'F:/TCGA/Pan-Cancer-Atlas/data/processed'\n",
    "with open(f'{data_folder}/sel_patient_clinical.pkl', 'rb') as f:\n",
    "  data = pickle.load(f)\n",
    "  disease_types = data['disease_types']\n",
    "  disease_type_dict = data['disease_type_dict']\n",
    "  pfi = data['pfi']\n",
    "disease_stats = {}\n",
    "for idx, name in disease_type_dict.items():\n",
    "  cnt = list(collections.Counter(pfi[disease_types==idx]).values())\n",
    "  if cnt[0] > 100 and cnt[1] > 100:\n",
    "    disease_stats[idx] = f'{name}: {cnt}'\n",
    "    print(name, idx, cnt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# additional_vars=[],#['age_at_initial_pathologic_diagnosis', 'gender']\n",
    "# additional_var_types=[],#['continuous', 'discrete']\n",
    "# additional_var_ranges=[],\n",
    "\n",
    "additional_vars = ['age_at_initial_pathologic_diagnosis', 'gender', 'ajcc_pathologic_tumor_stage']\n",
    "additional_var_types = ['continuous', 'discrete', 'discrete']\n",
    "additional_var_ranges = [[0, 100], ['MALE', 'FEMALE'], \n",
    "                         ['I/II NOS', 'IS', 'Stage 0', 'Stage I', 'Stage IA', 'Stage IB', \n",
    "                          'Stage II', 'Stage IIA', 'Stage IIB', 'Stage IIC', 'Stage III',\n",
    "                          'Stage IIIA', 'Stage IIIB', 'Stage IIIC', 'Stage IV', 'Stage IVA',\n",
    "                          'Stage IVB', 'Stage IVC', 'Stage X']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "To create file ./DFI_instance-shuffle_age_at_initial_pathologic_diagnosis-ajcc_pathologic_tumor_stage-gender_gene-methy-mirna-rppa_all_100-0_0.7-0.1-0.2_10sets_0_100-100_nn.ipynb\n",
      "To create file ./DFI_instance-shuffle_age_at_initial_pathologic_diagnosis-ajcc_pathologic_tumor_stage-gender_gene-methy-mirna-rppa_all_100-0_0.7-0.1-0.2_10sets_1_100-100_nn.ipynb\n",
      "To create file ./DFI_instance-shuffle_age_at_initial_pathologic_diagnosis-ajcc_pathologic_tumor_stage-gender_gene-methy-mirna-rppa_all_100-0_0.7-0.1-0.2_10sets_2_100-100_nn.ipynb\n",
      "To create file ./DFI_instance-shuffle_age_at_initial_pathologic_diagnosis-ajcc_pathologic_tumor_stage-gender_gene-methy-mirna-rppa_all_100-0_0.7-0.1-0.2_10sets_3_100-100_nn.ipynb\n",
      "To create file ./DFI_instance-shuffle_age_at_initial_pathologic_diagnosis-ajcc_pathologic_tumor_stage-gender_gene-methy-mirna-rppa_all_100-0_0.7-0.1-0.2_10sets_4_100-100_nn.ipynb\n",
      "To create file ./DFI_instance-shuffle_age_at_initial_pathologic_diagnosis-ajcc_pathologic_tumor_stage-gender_gene-methy-mirna-rppa_all_100-0_0.7-0.1-0.2_10sets_5_100-100_nn.ipynb\n",
      "To create file ./DFI_instance-shuffle_age_at_initial_pathologic_diagnosis-ajcc_pathologic_tumor_stage-gender_gene-methy-mirna-rppa_all_100-0_0.7-0.1-0.2_10sets_6_100-100_nn.ipynb\n",
      "To create file ./DFI_instance-shuffle_age_at_initial_pathologic_diagnosis-ajcc_pathologic_tumor_stage-gender_gene-methy-mirna-rppa_all_100-0_0.7-0.1-0.2_10sets_7_100-100_nn.ipynb\n",
      "To create file ./DFI_instance-shuffle_age_at_initial_pathologic_diagnosis-ajcc_pathologic_tumor_stage-gender_gene-methy-mirna-rppa_all_100-0_0.7-0.1-0.2_10sets_8_100-100_nn.ipynb\n",
      "To create file ./DFI_instance-shuffle_age_at_initial_pathologic_diagnosis-ajcc_pathologic_tumor_stage-gender_gene-methy-mirna-rppa_all_100-0_0.7-0.1-0.2_10sets_9_100-100_nn.ipynb\n"
     ]
    }
   ],
   "source": [
    "for i in ['all']:#[0, 1, 6, 8, 10, 11, 16, 17]:\n",
    "  for j in range(10):\n",
    "      for dtype in [['gene', 'mirna', 'rppa', 'methy']]:\n",
    "        submit_job(model_type='nn', dense=False, residual=False, hidden_dim=[100,100], \n",
    "               train_portion=0.7, val_portion=0.1, test_portion=0.2, \n",
    "               num_sets=10, num_folds=10, sel_set_idx=j,\n",
    "               num_train_types=-1, \n",
    "               num_val_types=-1,\n",
    "               num_test_types=-1,\n",
    "               cv_type='instance-shuffle',\n",
    "               sel_disease_types=i, \n",
    "               min_num_samples_per_type_cls=[100, 0],\n",
    "               predefined_sample_set_file='auto-search',\n",
    "               target_variable='DFI',\n",
    "               target_variable_type='discrete',\n",
    "               target_variable_range=[0,1],\n",
    "               data_type=dtype, \n",
    "               additional_vars=additional_vars,\n",
    "               additional_var_types=additional_var_types,\n",
    "               additional_var_ranges=additional_var_ranges,\n",
    "               normal_transform_feature=True, \n",
    "               randomize_labels=False,\n",
    "               lr=5e-4,\n",
    "               weight_decay=1e-4,\n",
    "               num_epochs=100,\n",
    "               reduce_every=500,\n",
    "               show_results_in_notebook=True, \n",
    "               idx_folder='results/data_split_idx', # no longer used\n",
    "               notebook_folder='.', \n",
    "               template_file='exp_template-mv-nn-v3.ipynb', \n",
    "               slurm_script='../run-slurm', \n",
    "               new_file=False, submit=False,\n",
    "               cell_idx=2, gpu_id=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 199,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_results(disease_type_str = '0', #0-1-6-8-10-11-16-17\n",
    "                  model_name = 'ml',\n",
    "                  sel_set_idx = 0,\n",
    "                  data_type_str = 'gene-mirna-rppa-methy',\n",
    "                  data_split_str = '70-10-20',\n",
    "                  hidden_dim_str = '100-100',\n",
    "                  filefolder = 'results',\n",
    "                  target_variable = 'pfi',\n",
    "                  return_variable='metric_all',\n",
    "                  filename=None, plot_acc=True, plot_loss=True):\n",
    "  if filename is None:\n",
    "    filename = (f'{filefolder}/{disease_type_str}_{data_type_str}_set{sel_set_idx}' \n",
    "                f'_{data_split_str}_{target_variable}_{hidden_dim_str}_{model_name}.pkl')\n",
    "    \n",
    "  with open(filename, 'rb') as f:\n",
    "    data = pickle.load(f)\n",
    "  if return_variable in data:\n",
    "    return np.array(data[return_variable])\n",
    "  metric = np.array(data['metric_all'])\n",
    "  confusion_mat = np.array(data['confusion_mat_all'])\n",
    "  model_names, split_names, metric_names = (data['model_names'], data['split_names'], \n",
    "                                            data['metric_names'])\n",
    "  # sanity check\n",
    "  assert metric.shape == (len(model_names), len(split_names), len(metric_names))\n",
    "  assert confusion_mat.shape[:2] == (len(model_names), len(split_names))\n",
    "  loss_his = data['loss_his_all']\n",
    "  acc_his = np.array(data['acc_his_all'])\n",
    "  title =  disease_type_str if len(disease_type_str)>2 else disease_stats[int(disease_type_str)]\n",
    "  if plot_acc and len(acc_his)>0:\n",
    "    for i, n in enumerate(split_names):\n",
    "      plot_history(acc_his[:, i].T, title=f'{title} {n} acc', \n",
    "                   indices=None, colors='rgbkmc', markers='ov+*,<',\n",
    "                       labels=model_names, linestyles=['']*6, markersize=3)\n",
    "    for i, n in enumerate(model_names):\n",
    "      plot_history(acc_his[i].T, title=f'{title} {n} acc', \n",
    "                   indices=None, colors='rgbkmc', markers='ov+*,<',\n",
    "                       labels=split_names, linestyles=['']*6, markersize=3)\n",
    "  if plot_loss and len(loss_his)>0:\n",
    "    for i, n in enumerate(model_names):\n",
    "      history = np.array(loss_his[i])\n",
    "      if history.ndim == 2:\n",
    "        plot_history(history.T, title=f'{title} {n} loss', indices=None, colors='rgbkmc', \n",
    "                     markers='ov+*,<',\n",
    "                       labels=split_names, linestyles=['']*6, markersize=3)\n",
    "      elif history.ndim == 3:\n",
    "        for j in range(history.shape[2]):\n",
    "           plot_history(history[:,:,j].T, title=f'{title} {n} loss{j}', indices=None, \n",
    "                        colors='rgbkmc', markers='ov+*,<',\n",
    "                       labels=split_names, linestyles=['']*6, markersize=3)\n",
    "      else:\n",
    "        raise ValueError(f'{filename} {n} loss has unexpected shape')\n",
    "  if return_variable == 'all':\n",
    "    return metric, confusion_mat, model_names, split_names, metric_names, acc_his, loss_his"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 206,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['acc', 'precision', 'recall', 'f1_score', 'adjusted_mutual_info',\n",
       "       'auc', 'average_precision'], dtype='<U20')"
      ]
     },
     "execution_count": 206,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "load_results(return_variable='metric_names', \n",
    "  filename=(f'results/PFI_instance-shuffle_gene-methy-mirna-rppa_all'\n",
    "            f'_100-0_0.7-0.1-0.2_10sets_0_100-100_nn.pkl'))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}