--- a +++ b/datasets_csv/.ipynb_checkpoints/Preprocessing-checkpoint.ipynb @@ -0,0 +1,4203 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from os.path import join\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "label_col = 'survival_months'\n", + "n_bins = 4\n", + "eps = 1e-6" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def add_bins(slide_data):\n", + " assert 'case_id' in slide_data.columns and 'censorship' in slide_data.columns\n", + " \n", + " patients_df = slide_data.drop_duplicates(['case_id']).copy()\n", + " uncensored_df = patients_df[patients_df['censorship'] < 1]\n", + " disc_labels, q_bins = pd.qcut(uncensored_df[label_col], q=n_bins, retbins=True, labels=False)\n", + " q_bins[-1] = slide_data[label_col].max() + eps\n", + " q_bins[0] = slide_data[label_col].min() - eps\n", + "\n", + " disc_labels, q_bins = pd.cut(patients_df[label_col], bins=q_bins, retbins=True, labels=False, right=False, include_lowest=True)\n", + " patients_df.insert(2, 'label', disc_labels.values.astype(int))\n", + "\n", + " patient_dict = {}\n", + " slide_data = slide_data.set_index('case_id')\n", + " for patient in patients_df['case_id']:\n", + " slide_ids = slide_data.loc[patient, 'slide_id']\n", + " if isinstance(slide_ids, str):\n", + " slide_ids = np.array(slide_ids).reshape(-1)\n", + " else:\n", + " slide_ids = slide_ids.values\n", + " patient_dict.update({patient:slide_ids})\n", + " \n", + " return q_bins, patient_dict, patients_df" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "slide_data = pd.read_csv('./tcga_gbmlgg_all_clean.csv.zip', compression='zip', header=0, index_col=0, sep=',', low_memory=False)\n", + "\n", + "n_bins = 4\n", + "eps = 1e-6\n", + "\n", + "### Asserts that 'case_id' is a column, not an index.\n", + "if 'case_id' not in slide_data:\n", + " slide_data.index = slide_data.index.str[:12]\n", + " slide_data['case_id'] = slide_data.index\n", + " slide_data = slide_data.reset_index(drop=True)\n", + "\n", + "q_bins, patients_dict, slide_data = add_bins(slide_data)\n", + "\n", + "slide_data.reset_index(drop=True, inplace=True)\n", + "slide_data = slide_data.assign(slide_id=slide_data['case_id'])\n", + "\n", + "label_dict = {}\n", + "key_count = 0\n", + "for i in range(len(q_bins)-1):\n", + " for c in [0, 1]:\n", + " label_dict.update({(i, c):key_count})\n", + " key_count+=1\n", + "\n", + "for i in slide_data.index:\n", + " key = slide_data.loc[i, 'label']\n", + " slide_data.at[i, 'disc_label'] = key\n", + " censorship = slide_data.loc[i, 'censorship']\n", + " key = (key, int(censorship))\n", + " slide_data.at[i, 'label'] = label_dict[key]\n", + "\n", + "bins = q_bins\n", + "num_classes=len(label_dict)\n", + "patients_df = slide_data.drop_duplicates(['case_id'])\n", + "patient_data = {'case_id':patients_df['case_id'].values, 'label':patients_df['label'].values}\n", + "\n", + "new_cols = list(slide_data.columns[-2:]) + list(slide_data.columns[:-2])\n", + "slide_data = slide_data[new_cols]\n", + "metadata = slide_data.columns[:11]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.pipeline import Pipeline\n", + "from sklearn.decomposition import PCA\n", + "from sklearn.preprocessing import StandardScaler\n", + "\n", + "\n", + "def series_intersection(s1, s2):\n", + " return pd.Series(list(set(s1) & set(s2)))\n", + "\n", + "genomic_features = slide_data.drop(metadata, axis=1)\n", + "scaler_omic = StandardScaler().fit(genomic_features)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/mahmoodlab/anaconda3/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3071: DtypeWarning: Columns (2) have mixed types.Specify dtype option on import or set low_memory=False.\n", + " has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n" + ] + } + ], + "source": [ + "signatures = pd.read_csv('./signatures.csv')\n", + "slide_df = pd.read_csv('./tcga_gbmlgg_all_clean.csv.zip')" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "omic_from_signatures = []\n", + "for col in signatures.columns:\n", + " omic = signatures[col].dropna().unique()\n", + " omic_from_signatures.append(omic)\n", + "\n", + "omic_from_signatures = np.concatenate(omic_from_signatures)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/mahmoodlab/anaconda3/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3071: DtypeWarning: Columns (2) have mixed types.Specify dtype option on import or set low_memory=False.\n", + " has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n", + "/home/mahmoodlab/anaconda3/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3071: DtypeWarning: Columns (4) have mixed types.Specify dtype option on import or set low_memory=False.\n", + " has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n" + ] + } + ], + "source": [ + "for fname in os.listdir('./'):\n", + " if fname.endswith('.csv.zip'):\n", + " slide_df = pd.read_csv(fname)\n", + " omic_overlap = np.concatenate([omic_from_signatures+mode for mode in ['_mut', '_cnv', '_rnaseq']])\n", + " omic_overlap = sorted(series_intersection(omic_overlap, slide_df.columns))\n", + " slide_df[list(slide_df.columns[:9]) + omic_overlap].to_csv('../dataset_csv/%s' % fname)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "omic_from_signatures = []\n", + "for col in signatures.columns:\n", + " omic = signatures[col].dropna().unique()\n", + " omic_from_signatures.append(omic)\n", + "\n", + "omic_from_signatures = np.concatenate(omic_from_signatures)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tumor Suppressor Genes Embedding Size: 84\n", + "Oncogenes Embedding Size: 314\n", + "Protein Kinases Embedding Size: 498\n", + "Cell Differentiation Markers Embedding Size: 415\n", + "Transcription Factors Embedding Size: 1396\n", + "Cytokines and Growth Factors Embedding Size: 428\n" + ] + } + ], + "source": [ + "\n", + "def series_intersection(s1, s2):\n", + " return pd.Series(list(set(s1) & set(s2)))\n", + "\n", + "sig_names = []\n", + "for col in signatures.columns:\n", + " sig = signatures[col].dropna().unique()\n", + " sig = np.concatenate([sig+mode for mode in ['_mut', '_cnv', '_rnaseq']])\n", + " sig = sorted(series_intersection(sig, genomic_features.columns))\n", + " sig_names.append(sig)\n", + " print('%s Embedding Size: %d' % (col, len(sig)))\n", + "sig_sizes = [len(sig) for sig in sig_names]" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['IFNA10_cnv',\n", + " 'IFNA13_cnv',\n", + " 'IFNA14_cnv',\n", + " 'IFNA16_cnv',\n", + " 'IFNA17_cnv',\n", + " 'IFNA1_cnv',\n", + " 'IFNA21_cnv',\n", + " 'IFNA2_cnv',\n", + " 'IFNA4_cnv',\n", + " 'IFNA5_cnv',\n", + " 'IFNA6_cnv',\n", + " 'IFNA7_cnv',\n", + " 'IFNA8_cnv',\n", + " 'IFNB1_cnv',\n", + " 'IFNE_cnv',\n", + " 'IFNW1_cnv',\n", + " 'PDGFRA_cnv']" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sig" + ] + }, + { + "cell_type": "code", + "execution_count": 434, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>NDUFS5_cnv</th>\n", + " <th>MACF1_cnv</th>\n", + " <th>RNA5SP44_cnv</th>\n", + " <th>KIAA0754_cnv</th>\n", + " <th>BMP8A_cnv</th>\n", + " <th>PABPC4_cnv</th>\n", + " <th>SNORA55_cnv</th>\n", + " <th>HEYL_cnv</th>\n", + " <th>HPCAL4_cnv</th>\n", + " <th>NT5C1A_cnv</th>\n", + " <th>...</th>\n", + " <th>ZWINT_rnaseq</th>\n", + " <th>ZXDA_rnaseq</th>\n", + " <th>ZXDB_rnaseq</th>\n", + " <th>ZXDC_rnaseq</th>\n", + " <th>ZYG11A_rnaseq</th>\n", + " <th>ZYG11B_rnaseq</th>\n", + " <th>ZYX_rnaseq</th>\n", + " <th>ZZEF1_rnaseq</th>\n", + " <th>ZZZ3_rnaseq</th>\n", + " <th>TPTEP1_rnaseq</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>-1</td>\n", + " <td>-1</td>\n", + " <td>-1</td>\n", + " <td>-1</td>\n", + " <td>-1</td>\n", + " <td>-1</td>\n", + " <td>-1</td>\n", + " <td>-1</td>\n", + " <td>-1</td>\n", + " <td>-1</td>\n", + " <td>...</td>\n", + " <td>-0.8388</td>\n", + " <td>4.1375</td>\n", + " <td>3.9664</td>\n", + " <td>1.8437</td>\n", + " <td>-0.3959</td>\n", + " <td>-0.2561</td>\n", + " <td>-0.2866</td>\n", + " <td>1.8770</td>\n", + " <td>-0.3179</td>\n", + " <td>-0.3633</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>...</td>\n", + " <td>-0.1083</td>\n", + " <td>0.3393</td>\n", + " <td>0.2769</td>\n", + " <td>1.7320</td>\n", + " <td>-0.0975</td>\n", + " <td>2.6955</td>\n", + " <td>-0.6741</td>\n", + " <td>1.0323</td>\n", + " <td>1.2766</td>\n", + " <td>-0.3982</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>...</td>\n", + " <td>-0.4155</td>\n", + " <td>1.6846</td>\n", + " <td>0.7711</td>\n", + " <td>-0.3061</td>\n", + " <td>-0.5016</td>\n", + " <td>2.8548</td>\n", + " <td>-0.6171</td>\n", + " <td>-0.8608</td>\n", + " <td>-0.0486</td>\n", + " <td>-0.3962</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>...</td>\n", + " <td>-0.8143</td>\n", + " <td>0.8344</td>\n", + " <td>1.5075</td>\n", + " <td>3.6068</td>\n", + " <td>-0.5004</td>\n", + " <td>-0.0747</td>\n", + " <td>-0.2185</td>\n", + " <td>-0.4379</td>\n", + " <td>1.6913</td>\n", + " <td>1.7748</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>...</td>\n", + " <td>0.0983</td>\n", + " <td>-0.7908</td>\n", + " <td>-0.0053</td>\n", + " <td>-0.0643</td>\n", + " <td>-0.3706</td>\n", + " <td>0.3870</td>\n", + " <td>-0.5589</td>\n", + " <td>-0.5979</td>\n", + " <td>0.0047</td>\n", + " <td>-0.3548</td>\n", + " </tr>\n", + " <tr>\n", + " <th>...</th>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " </tr>\n", + " <tr>\n", + " <th>368</th>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>...</td>\n", + " <td>-0.0291</td>\n", + " <td>-0.1058</td>\n", + " <td>-0.6721</td>\n", + " <td>0.2802</td>\n", + " <td>1.9504</td>\n", + " <td>-0.8784</td>\n", + " <td>0.9506</td>\n", + " <td>0.0607</td>\n", + " <td>1.1883</td>\n", + " <td>-0.3521</td>\n", + " </tr>\n", + " <tr>\n", + " <th>369</th>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>...</td>\n", + " <td>0.0497</td>\n", + " <td>0.3673</td>\n", + " <td>-0.2208</td>\n", + " <td>0.3034</td>\n", + " <td>3.2580</td>\n", + " <td>-0.2089</td>\n", + " <td>1.6053</td>\n", + " <td>-0.8746</td>\n", + " <td>-0.4491</td>\n", + " <td>-0.3450</td>\n", + " </tr>\n", + " <tr>\n", + " <th>370</th>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>...</td>\n", + " <td>0.3822</td>\n", + " <td>-0.7003</td>\n", + " <td>-0.7661</td>\n", + " <td>-1.7035</td>\n", + " <td>-0.5423</td>\n", + " <td>-0.3488</td>\n", + " <td>1.3713</td>\n", + " <td>-0.4365</td>\n", + " <td>2.3456</td>\n", + " <td>-0.3866</td>\n", + " </tr>\n", + " <tr>\n", + " <th>371</th>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>...</td>\n", + " <td>-0.6853</td>\n", + " <td>-1.0240</td>\n", + " <td>-1.2890</td>\n", + " <td>-1.5666</td>\n", + " <td>-0.1270</td>\n", + " <td>-1.4662</td>\n", + " <td>0.3981</td>\n", + " <td>-0.5976</td>\n", + " <td>-1.3822</td>\n", + " <td>-0.4157</td>\n", + " </tr>\n", + " <tr>\n", + " <th>372</th>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>...</td>\n", + " <td>0.0517</td>\n", + " <td>-0.3570</td>\n", + " <td>-0.4843</td>\n", + " <td>-0.3792</td>\n", + " <td>-0.1964</td>\n", + " <td>0.4200</td>\n", + " <td>3.2547</td>\n", + " <td>-0.1232</td>\n", + " <td>3.4519</td>\n", + " <td>-0.1962</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>373 rows × 20395 columns</p>\n", + "</div>" + ], + "text/plain": [ + " NDUFS5_cnv MACF1_cnv RNA5SP44_cnv KIAA0754_cnv BMP8A_cnv PABPC4_cnv \\\n", + "0 -1 -1 -1 -1 -1 -1 \n", + "1 2 2 2 2 2 2 \n", + "2 0 0 0 0 0 0 \n", + "3 0 0 0 0 0 0 \n", + "4 0 0 0 0 0 0 \n", + ".. ... ... ... ... ... ... \n", + "368 2 2 2 2 2 2 \n", + "369 0 0 0 0 0 0 \n", + "370 1 1 1 1 1 1 \n", + "371 0 0 0 0 0 0 \n", + "372 0 0 0 0 0 0 \n", + "\n", + " SNORA55_cnv HEYL_cnv HPCAL4_cnv NT5C1A_cnv ... ZWINT_rnaseq \\\n", + "0 -1 -1 -1 -1 ... -0.8388 \n", + "1 2 2 2 2 ... -0.1083 \n", + "2 0 0 0 0 ... -0.4155 \n", + "3 0 0 0 0 ... -0.8143 \n", + "4 0 0 0 0 ... 0.0983 \n", + ".. ... ... ... ... ... ... \n", + "368 2 2 2 2 ... -0.0291 \n", + "369 0 0 0 0 ... 0.0497 \n", + "370 1 1 1 1 ... 0.3822 \n", + "371 0 0 0 0 ... -0.6853 \n", + "372 0 0 0 0 ... 0.0517 \n", + "\n", + " ZXDA_rnaseq ZXDB_rnaseq ZXDC_rnaseq ZYG11A_rnaseq ZYG11B_rnaseq \\\n", + "0 4.1375 3.9664 1.8437 -0.3959 -0.2561 \n", + "1 0.3393 0.2769 1.7320 -0.0975 2.6955 \n", + "2 1.6846 0.7711 -0.3061 -0.5016 2.8548 \n", + "3 0.8344 1.5075 3.6068 -0.5004 -0.0747 \n", + "4 -0.7908 -0.0053 -0.0643 -0.3706 0.3870 \n", + ".. ... ... ... ... ... \n", + "368 -0.1058 -0.6721 0.2802 1.9504 -0.8784 \n", + "369 0.3673 -0.2208 0.3034 3.2580 -0.2089 \n", + "370 -0.7003 -0.7661 -1.7035 -0.5423 -0.3488 \n", + "371 -1.0240 -1.2890 -1.5666 -0.1270 -1.4662 \n", + "372 -0.3570 -0.4843 -0.3792 -0.1964 0.4200 \n", + "\n", + " ZYX_rnaseq ZZEF1_rnaseq ZZZ3_rnaseq TPTEP1_rnaseq \n", + "0 -0.2866 1.8770 -0.3179 -0.3633 \n", + "1 -0.6741 1.0323 1.2766 -0.3982 \n", + "2 -0.6171 -0.8608 -0.0486 -0.3962 \n", + "3 -0.2185 -0.4379 1.6913 1.7748 \n", + "4 -0.5589 -0.5979 0.0047 -0.3548 \n", + ".. ... ... ... ... \n", + "368 0.9506 0.0607 1.1883 -0.3521 \n", + "369 1.6053 -0.8746 -0.4491 -0.3450 \n", + "370 1.3713 -0.4365 2.3456 -0.3866 \n", + "371 0.3981 -0.5976 -1.3822 -0.4157 \n", + "372 3.2547 -0.1232 3.4519 -0.1962 \n", + "\n", + "[373 rows x 20395 columns]" + ] + }, + "execution_count": 434, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "genomic_features" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import pdb\n", + "import numpy as np\n", + "\n", + "class MIL_Sum_FC_surv(nn.Module):\n", + " def __init__(self, size_arg = \"small\", dropout=0.25, n_classes=4):\n", + " super(MIL_Sum_FC_surv, self).__init__()\n", + "\n", + " self.size_dict = {\"small\": [1024, 512, 256], \"big\": [1024, 512, 384]}\n", + " size = self.size_dict[size_arg]\n", + " self.phi = nn.Sequential(*[nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout)])\n", + " self.rho = nn.Sequential(*[nn.Linear(size[1], size[2]), nn.ReLU(), nn.Dropout(dropout)])\n", + " self.classifier = nn.Linear(size[2], n_classes)\n", + "\n", + " def relocate(self):\n", + " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + " if torch.cuda.device_count() >= 1:\n", + " device_ids = list(range(torch.cuda.device_count()))\n", + " self.phi = nn.DataParallel(self.phi, device_ids=device_ids).to('cuda:0')\n", + "\n", + " self.rho = self.rho.to(device)\n", + " self.classifier = self.classifier.to(device)\n", + "\n", + " def forward(self, **kwargs):\n", + " h = kwargs['x_path']\n", + "\n", + " h = self.phi(h).sum(axis=0)\n", + " h = self.rho(h)\n", + " logits = self.classifier(h).unsqueeze(0)\n", + " Y_hat = torch.topk(logits, 1, dim = 1)[1]\n", + " hazards = torch.sigmoid(logits)\n", + " S = torch.cumprod(1 - hazards, dim=1)\n", + " \n", + " return hazards, S, Y_hat, None, None\n", + "\n", + "from os.path import join\n", + "from collections import OrderedDict\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import pdb\n", + "import numpy as np\n", + "\n", + "\"\"\"\n", + "A Modified Implementation of Deep Attention MIL\n", + "\"\"\"\n", + "\n", + "\n", + "\"\"\"\n", + "Attention Network without Gating (2 fc layers)\n", + "args:\n", + " L: input feature dimension\n", + " D: hidden layer dimension\n", + " dropout: whether to use dropout (p = 0.25)\n", + " n_classes: number of classes (experimental usage for multiclass MIL)\n", + "\"\"\"\n", + "class Attn_Net(nn.Module):\n", + "\n", + " def __init__(self, L = 1024, D = 256, dropout = False, n_classes = 1):\n", + " super(Attn_Net, self).__init__()\n", + " self.module = [\n", + " nn.Linear(L, D),\n", + " nn.Tanh()]\n", + "\n", + " if dropout:\n", + " self.module.append(nn.Dropout(0.25))\n", + "\n", + " self.module.append(nn.Linear(D, n_classes))\n", + " \n", + " self.module = nn.Sequential(*self.module)\n", + " \n", + " def forward(self, x):\n", + " return self.module(x), x # N x n_classes\n", + "\n", + "\"\"\"\n", + "Attention Network with Sigmoid Gating (3 fc layers)\n", + "args:\n", + " L: input feature dimension\n", + " D: hidden layer dimension\n", + " dropout: whether to use dropout (p = 0.25)\n", + " n_classes: number of classes (experimental usage for multiclass MIL)\n", + "\"\"\"\n", + "class Attn_Net_Gated(nn.Module):\n", + "\n", + " def __init__(self, L = 1024, D = 256, dropout = False, n_classes = 1):\n", + " super(Attn_Net_Gated, self).__init__()\n", + " self.attention_a = [\n", + " nn.Linear(L, D),\n", + " nn.Tanh()]\n", + " \n", + " self.attention_b = [nn.Linear(L, D),\n", + " nn.Sigmoid()]\n", + " if dropout:\n", + " self.attention_a.append(nn.Dropout(0.25))\n", + " self.attention_b.append(nn.Dropout(0.25))\n", + "\n", + " self.attention_a = nn.Sequential(*self.attention_a)\n", + " self.attention_b = nn.Sequential(*self.attention_b)\n", + " \n", + " self.attention_c = nn.Linear(D, n_classes)\n", + "\n", + " def forward(self, x):\n", + " a = self.attention_a(x)\n", + " b = self.attention_b(x)\n", + " A = a.mul(b)\n", + " A = self.attention_c(A) # N x n_classes\n", + " return A, x\n", + " \n", + "class MIL_Cluster_FC_surv(nn.Module):\n", + " def __init__(self, num_clusters=10, size_arg = \"small\", dropout=0.25, n_classes=4):\n", + " super(MIL_Cluster_FC_surv, self).__init__()\n", + " self.size_dict = {\"small\": [1024, 512, 256], \"big\": [1024, 512, 384]}\n", + " self.num_clusters = num_clusters\n", + " \n", + " ### Phenotype Learning\n", + " size = self.size_dict[size_arg]\n", + " phis = []\n", + " for phenotype_i in range(num_clusters):\n", + " phi = [nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout),\n", + " nn.Linear(size[1], size[1]), nn.ReLU(), nn.Dropout(dropout)]\n", + " phis.append(nn.Sequential(*phi))\n", + " self.phis = nn.ModuleList(phis)\n", + " self.pool1d = nn.AdaptiveAvgPool1d(1)\n", + " \n", + " \n", + " ### WSI Attention MIL Construction\n", + " fc = [nn.Linear(size[1], size[1]), nn.ReLU()]\n", + " fc.append(nn.Dropout(0.25))\n", + " attention_net = Attn_Net_Gated(L=size[1], D=size[2], dropout=dropout, n_classes=1)\n", + " fc.append(attention_net)\n", + " self.attention_net = nn.Sequential(*fc)\n", + "\n", + " \n", + " self.rho = nn.Sequential(*[nn.Linear(size[1], size[2]), nn.ReLU(), nn.Dropout(dropout)])\n", + " self.classifier = nn.Linear(size[2], n_classes)\n", + "\n", + " def relocate(self):\n", + " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + " if torch.cuda.device_count() >= 1:\n", + " device_ids = list(range(torch.cuda.device_count()))\n", + " self.phis = nn.DataParallel(self.phi, device_ids=device_ids).to('cuda:0')\n", + "\n", + " self.rho = self.rho.to(device)\n", + " self.classifier = self.classifier.to(device)\n", + "\n", + " def forward(self, **kwargs):\n", + " x_path = kwargs['x_path']\n", + " ### Phenotyping\n", + " h_phenotypes = []\n", + " from sklearn.cluster import KMeans\n", + " kmeans = KMeans(n_clusters=self.num_clusters, random_state=2021).fit(X)\n", + " #cluster_ids_x, cluster_centers = kmeans(X=x_path, num_clusters=self.num_clusters, distance='euclidean', device=torch.device('cpu'))\n", + " cluster_ids_x = KMeans(n_clusters=10, random_state=2021, max_iter=20).fit_predict(x_path)\n", + " for i in range(self.num_clusters):\n", + " h_phenotypes_i = self.phis[i](x_path[cluster_ids_x==i])\n", + " h_phenotypes.append(self.pool1d(h_phenotypes_i.T.unsqueeze(0)).squeeze(2))\n", + " h_phenotypes = torch.stack(h_phenotypes, dim=1).squeeze(0)\n", + "\n", + "\n", + " ### Attention MIL\n", + " A, h = self.attention_net(h_phenotypes) \n", + " A = torch.transpose(A, 1, 0)\n", + " if 'attention_only' in kwargs.keys():\n", + " if kwargs['attention_only']:\n", + " return A\n", + " A_raw = A \n", + " A = F.softmax(A, dim=1) \n", + " h = torch.mm(A, h_phenotypes)\n", + "\n", + " \n", + " h = self.rho(h)\n", + " logits = self.classifier(h).unsqueeze(0)\n", + " Y_hat = torch.topk(logits, 1, dim = 1)[1]\n", + " hazards = torch.sigmoid(logits)\n", + " S = torch.cumprod(1 - hazards, dim=1)\n", + " \n", + " return hazards, S, Y_hat, None, None" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "x_path = torch.randint(10, size=(500, 1024)).type(torch.cuda.FloatTensor)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.cluster import KMeans\n", + "kmeans = KMeans(n_clusters=10, random_state=2021, max_iter=20).fit_predict(x_path.cpu())" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([5, 5, 3, 5, 8, 4, 8, 7, 5, 4, 9, 1, 9, 1, 6, 1, 1, 0, 5, 0, 4, 3,\n", + " 0, 6, 3, 1, 0, 7, 9, 8, 0, 5, 5, 3, 0, 1, 5, 1, 0, 6, 6, 4, 1, 5,\n", + " 3, 0, 1, 0, 8, 5, 1, 8, 1, 0, 5, 0, 2, 5, 6, 5, 0, 0, 5, 1, 2, 7,\n", + " 4, 6, 5, 3, 0, 7, 9, 1, 3, 4, 4, 5, 7, 9, 9, 5, 0, 1, 9, 1, 2, 0,\n", + " 6, 3, 1, 1, 2, 4, 0, 5, 1, 1, 1, 0, 0, 9, 8, 1, 5, 5, 0, 9, 2, 3,\n", + " 7, 0, 1, 6, 7, 5, 3, 5, 0, 1, 6, 1, 6, 2, 8, 7, 6, 1, 6, 2, 5, 0,\n", + " 1, 6, 0, 9, 2, 1, 0, 1, 7, 7, 6, 1, 6, 0, 3, 4, 1, 3, 2, 4, 4, 5,\n", + " 4, 1, 1, 9, 6, 0, 3, 6, 4, 8, 7, 9, 6, 5, 5, 9, 0, 6, 0, 1, 9, 2,\n", + " 3, 5, 1, 9, 6, 1, 0, 6, 6, 0, 0, 6, 7, 1, 6, 1, 1, 1, 4, 0, 2, 1,\n", + " 9, 5, 7, 5, 9, 0, 1, 0, 6, 2, 2, 1, 1, 5, 3, 5, 3, 6, 5, 6, 9, 5,\n", + " 2, 2, 2, 6, 0, 0, 0, 5, 2, 6, 6, 0, 2, 5, 1, 9, 2, 4, 4, 0, 4, 7,\n", + " 4, 1, 1, 3, 6, 0, 1, 2, 4, 0, 8, 1, 8, 5, 5, 7, 4, 1, 6, 1, 0, 8,\n", + " 6, 1, 1, 4, 8, 7, 5, 2, 3, 0, 2, 9, 5, 6, 4, 3, 6, 5, 5, 4, 6, 6,\n", + " 0, 1, 5, 1, 1, 1, 1, 9, 5, 7, 3, 0, 2, 4, 0, 5, 4, 0, 5, 0, 6, 0,\n", + " 3, 1, 4, 6, 3, 7, 1, 6, 7, 0, 1, 4, 6, 1, 6, 0, 6, 0, 5, 9, 1, 1,\n", + " 3, 1, 5, 6, 1, 6, 6, 8, 2, 0, 7, 9, 9, 6, 0, 6, 2, 6, 8, 0, 8, 5,\n", + " 1, 3, 1, 9, 2, 3, 5, 8, 2, 5, 6, 6, 5, 2, 9, 0, 1, 8, 5, 9, 5, 1,\n", + " 0, 1, 0, 8, 6, 1, 7, 2, 8, 3, 1, 6, 2, 2, 1, 6, 0, 2, 6, 1, 1, 4,\n", + " 5, 6, 4, 0, 5, 0, 9, 0, 4, 8, 0, 7, 6, 5, 5, 0, 4, 1, 1, 2, 2, 0,\n", + " 0, 6, 4, 0, 7, 7, 2, 3, 1, 4, 7, 9, 4, 7, 2, 4, 5, 6, 4, 5, 7, 9,\n", + " 8, 0, 6, 2, 0, 6, 6, 3, 5, 4, 4, 0, 1, 0, 5, 3, 1, 6, 0, 7, 4, 1,\n", + " 6, 3, 6, 0, 4, 1, 5, 7, 3, 1, 4, 8, 0, 7, 0, 6, 1, 1, 0, 1, 5, 1,\n", + " 2, 3, 2, 3, 8, 8, 4, 6, 5, 6, 1, 0, 7, 6, 4, 4], dtype=int32)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "kmeans" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([[0.9992, 0.0000, 0.0000, 1.0000]], grad_fn=<SigmoidBackward>),\n", + " tensor([[0.0008, 0.0008, 0.0008, 0.0000]], grad_fn=<CumprodBackward>),\n", + " tensor([[3]]),\n", + " None,\n", + " None)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x_path = torch.randint(10, size=(500, 1024)).type(torch.FloatTensor)\n", + "model = MIL_Sum_FC_surv()\n", + "model.forward(x_path=x_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([[4.2595e-07, 1.0000e+00, 0.0000e+00, 7.2488e-12]],\n", + " grad_fn=<SigmoidBackward>),\n", + " tensor([[1.0000, 0.0000, 0.0000, 0.0000]], grad_fn=<CumprodBackward>),\n", + " tensor([[1]]),\n", + " None,\n", + " None)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x_path = torch.randint(10, size=(500, 1024)).type(torch.FloatTensor)\n", + "self = MIL_Cluster_FC_surv()\n", + "model.forward(x_path=x_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "fname = os.path.join('/media/ssd1/pan-cancer/tcga_gbm_20x_features/h5_files/TCGA-02-0001-01Z-00-DX1.83fce43e-42ac-4dcd-b156-2908e75f2e47.h5')" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "import h5py\n", + "h5 = h5py.File(fname, \"r\")\n", + "coords = np.array(h5['coords'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fm" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([43121, 29428])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.array(h5['coords'])[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([43121, 29940])" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.array(h5['coords'])[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "512" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.array(h5['coords'])[1][1] - np.array(h5['coords'])[0][1]" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "512" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.array(h5['coords'])[2][1] - np.array(h5['coords'])[1][1]" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "import nmslib\n", + "class Hnsw:\n", + "\n", + " def __init__(self, space='cosinesimil', index_params=None,\n", + " query_params=None, print_progress=True):\n", + " self.space = space\n", + " self.index_params = index_params\n", + " self.query_params = query_params\n", + " self.print_progress = print_progress\n", + "\n", + " def fit(self, X):\n", + " index_params = self.index_params\n", + " if index_params is None:\n", + " index_params = {'M': 16, 'post': 0, 'efConstruction': 400}\n", + "\n", + " query_params = self.query_params\n", + " if query_params is None:\n", + " query_params = {'ef': 90}\n", + "\n", + " # this is the actual nmslib part, hopefully the syntax should\n", + " # be pretty readable, the documentation also has a more verbiage\n", + " # introduction: https://nmslib.github.io/nmslib/quickstart.html\n", + " index = nmslib.init(space=self.space, method='hnsw')\n", + " index.addDataPointBatch(X)\n", + " index.createIndex(index_params, print_progress=self.print_progress)\n", + " index.setQueryTimeParams(query_params)\n", + "\n", + " self.index_ = index\n", + " self.index_params_ = index_params\n", + " self.query_params_ = query_params\n", + " return self\n", + "\n", + " def query(self, vector, topn):\n", + " # the knnQuery returns indices and corresponding distance\n", + " # we will throw the distance away for now\n", + " indices, _ = self.index_.knnQuery(vector, k=topn)\n", + " return indices" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([85, 87, 88, 73, 75, 76, 63, 29], dtype=int32)" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = Hnsw(space='l2')\n", + "model.fit(coords)\n", + "model.query(coords, topn=8)" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [], + "source": [ + "import networkx as nx\n", + "G = nx.Graph()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([43121, 29428])" + ] + }, + "execution_count": 56, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "for" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "130" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "temp[3]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 7440, 13280])" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "coords[100]" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "indices = model.query(coords[100], topn =10)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 7440, 13280],\n", + " [ 7440, 13792],\n", + " [ 7952, 13280],\n", + " [ 6928, 13792],\n", + " [ 7952, 12768],\n", + " [ 7952, 13792],\n", + " [ 7440, 14304],\n", + " [ 8464, 13280],\n", + " [ 6928, 14304],\n", + " [ 8464, 13792]])" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "coords[indices]" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "metadata": {}, + "outputs": [], + "source": [ + "def do_KmeansPCA(X=None, y=None, scaler=None, n_clusters=4, n_components=5):\n", + " import pandas as pd\n", + " import seaborn as sns\n", + " from sklearn.datasets import make_blobs\n", + " from sklearn import decomposition\n", + " from sklearn.decomposition import PCA, TruncatedSVD\n", + " from sklearn.preprocessing import StandardScaler, Normalizer\n", + " from sklearn.pipeline import make_pipeline\n", + " from sklearn.cluster import KMeans\n", + " ### Initialize Scaler\n", + " if scaler is None: \n", + " scaler = StandardScaler()\n", + " ### Get Random Data\n", + " X, y = make_blobs(n_features=10, n_samples=100, centers=4, random_state=4, cluster_std=7)\n", + " ### Scale Data\n", + " X = scaler.fit_transform(X)\n", + " ### Perform K-Means Clustering\n", + " cls = KMeans(n_clusters=n_clusters, init='k-means++', n_jobs=-1, n_init=1)\n", + " y_pred = cls.fit_predict(X)\n", + " ### Perform PCA\n", + " pca = PCA(n_components=n_components)\n", + " pc = pca.fit_transform(X)\n", + " ### Plot Results\n", + " columns = ['PC%d'%c for c in range(1, n_components+1)]\n", + " pc_df = pd.DataFrame(data=pc, columns=columns)\n", + " pc_df['y_pred'] = y_pred\n", + " pc_df['y'] = y\n", + " df = pd.DataFrame({'Variance Explained':pca.explained_variance_ratio_, 'Principal Components': columns})\n", + " sns.barplot(x='Principal Components',y=\"Variance Explained\", data=df, color=\"c\")\n", + " sns.lmplot( x=\"PC1\", y=\"PC2\", data=pc_df, fit_reg=False, \n", + " hue='y', legend=True, scatter_kws={\"s\": 80})\n", + " sns.lmplot( x=\"PC1\", y=\"PC2\", data=pc_df, fit_reg=False, \n", + " hue='y', legend=True, scatter_kws={\"s\": 80})" + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "<Figure size 432x288 with 1 Axes>" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "<Figure size 402.375x360 with 1 Axes>" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "<Figure size 402.375x360 with 1 Axes>" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "do_KmeansPCA()" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([[1.0000, 0.0000, 1.0000, 0.9998]], grad_fn=<SigmoidBackward>),\n", + " tensor([[0., 0., 0., 0.]], grad_fn=<CumprodBackward>),\n", + " tensor([[2]]),\n", + " None,\n", + " None)" + ] + }, + "execution_count": 76, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.forward(x_path=x_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": {}, + "outputs": [ + { + "ename": "SyntaxError", + "evalue": "invalid syntax (<ipython-input-69-c543913fa78f>, line 1)", + "output_type": "error", + "traceback": [ + "\u001b[0;36m File \u001b[0;32m\"<ipython-input-69-c543913fa78f>\"\u001b[0;36m, line \u001b[0;32m1\u001b[0m\n\u001b[0;31m import ..models\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n" + ] + } + ], + "source": [ + "import ..models" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [], + "source": [ + "x_path = torch.randint(10, size=(500, 1024)).type(torch.FloatTensor)" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'MultiheadAttention' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m<ipython-input-65-f85a99af33ee>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mself\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMM_CoAttn_Transformer_Surv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0momic_sizes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msig_sizes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m<ipython-input-62-9e5f322e30a0>\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, omic_sizes, n_classes, model_size_wsi, model_size_omic, dropout)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;31m### Multihead Attention\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcoattn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMultiheadAttention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0membed_dim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m256\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_heads\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;31m### Transformer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'MultiheadAttention' is not defined" + ] + } + ], + "source": [ + "self = MM_CoAttn_Transformer_Surv(omic_sizes=sig_sizes)" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'sig_size' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m<ipython-input-52-097a03ed0c40>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mself\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMM_CoAttn_Surv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msig_sizes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msig_sizes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mx_path\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m500\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1024\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFloatTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0msig_feats\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFloatTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0msize\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msig_sizes\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mx_path\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention_net\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m<ipython-input-43-4469ba9e1eea>\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, sig_sizes, n_classes, model_size_wsi, model_size_omic, dropout)\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0mhidden\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize_dict_omic\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mmodel_size_omic\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0msig_networks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0minput_dim\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msig_size\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 22\u001b[0m \u001b[0mfc_omic\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mSNN_Block\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim1\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minput_dim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim2\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mhidden\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'sig_size' is not defined" + ] + } + ], + "source": [ + "self = MM_CoAttn_Surv(sig_sizes=sig_sizes)\n", + "x_path = torch.randint(10, size=(500, 1024)).type(torch.FloatTensor)\n", + "sig_feats = [torch.randint(10, size=(size,)).type(torch.FloatTensor) for size in sig_sizes]\n", + "\n", + "x_path = self.attention_net(x_path).unsqueeze(1)\n", + "x_omic = torch.stack([self.sig_networks[idx].forward(sig_feat) for idx, sig_feat in enumerate(sig_feats)]).unsqueeze(1)\n", + "\n", + "out, attention_weights = self.coattn(x_omic, x_path, x_path)\n", + "out = self.transformer(out)\n", + "out = self.conv(out.squeeze(1).T.unsqueeze(0))\n", + "#out = self.classifier(out.squeeze(0).squeeze(1))" + ] + }, + { + "cell_type": "code", + "execution_count": 471, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 256, 1])" + ] + }, + "execution_count": 471, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 472, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[ 0.5998, 1.9873, -1.1435, ..., -0.0048, 0.2963, 1.1112]],\n", + "\n", + " [[-0.4201, -0.1456, 0.2057, ..., -0.2175, 0.4188, 0.4702]],\n", + "\n", + " [[ 1.0294, 3.1634, 0.4595, ..., 1.2059, 0.5845, 1.4114]],\n", + "\n", + " [[-1.1435, -1.1435, -1.1435, ..., 0.1951, -0.4378, 0.2051]],\n", + "\n", + " [[ 0.9948, 1.1596, 2.1419, ..., -0.1225, 1.3597, -0.3037]],\n", + "\n", + " [[ 0.4019, -1.1435, -0.1522, ..., -0.2058, 0.0351, -1.1435]]],\n", + " grad_fn=<UnsqueezeBackward0>)" + ] + }, + "execution_count": 472, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x_omic" + ] + }, + { + "cell_type": "code", + "execution_count": 474, + "metadata": {}, + "outputs": [], + "source": [ + "self = MM_CoAttn_Surv(sig_sizes=sig_sizes)\n", + "x_path = torch.randint(10, size=(500, 1024)).type(torch.FloatTensor)\n", + "sig_feats = [torch.randint(10, size=(size,)).type(torch.FloatTensor) for size in sig_sizes]\n", + "\n", + "x_path = self.attention_net(x_path).unsqueeze(1)\n", + "x_omic = torch.stack([self.sig_networks[idx].forward(sig_feat) for idx, sig_feat in enumerate(sig_feats)]).unsqueeze(1)\n", + "out, attention_weights = self.coattn(x_omic, x_path, x_path)\n", + "\n", + "out = self.transformer(out)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 491, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1536])" + ] + }, + "execution_count": 491, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.cat([self.sig_networks[idx].forward(sig_feat) for idx, sig_feat in enumerate(sig_feats)]).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 484, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([6, 1, 512])" + ] + }, + "execution_count": 484, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.cat([out, out], axis=2).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 455, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([6, 1, 256])" + ] + }, + "execution_count": 455, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 452, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([6, 1, 256])" + ] + }, + "execution_count": 452, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 423, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 8, 6, 500])" + ] + }, + "execution_count": 423, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "attention_weights.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 415, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[0.0018, 0.0020, 0.0012, ..., 0.0016, 0.0025, 0.0031],\n", + " [0.0026, 0.0015, 0.0016, ..., 0.0021, 0.0021, 0.0016],\n", + " [0.0019, 0.0014, 0.0011, ..., 0.0020, 0.0013, 0.0025],\n", + " [0.0016, 0.0013, 0.0023, ..., 0.0009, 0.0015, 0.0027],\n", + " [0.0015, 0.0013, 0.0023, ..., 0.0026, 0.0019, 0.0026],\n", + " [0.0013, 0.0019, 0.0025, ..., 0.0022, 0.0020, 0.0021]]],\n", + " grad_fn=<DivBackward0>)" + ] + }, + "execution_count": 415, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "attention_weights_0" + ] + }, + { + "cell_type": "code", + "execution_count": 416, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[0.0018, 0.0020, 0.0012, ..., 0.0016, 0.0025, 0.0031],\n", + " [0.0026, 0.0015, 0.0016, ..., 0.0021, 0.0021, 0.0016],\n", + " [0.0019, 0.0014, 0.0011, ..., 0.0020, 0.0013, 0.0025],\n", + " [0.0016, 0.0013, 0.0023, ..., 0.0009, 0.0015, 0.0027],\n", + " [0.0015, 0.0013, 0.0023, ..., 0.0026, 0.0019, 0.0026],\n", + " [0.0013, 0.0019, 0.0025, ..., 0.0022, 0.0020, 0.0021]]],\n", + " grad_fn=<DivBackward0>)" + ] + }, + "execution_count": 416, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "softmax(attention_weights_1, dim=-1).sum(axis=1) / 8" + ] + }, + { + "cell_type": "code", + "execution_count": 411, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 1, 6, 500])" + ] + }, + "execution_count": 411, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "softmax(attention_weights_1, dim=-1).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 339, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(1.0000, grad_fn=<SumBackward0>)" + ] + }, + "execution_count": 339, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "attention_weights_0[0][0].sum()" + ] + }, + { + "cell_type": "code", + "execution_count": 396, + "metadata": {}, + "outputs": [], + "source": [ + "test = softmax(attention_weights_2, dim=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": 402, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0.0024, 0.0030, 0.0019, 0.0018, 0.0038, 0.0015, 0.0020, 0.0016, 0.0015,\n", + " 0.0019, 0.0015, 0.0035, 0.0026, 0.0017, 0.0014, 0.0013, 0.0023, 0.0020,\n", + " 0.0017, 0.0010], grad_fn=<SliceBackward>)" + ] + }, + "execution_count": 402, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "attention_weights_0[0][0][:20]" + ] + }, + { + "cell_type": "code", + "execution_count": 404, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0.0028, 0.0033, 0.0019, 0.0013, 0.0042, 0.0016, 0.0024, 0.0018, 0.0019,\n", + " 0.0024, 0.0016, 0.0033, 0.0022, 0.0014, 0.0016, 0.0013, 0.0023, 0.0021,\n", + " 0.0013, 0.0013], grad_fn=<SliceBackward>)" + ] + }, + "execution_count": 404, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test[0][0][:20]" + ] + }, + { + "cell_type": "code", + "execution_count": 366, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[False, False, False, ..., False, False, False],\n", + " [False, False, False, ..., False, False, False],\n", + " [False, False, False, ..., False, False, False],\n", + " [False, False, False, ..., False, False, False],\n", + " [False, False, False, ..., False, False, False],\n", + " [False, False, False, ..., False, False, False]]])" + ] + }, + "execution_count": 366, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.eq(attention_weights_0, test)" + ] + }, + { + "cell_type": "code", + "execution_count": 320, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 8, 6, 500])" + ] + }, + "execution_count": 320, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "attention_weights_1.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 318, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 6, 500])" + ] + }, + "execution_count": 318, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "attention_weights_2.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 282, + "metadata": {}, + "outputs": [], + "source": [ + "out = self.classifier(out.squeeze(0).squeeze(1))" + ] + }, + { + "cell_type": "code", + "execution_count": 284, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 0.2832, 0.1548, -0.0972, -0.2801], grad_fn=<AddBackward0>)" + ] + }, + "execution_count": 284, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out" + ] + }, + { + "cell_type": "code", + "execution_count": 269, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[0.0018, 0.0019, 0.0019, ..., 0.0019, 0.0022, 0.0018],\n", + " [0.0020, 0.0020, 0.0021, ..., 0.0021, 0.0020, 0.0020],\n", + " [0.0019, 0.0022, 0.0021, ..., 0.0019, 0.0019, 0.0020],\n", + " [0.0021, 0.0022, 0.0019, ..., 0.0018, 0.0020, 0.0021],\n", + " [0.0019, 0.0019, 0.0020, ..., 0.0020, 0.0018, 0.0019],\n", + " [0.0021, 0.0021, 0.0019, ..., 0.0019, 0.0021, 0.0021]],\n", + " grad_fn=<SelectBackward>)" + ] + }, + "execution_count": 269, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "attention_weights[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 241, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([[[-0.0504, 0.0757, -0.0366, ..., -0.0275, -0.0294, 0.1300]],\n", + " \n", + " [[-0.0500, 0.0762, -0.0352, ..., -0.0253, -0.0289, 0.1311]],\n", + " \n", + " [[-0.0497, 0.0772, -0.0321, ..., -0.0246, -0.0288, 0.1301]],\n", + " \n", + " [[-0.0491, 0.0794, -0.0337, ..., -0.0260, -0.0278, 0.1281]],\n", + " \n", + " [[-0.0483, 0.0781, -0.0343, ..., -0.0246, -0.0301, 0.1321]],\n", + " \n", + " [[-0.0499, 0.0768, -0.0305, ..., -0.0257, -0.0280, 0.1321]]],\n", + " grad_fn=<AddBackward0>),\n", + " tensor([[[0.0019, 0.0019, 0.0019, ..., 0.0020, 0.0021, 0.0021],\n", + " [0.0017, 0.0020, 0.0020, ..., 0.0019, 0.0019, 0.0018],\n", + " [0.0019, 0.0018, 0.0019, ..., 0.0019, 0.0019, 0.0021],\n", + " [0.0020, 0.0020, 0.0019, ..., 0.0020, 0.0021, 0.0019],\n", + " [0.0017, 0.0023, 0.0021, ..., 0.0019, 0.0020, 0.0020],\n", + " [0.0021, 0.0021, 0.0020, ..., 0.0021, 0.0021, 0.0020]]],\n", + " grad_fn=<DivBackward0>))" + ] + }, + "execution_count": 241, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "self.coattn(x_omic, x_path, x_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "h" + ] + }, + { + "cell_type": "code", + "execution_count": 208, + "metadata": {}, + "outputs": [], + "source": [ + "sig_feats = [torch.randn(size) for size in sig_sizes]\n", + "x_omic = torch.stack([self.sig_networks[idx].forward(sig_feat) for idx, sig_feat in enumerate(sig_feats)])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 204, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 206, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([6, 256])" + ] + }, + "execution_count": 206, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x_omic.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 166, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'sig1' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m<ipython-input-166-aea4cb4c555c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0msig1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msig2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msig3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msig4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msig5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msig6\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name 'sig1' is not defined" + ] + } + ], + "source": [ + "sig1, sig2, sig3, sig4, sig5, sig6 = torch.randn()" + ] + }, + { + "cell_type": "code", + "execution_count": 158, + "metadata": {}, + "outputs": [], + "source": [ + "src = torch.rand(6, 1, 256)\n", + "out = transformer(src)\n", + "out = out.squeeze(1).T.unsqueeze(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 163, + "metadata": {}, + "outputs": [], + "source": [ + "conv = nn.Conv1d(in_channels=256, out_channels=256, kernel_size=4, stride=4)" + ] + }, + { + "cell_type": "code", + "execution_count": 164, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 256, 6])" + ] + }, + "execution_count": 164, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 165, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 256, 1])" + ] + }, + "execution_count": 165, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "conv(out).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 112, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1536])" + ] + }, + "execution_count": 112, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x.reshape(-1).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 106, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3072" + ] + }, + "execution_count": 106, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "256 * 12" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "metadata": {}, + "outputs": [], + "source": [ + "net = Attn_Net_Gated()\n", + "wsi_feats = torch.randn(500, 1, 256)\n", + "sig_feats = torch.randn(6, 1, 256)" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "metadata": {}, + "outputs": [], + "source": [ + "multihead_attn = nn.MultiheadAttention(embed_dim=256, num_heads=8)" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "metadata": {}, + "outputs": [], + "source": [ + "out, coattn_weights = multihead_attn(sig_feats, wsi_feats, wsi_feats)" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "metadata": {}, + "outputs": [], + "source": [ + "cotton = DenseCoAttn(dim1=256, dim2=256, num_attn=8, num_none=3, dropout=0.3b)" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "metadata": {}, + "outputs": [], + "source": [ + "from math import sqrt\n", + "wsi_feats = torch.randn(1, 500, 256)\n", + "sig_feats = torch.randn(1, 6, 256)\n", + "_ = cotton(wsi_feats, sig_feats)" + ] + }, + { + "cell_type": "code", + "execution_count": 103, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 6, 256])" + ] + }, + "execution_count": 103, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "_[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 104, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 500, 256])" + ] + }, + "execution_count": 104, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "_[1].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "\n", + "def qkv_attention(query, key, value, mask=None, dropout=None):\n", + "\td_k = query.size(-1)\n", + "\tscores = torch.matmul(query, key.transpose(-2,-1)) / sqrt(d_k)\n", + "\tif mask is not None:\n", + "\t\tscores.data.masked_fill_(mask.eq(0), -65504.0)\n", + "\t\n", + "\tp_attn = F.softmax(scores, dim=-1)\n", + "\tif dropout is not None:\n", + "\t\tp_attn = dropout(p_attn)\n", + "\n", + "\treturn torch.matmul(p_attn, value), p_attn\n", + "\n", + "\n", + "class DenseCoAttn(nn.Module):\n", + "\n", + "\tdef __init__(self, dim1, dim2, num_attn, num_none, dropout, is_multi_head=False):\n", + "\t\tsuper(DenseCoAttn, self).__init__()\n", + "\t\tdim = min(dim1, dim2)\n", + "\t\tself.linears = nn.ModuleList([nn.Linear(dim1, dim, bias=False),\n", + "\t\t\t\t\t\t\t\t\t nn.Linear(dim2, dim, bias=False)])\n", + "\t\tself.nones = nn.ParameterList([nn.Parameter(nn.init.xavier_uniform_(torch.empty(num_none, dim1))),\n", + "\t\t\t\t\t\t\t\t\t nn.Parameter(nn.init.xavier_uniform_(torch.empty(num_none, dim2)))])\n", + "\t\tself.d_k = dim // num_attn\n", + "\t\tself.h = num_attn\n", + "\t\tself.num_none = num_none\n", + "\t\tself.is_multi_head = is_multi_head\n", + "\t\tself.attn = None\n", + "\t\tself.dropouts = nn.ModuleList([nn.Dropout(p=dropout) for _ in range(2)])\n", + "\n", + "\tdef forward(self, value1, value2, mask1=None, mask2=None):\n", + "\t\tbatch = value1.size(0)\n", + "\t\tdim1, dim2 = value1.size(-1), value2.size(-1)\n", + "\t\tvalue1 = torch.cat([self.nones[0].unsqueeze(0).expand(batch, self.num_none, dim1), value1], dim=1)\n", + "\t\tvalue2 = torch.cat([self.nones[1].unsqueeze(0).expand(batch, self.num_none, dim2), value2], dim=1)\n", + "\t\tnone_mask = value1.new_ones((batch, self.num_none))\n", + "\n", + "\t\tif mask1 is not None:\n", + "\t\t\tmask1 = torch.cat([none_mask, mask1], dim=1)\n", + "\t\t\tmask1 = mask1.unsqueeze(1).unsqueeze(2)\n", + "\t\tif mask2 is not None:\n", + "\t\t\tmask2 = torch.cat([none_mask, mask2], dim=1)\n", + "\t\t\tmask2 = mask2.unsqueeze(1).unsqueeze(2)\n", + "\n", + "\t\tquery1, query2 = [l(x).view(batch, -1, self.h, self.d_k).transpose(1, 2) \n", + "\t\t\tfor l, x in zip(self.linears, (value1, value2))]\n", + "\n", + "\t\tif self.is_multi_head:\n", + "\t\t\tweighted1, attn1 = qkv_attention(query2, query1, query1, mask=mask1, dropout=self.dropouts[0])\n", + "\t\t\tweighted1 = weighted1.transpose(1, 2).contiguous()[:, self.num_none:, :]\n", + "\t\t\tweighted2, attn2 = qkv_attention(query1, query2, query2, mask=mask2, dropout=self.dropouts[1])\n", + "\t\t\tweighted2 = weighted2.transpose(1, 2).contiguous()[:, self.num_none:, :]\n", + "\t\telse:\n", + "\t\t\tweighted1, attn1 = qkv_attention(query2, query1, value1.unsqueeze(1), mask=mask1, \n", + "\t\t\t\tdropout=self.dropouts[0])\n", + "\t\t\tweighted1 = weighted1.mean(dim=1)[:, self.num_none:, :]\n", + "\t\t\tweighted2, attn2 = qkv_attention(query1, query2, value2.unsqueeze(1), mask=mask2, \n", + "\t\t\t\tdropout=self.dropouts[1])\n", + "\t\t\tweighted2 = weighted2.mean(dim=1)[:, self.num_none:, :]\n", + "\t\tself.attn = [attn1[:,:,self.num_none:,self.num_none:], attn2[:,:,self.num_none:,self.num_none:]]\n", + "\n", + "\t\treturn weighted1, weighted2\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 417, + "metadata": {}, + "outputs": [], + "source": [ + "from torch.nn.functional import *\n", + "\n", + "def multi_head_attention_forward(\n", + " query: Tensor,\n", + " key: Tensor,\n", + " value: Tensor,\n", + " embed_dim_to_check: int,\n", + " num_heads: int,\n", + " in_proj_weight: Tensor,\n", + " in_proj_bias: Tensor,\n", + " bias_k: Optional[Tensor],\n", + " bias_v: Optional[Tensor],\n", + " add_zero_attn: bool,\n", + " dropout_p: float,\n", + " out_proj_weight: Tensor,\n", + " out_proj_bias: Tensor,\n", + " training: bool = True,\n", + " key_padding_mask: Optional[Tensor] = None,\n", + " need_weights: bool = True,\n", + " need_raw: bool = True,\n", + " attn_mask: Optional[Tensor] = None,\n", + " use_separate_proj_weight: bool = False,\n", + " q_proj_weight: Optional[Tensor] = None,\n", + " k_proj_weight: Optional[Tensor] = None,\n", + " v_proj_weight: Optional[Tensor] = None,\n", + " static_k: Optional[Tensor] = None,\n", + " static_v: Optional[Tensor] = None,\n", + ") -> Tuple[Tensor, Optional[Tensor]]:\n", + " r\"\"\"\n", + " Args:\n", + " query, key, value: map a query and a set of key-value pairs to an output.\n", + " See \"Attention Is All You Need\" for more details.\n", + " embed_dim_to_check: total dimension of the model.\n", + " num_heads: parallel attention heads.\n", + " in_proj_weight, in_proj_bias: input projection weight and bias.\n", + " bias_k, bias_v: bias of the key and value sequences to be added at dim=0.\n", + " add_zero_attn: add a new batch of zeros to the key and\n", + " value sequences at dim=1.\n", + " dropout_p: probability of an element to be zeroed.\n", + " out_proj_weight, out_proj_bias: the output projection weight and bias.\n", + " training: apply dropout if is ``True``.\n", + " key_padding_mask: if provided, specified padding elements in the key will\n", + " be ignored by the attention. This is an binary mask. When the value is True,\n", + " the corresponding value on the attention layer will be filled with -inf.\n", + " need_weights: output attn_output_weights.\n", + " attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all\n", + " the batches while a 3D mask allows to specify a different mask for the entries of each batch.\n", + " use_separate_proj_weight: the function accept the proj. weights for query, key,\n", + " and value in different forms. If false, in_proj_weight will be used, which is\n", + " a combination of q_proj_weight, k_proj_weight, v_proj_weight.\n", + " q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.\n", + " static_k, static_v: static key and value used for attention operators.\n", + " Shape:\n", + " Inputs:\n", + " - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is\n", + " the embedding dimension.\n", + " - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is\n", + " the embedding dimension.\n", + " - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is\n", + " the embedding dimension.\n", + " - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.\n", + " If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions\n", + " will be unchanged. If a BoolTensor is provided, the positions with the\n", + " value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.\n", + " - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.\n", + " 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,\n", + " S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked\n", + " positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend\n", + " while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``\n", + " are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor\n", + " is provided, it will be added to the attention weight.\n", + " - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,\n", + " N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.\n", + " - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,\n", + " N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.\n", + " Outputs:\n", + " - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,\n", + " E is the embedding dimension.\n", + " - attn_output_weights: :math:`(N, L, S)` where N is the batch size,\n", + " L is the target sequence length, S is the source sequence length.\n", + " \"\"\"\n", + " tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)\n", + " if has_torch_function(tens_ops):\n", + " return handle_torch_function(\n", + " multi_head_attention_forward,\n", + " tens_ops,\n", + " query,\n", + " key,\n", + " value,\n", + " embed_dim_to_check,\n", + " num_heads,\n", + " in_proj_weight,\n", + " in_proj_bias,\n", + " bias_k,\n", + " bias_v,\n", + " add_zero_attn,\n", + " dropout_p,\n", + " out_proj_weight,\n", + " out_proj_bias,\n", + " training=training,\n", + " key_padding_mask=key_padding_mask,\n", + " need_weights=need_weights,\n", + " need_raw=need_raw,\n", + " attn_mask=attn_mask,\n", + " use_separate_proj_weight=use_separate_proj_weight,\n", + " q_proj_weight=q_proj_weight,\n", + " k_proj_weight=k_proj_weight,\n", + " v_proj_weight=v_proj_weight,\n", + " static_k=static_k,\n", + " static_v=static_v,\n", + " )\n", + " tgt_len, bsz, embed_dim = query.size()\n", + " assert embed_dim == embed_dim_to_check\n", + " # allow MHA to have different sizes for the feature dimension\n", + " assert key.size(0) == value.size(0) and key.size(1) == value.size(1)\n", + "\n", + " head_dim = embed_dim // num_heads\n", + " assert head_dim * num_heads == embed_dim, \"embed_dim must be divisible by num_heads\"\n", + " scaling = float(head_dim) ** -0.5\n", + "\n", + " if not use_separate_proj_weight:\n", + " if (query is key or torch.equal(query, key)) and (key is value or torch.equal(key, value)):\n", + " # self-attention\n", + " q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)\n", + "\n", + " elif key is value or torch.equal(key, value):\n", + " # encoder-decoder attention\n", + " # This is inline in_proj function with in_proj_weight and in_proj_bias\n", + " _b = in_proj_bias\n", + " _start = 0\n", + " _end = embed_dim\n", + " _w = in_proj_weight[_start:_end, :]\n", + " if _b is not None:\n", + " _b = _b[_start:_end]\n", + " q = linear(query, _w, _b)\n", + "\n", + " if key is None:\n", + " assert value is None\n", + " k = None\n", + " v = None\n", + " else:\n", + "\n", + " # This is inline in_proj function with in_proj_weight and in_proj_bias\n", + " _b = in_proj_bias\n", + " _start = embed_dim\n", + " _end = None\n", + " _w = in_proj_weight[_start:, :]\n", + " if _b is not None:\n", + " _b = _b[_start:]\n", + " k, v = linear(key, _w, _b).chunk(2, dim=-1)\n", + "\n", + " else:\n", + " # This is inline in_proj function with in_proj_weight and in_proj_bias\n", + " _b = in_proj_bias\n", + " _start = 0\n", + " _end = embed_dim\n", + " _w = in_proj_weight[_start:_end, :]\n", + " if _b is not None:\n", + " _b = _b[_start:_end]\n", + " q = linear(query, _w, _b)\n", + "\n", + " # This is inline in_proj function with in_proj_weight and in_proj_bias\n", + " _b = in_proj_bias\n", + " _start = embed_dim\n", + " _end = embed_dim * 2\n", + " _w = in_proj_weight[_start:_end, :]\n", + " if _b is not None:\n", + " _b = _b[_start:_end]\n", + " k = linear(key, _w, _b)\n", + "\n", + " # This is inline in_proj function with in_proj_weight and in_proj_bias\n", + " _b = in_proj_bias\n", + " _start = embed_dim * 2\n", + " _end = None\n", + " _w = in_proj_weight[_start:, :]\n", + " if _b is not None:\n", + " _b = _b[_start:]\n", + " v = linear(value, _w, _b)\n", + " else:\n", + " q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)\n", + " len1, len2 = q_proj_weight_non_opt.size()\n", + " assert len1 == embed_dim and len2 == query.size(-1)\n", + "\n", + " k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)\n", + " len1, len2 = k_proj_weight_non_opt.size()\n", + " assert len1 == embed_dim and len2 == key.size(-1)\n", + "\n", + " v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)\n", + " len1, len2 = v_proj_weight_non_opt.size()\n", + " assert len1 == embed_dim and len2 == value.size(-1)\n", + "\n", + " if in_proj_bias is not None:\n", + " q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])\n", + " k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim : (embed_dim * 2)])\n", + " v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2) :])\n", + " else:\n", + " q = linear(query, q_proj_weight_non_opt, in_proj_bias)\n", + " k = linear(key, k_proj_weight_non_opt, in_proj_bias)\n", + " v = linear(value, v_proj_weight_non_opt, in_proj_bias)\n", + " q = q * scaling\n", + "\n", + " if attn_mask is not None:\n", + " assert (\n", + " attn_mask.dtype == torch.float32\n", + " or attn_mask.dtype == torch.float64\n", + " or attn_mask.dtype == torch.float16\n", + " or attn_mask.dtype == torch.uint8\n", + " or attn_mask.dtype == torch.bool\n", + " ), \"Only float, byte, and bool types are supported for attn_mask, not {}\".format(attn_mask.dtype)\n", + " if attn_mask.dtype == torch.uint8:\n", + " warnings.warn(\"Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.\")\n", + " attn_mask = attn_mask.to(torch.bool)\n", + "\n", + " if attn_mask.dim() == 2:\n", + " attn_mask = attn_mask.unsqueeze(0)\n", + " if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:\n", + " raise RuntimeError(\"The size of the 2D attn_mask is not correct.\")\n", + " elif attn_mask.dim() == 3:\n", + " if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:\n", + " raise RuntimeError(\"The size of the 3D attn_mask is not correct.\")\n", + " else:\n", + " raise RuntimeError(\"attn_mask's dimension {} is not supported\".format(attn_mask.dim()))\n", + " # attn_mask's dim is 3 now.\n", + "\n", + " # convert ByteTensor key_padding_mask to bool\n", + " if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:\n", + " warnings.warn(\n", + " \"Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.\"\n", + " )\n", + " key_padding_mask = key_padding_mask.to(torch.bool)\n", + "\n", + " if bias_k is not None and bias_v is not None:\n", + " if static_k is None and static_v is None:\n", + " k = torch.cat([k, bias_k.repeat(1, bsz, 1)])\n", + " v = torch.cat([v, bias_v.repeat(1, bsz, 1)])\n", + " if attn_mask is not None:\n", + " attn_mask = pad(attn_mask, (0, 1))\n", + " if key_padding_mask is not None:\n", + " key_padding_mask = pad(key_padding_mask, (0, 1))\n", + " else:\n", + " assert static_k is None, \"bias cannot be added to static key.\"\n", + " assert static_v is None, \"bias cannot be added to static value.\"\n", + " else:\n", + " assert bias_k is None\n", + " assert bias_v is None\n", + "\n", + " q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)\n", + " if k is not None:\n", + " k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)\n", + " if v is not None:\n", + " v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)\n", + "\n", + " if static_k is not None:\n", + " assert static_k.size(0) == bsz * num_heads\n", + " assert static_k.size(2) == head_dim\n", + " k = static_k\n", + "\n", + " if static_v is not None:\n", + " assert static_v.size(0) == bsz * num_heads\n", + " assert static_v.size(2) == head_dim\n", + " v = static_v\n", + "\n", + " src_len = k.size(1)\n", + "\n", + " if key_padding_mask is not None:\n", + " assert key_padding_mask.size(0) == bsz\n", + " assert key_padding_mask.size(1) == src_len\n", + "\n", + " if add_zero_attn:\n", + " src_len += 1\n", + " k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)\n", + " v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)\n", + " if attn_mask is not None:\n", + " attn_mask = pad(attn_mask, (0, 1))\n", + " if key_padding_mask is not None:\n", + " key_padding_mask = pad(key_padding_mask, (0, 1))\n", + "\n", + " attn_output_weights = torch.bmm(q, k.transpose(1, 2))\n", + " assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]\n", + "\n", + " if attn_mask is not None:\n", + " if attn_mask.dtype == torch.bool:\n", + " attn_output_weights.masked_fill_(attn_mask, float(\"-inf\"))\n", + " else:\n", + " attn_output_weights += attn_mask\n", + "\n", + " if key_padding_mask is not None:\n", + " attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)\n", + " attn_output_weights = attn_output_weights.masked_fill(\n", + " key_padding_mask.unsqueeze(1).unsqueeze(2),\n", + " float(\"-inf\"),\n", + " )\n", + " attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)\n", + " \n", + " attn_output_weights_raw = attn_output_weights\n", + " attn_output_weights = softmax(attn_output_weights, dim=-1)\n", + " attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)\n", + "\n", + " attn_output = torch.bmm(attn_output_weights, v)\n", + " assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]\n", + " attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)\n", + " attn_output = linear(attn_output, out_proj_weight, out_proj_bias)\n", + " \n", + " if need_weights:\n", + " if need_raw:\n", + " \n", + " attn_output_weights_raw = attn_output_weights_raw.view(bsz, num_heads, tgt_len, src_len)\n", + " return attn_output,attn_output_weights_raw\n", + " \n", + " #attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)\n", + " #return attn_output, attn_output_weights.sum(dim=1) / num_heads, attn_output_weights_raw, attn_output_weights_raw.sum(dim=1) / num_heads\n", + " else:\n", + " # average attention weights over heads\n", + " attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)\n", + " return attn_output, attn_output_weights.sum(dim=1) / num_heads\n", + " else:\n", + " return attn_output, None\n" + ] + }, + { + "cell_type": "code", + "execution_count": 418, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch import Tensor\n", + "from torch.nn.modules.linear import _LinearWithBias\n", + "from torch.nn.init import xavier_uniform_\n", + "from torch.nn.init import constant_\n", + "from torch.nn.init import xavier_normal_\n", + "from torch.nn.parameter import Parameter\n", + "from torch.nn import Module\n", + "\n", + "class MultiheadAttention(Module):\n", + " r\"\"\"Allows the model to jointly attend to information\n", + " from different representation subspaces.\n", + " See reference: Attention Is All You Need\n", + "\n", + " .. math::\n", + " \\text{MultiHead}(Q, K, V) = \\text{Concat}(head_1,\\dots,head_h)W^O\n", + " \\text{where} head_i = \\text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\n", + "\n", + " Args:\n", + " embed_dim: total dimension of the model.\n", + " num_heads: parallel attention heads.\n", + " dropout: a Dropout layer on attn_output_weights. Default: 0.0.\n", + " bias: add bias as module parameter. Default: True.\n", + " add_bias_kv: add bias to the key and value sequences at dim=0.\n", + " add_zero_attn: add a new batch of zeros to the key and\n", + " value sequences at dim=1.\n", + " kdim: total number of features in key. Default: None.\n", + " vdim: total number of features in value. Default: None.\n", + "\n", + " Note: if kdim and vdim are None, they will be set to embed_dim such that\n", + " query, key, and value have the same number of features.\n", + "\n", + " Examples::\n", + "\n", + " >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)\n", + " >>> attn_output, attn_output_weights = multihead_attn(query, key, value)\n", + " \"\"\"\n", + " bias_k: Optional[torch.Tensor]\n", + " bias_v: Optional[torch.Tensor]\n", + "\n", + " def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):\n", + " super(MultiheadAttention, self).__init__()\n", + " self.embed_dim = embed_dim\n", + " self.kdim = kdim if kdim is not None else embed_dim\n", + " self.vdim = vdim if vdim is not None else embed_dim\n", + " self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim\n", + "\n", + " self.num_heads = num_heads\n", + " self.dropout = dropout\n", + " self.head_dim = embed_dim // num_heads\n", + " assert self.head_dim * num_heads == self.embed_dim, \"embed_dim must be divisible by num_heads\"\n", + "\n", + " if self._qkv_same_embed_dim is False:\n", + " self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))\n", + " self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))\n", + " self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))\n", + " self.register_parameter('in_proj_weight', None)\n", + " else:\n", + " self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))\n", + " self.register_parameter('q_proj_weight', None)\n", + " self.register_parameter('k_proj_weight', None)\n", + " self.register_parameter('v_proj_weight', None)\n", + "\n", + " if bias:\n", + " self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))\n", + " else:\n", + " self.register_parameter('in_proj_bias', None)\n", + " self.out_proj = _LinearWithBias(embed_dim, embed_dim)\n", + "\n", + " if add_bias_kv:\n", + " self.bias_k = Parameter(torch.empty(1, 1, embed_dim))\n", + " self.bias_v = Parameter(torch.empty(1, 1, embed_dim))\n", + " else:\n", + " self.bias_k = self.bias_v = None\n", + "\n", + " self.add_zero_attn = add_zero_attn\n", + "\n", + " self._reset_parameters()\n", + "\n", + " def _reset_parameters(self):\n", + " if self._qkv_same_embed_dim:\n", + " xavier_uniform_(self.in_proj_weight)\n", + " else:\n", + " xavier_uniform_(self.q_proj_weight)\n", + " xavier_uniform_(self.k_proj_weight)\n", + " xavier_uniform_(self.v_proj_weight)\n", + "\n", + " if self.in_proj_bias is not None:\n", + " constant_(self.in_proj_bias, 0.)\n", + " constant_(self.out_proj.bias, 0.)\n", + " if self.bias_k is not None:\n", + " xavier_normal_(self.bias_k)\n", + " if self.bias_v is not None:\n", + " xavier_normal_(self.bias_v)\n", + "\n", + " def __setstate__(self, state):\n", + " # Support loading old MultiheadAttention checkpoints generated by v1.1.0\n", + " if '_qkv_same_embed_dim' not in state:\n", + " state['_qkv_same_embed_dim'] = True\n", + "\n", + " super(MultiheadAttention, self).__setstate__(state)\n", + "\n", + " def forward(self, query, key, value, key_padding_mask=None,\n", + " need_weights=True, need_raw=True, attn_mask=None):\n", + " # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]\n", + " r\"\"\"\n", + " Args:\n", + " query, key, value: map a query and a set of key-value pairs to an output.\n", + " See \"Attention Is All You Need\" for more details.\n", + " key_padding_mask: if provided, specified padding elements in the key will\n", + " be ignored by the attention. When given a binary mask and a value is True,\n", + " the corresponding value on the attention layer will be ignored. When given\n", + " a byte mask and a value is non-zero, the corresponding value on the attention\n", + " layer will be ignored\n", + " need_weights: output attn_output_weights.\n", + " attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all\n", + " the batches while a 3D mask allows to specify a different mask for the entries of each batch.\n", + "\n", + " Shape:\n", + " - Inputs:\n", + " - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is\n", + " the embedding dimension.\n", + " - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is\n", + " the embedding dimension.\n", + " - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is\n", + " the embedding dimension.\n", + " - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.\n", + " If a ByteTensor is provided, the non-zero positions will be ignored while the position\n", + " with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the\n", + " value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.\n", + " - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.\n", + " 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,\n", + " S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked\n", + " positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend\n", + " while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``\n", + " is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor\n", + " is provided, it will be added to the attention weight.\n", + "\n", + " - Outputs:\n", + " - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,\n", + " E is the embedding dimension.\n", + " - attn_output_weights: :math:`(N, L, S)` where N is the batch size,\n", + " L is the target sequence length, S is the source sequence length.\n", + " \"\"\"\n", + " if not self._qkv_same_embed_dim:\n", + " return multi_head_attention_forward(\n", + " query, key, value, self.embed_dim, self.num_heads,\n", + " self.in_proj_weight, self.in_proj_bias,\n", + " self.bias_k, self.bias_v, self.add_zero_attn,\n", + " self.dropout, self.out_proj.weight, self.out_proj.bias,\n", + " training=self.training,\n", + " key_padding_mask=key_padding_mask, need_weights=need_weights, need_raw=need_raw,\n", + " attn_mask=attn_mask, use_separate_proj_weight=True,\n", + " q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,\n", + " v_proj_weight=self.v_proj_weight)\n", + " else:\n", + " return multi_head_attention_forward(\n", + " query, key, value, self.embed_dim, self.num_heads,\n", + " self.in_proj_weight, self.in_proj_bias,\n", + " self.bias_k, self.bias_v, self.add_zero_attn,\n", + " self.dropout, self.out_proj.weight, self.out_proj.bias,\n", + " training=self.training,\n", + " key_padding_mask=key_padding_mask, need_weights=need_weights, need_raw=need_raw,\n", + " attn_mask=attn_mask)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 104, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'torch'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m<ipython-input-104-6bb47b25d46a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch'" + ] + } + ], + "source": [ + "import math\n", + "\n", + "import torch\n", + "from torch import nn\n", + "\n", + "############\n", + "# Omic Model\n", + "############\n", + "def init_max_weights(module):\n", + " for m in module.modules():\n", + " if type(m) == nn.Linear:\n", + " stdv = 1. / math.sqrt(m.weight.size(1))\n", + " m.weight.data.normal_(0, stdv)\n", + " m.bias.data.zero_()\n", + "\n", + "def SNN_Block(dim1, dim2, dropout=0.25):\n", + " return nn.Sequential(\n", + " nn.Linear(dim1, dim2),\n", + " nn.ELU(),\n", + " nn.AlphaDropout(p=dropout, inplace=False))\n", + "\n", + "class MaxNet(nn.Module):\n", + " def __init__(self, input_dim: int, meta_dim: int=0, model_size_omic: str='small', n_classes: int=4):\n", + " super(MaxNet, self).__init__()\n", + " self.meta_dim = meta_dim\n", + " self.n_classes = n_classes\n", + " self.size_dict_omic = {'small': [256, 256, 256, 256], 'big': [1024, 1024, 1024, 256]}\n", + " \n", + " ### Constructing Genomic SNN\n", + " hidden = self.size_dict_omic[model_size_omic]\n", + " fc_omic = [SNN_Block(dim1=input_dim, dim2=hidden[0])]\n", + " for i, _ in enumerate(hidden[1:]):\n", + " fc_omic.append(SNN_Block(dim1=hidden[i], dim2=hidden[i+1], dropout=0.25))\n", + " self.fc_omic = nn.Sequential(*fc_omic)\n", + " self.classifier = nn.Linear(hidden[-1]+self.meta_dim, n_classes)\n", + " init_max_weights(self)\n", + "\n", + " def forward(self, **kwargs):\n", + " x = kwargs['x_omic']\n", + " meta = kwargs['meta']\n", + " features = self.fc_omic(x)\n", + "\n", + " if self.meta_dim: \n", + " axis_dim = 1 if len(meta.shape) > 1 else 0\n", + " features = torch.cat((features, meta), axis_dim)\n", + "\n", + " logits = self.classifier(features).unsqueeze(0)\n", + " Y_hat = torch.topk(logits, 1, dim=1)[1]\n", + " hazards = torch.sigmoid(logits)\n", + " S = torch.cumprod(1 - hazards, dim=1)\n", + " return hazards, S, Y_hat, None, None\n", + "\n", + " def relocate(self):\n", + " device=torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + " if torch.cuda.device_count() > 1:\n", + " device_ids = list(range(torch.cuda.device_count()))\n", + " self.fc_omic = nn.DataParallel(self.fc_omic, device_ids=device_ids).to('cuda:0')\n", + " else:\n", + " self.fc_omic = self.fc_omic.to(device)\n", + "\n", + "\n", + " self.classifier = self.classifier.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 88, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>CXCL14_rnaseq</th>\n", + " <th>FGF1_rnaseq</th>\n", + " <th>IFNA8_cnv</th>\n", + " <th>ADM_rnaseq</th>\n", + " <th>LTBP2_rnaseq</th>\n", + " <th>CCL28_rnaseq</th>\n", + " <th>IFNA7_rnaseq</th>\n", + " <th>GH2_rnaseq</th>\n", + " <th>AIMP1_rnaseq</th>\n", + " <th>DEFB1_rnaseq</th>\n", + " <th>...</th>\n", + " <th>NPPB_rnaseq</th>\n", + " <th>CCL27_rnaseq</th>\n", + " <th>FASLG_rnaseq</th>\n", + " <th>FGF20_cnv</th>\n", + " <th>FAM3C_rnaseq</th>\n", + " <th>IL18_rnaseq</th>\n", + " <th>GDF10_rnaseq</th>\n", + " <th>MYDGF_rnaseq</th>\n", + " <th>IL10_rnaseq</th>\n", + " <th>IFNW1_rnaseq</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>-0.1170</td>\n", + " <td>-0.2221</td>\n", + " <td>1</td>\n", + " <td>-0.5126</td>\n", + " <td>-0.3289</td>\n", + " <td>-0.7331</td>\n", + " <td>-0.1244</td>\n", + " <td>-0.1693</td>\n", + " <td>0.5942</td>\n", + " <td>-0.4707</td>\n", + " <td>...</td>\n", + " <td>-0.2276</td>\n", + " <td>1.2033</td>\n", + " <td>0.9826</td>\n", + " <td>-1</td>\n", + " <td>-0.6161</td>\n", + " <td>-0.5643</td>\n", + " <td>-0.2165</td>\n", + " <td>-0.2836</td>\n", + " <td>0.9991</td>\n", + " <td>-0.3899</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>-0.2330</td>\n", + " <td>-0.4343</td>\n", + " <td>-1</td>\n", + " <td>-0.2381</td>\n", + " <td>-0.4799</td>\n", + " <td>-0.0520</td>\n", + " <td>-0.1244</td>\n", + " <td>-0.1693</td>\n", + " <td>1.1854</td>\n", + " <td>-0.4820</td>\n", + " <td>...</td>\n", + " <td>-0.2276</td>\n", + " <td>-0.2946</td>\n", + " <td>-0.5443</td>\n", + " <td>-1</td>\n", + " <td>-0.3499</td>\n", + " <td>-0.7958</td>\n", + " <td>-0.3140</td>\n", + " <td>-0.3359</td>\n", + " <td>-0.4865</td>\n", + " <td>-0.3899</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>-0.1384</td>\n", + " <td>-0.1597</td>\n", + " <td>-1</td>\n", + " <td>-0.1521</td>\n", + " <td>-0.3348</td>\n", + " <td>-0.5310</td>\n", + " <td>-0.1244</td>\n", + " <td>-0.1693</td>\n", + " <td>0.3889</td>\n", + " <td>-0.3607</td>\n", + " <td>...</td>\n", + " <td>3.4177</td>\n", + " <td>-0.2946</td>\n", + " <td>-0.5320</td>\n", + " <td>0</td>\n", + " <td>0.4581</td>\n", + " <td>-0.6179</td>\n", + " <td>-0.2107</td>\n", + " <td>0.2751</td>\n", + " <td>-0.5108</td>\n", + " <td>1.0629</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>-0.1624</td>\n", + " <td>-0.3463</td>\n", + " <td>-1</td>\n", + " <td>0.0272</td>\n", + " <td>-0.7623</td>\n", + " <td>0.8196</td>\n", + " <td>-0.1244</td>\n", + " <td>-0.1693</td>\n", + " <td>-0.0416</td>\n", + " <td>0.1661</td>\n", + " <td>...</td>\n", + " <td>-0.2276</td>\n", + " <td>-0.1020</td>\n", + " <td>-0.4682</td>\n", + " <td>-1</td>\n", + " <td>-0.4391</td>\n", + " <td>-0.7275</td>\n", + " <td>-0.2876</td>\n", + " <td>-0.4696</td>\n", + " <td>-0.6248</td>\n", + " <td>-0.3899</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>-0.2346</td>\n", + " <td>-0.4090</td>\n", + " <td>-1</td>\n", + " <td>-0.2078</td>\n", + " <td>0.5702</td>\n", + " <td>-0.4219</td>\n", + " <td>-0.1244</td>\n", + " <td>0.5257</td>\n", + " <td>-0.9790</td>\n", + " <td>0.3938</td>\n", + " <td>...</td>\n", + " <td>-0.2276</td>\n", + " <td>-0.1035</td>\n", + " <td>-0.4688</td>\n", + " <td>-1</td>\n", + " <td>1.2596</td>\n", + " <td>-0.5807</td>\n", + " <td>0.4108</td>\n", + " <td>0.1801</td>\n", + " <td>-0.6086</td>\n", + " <td>-0.3899</td>\n", + " </tr>\n", + " <tr>\n", + " <th>...</th>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " </tr>\n", + " <tr>\n", + " <th>368</th>\n", + " <td>-0.2417</td>\n", + " <td>10.1423</td>\n", + " <td>-1</td>\n", + " <td>-0.5456</td>\n", + " <td>0.8742</td>\n", + " <td>-0.1822</td>\n", + " <td>-0.1244</td>\n", + " <td>-0.1693</td>\n", + " <td>-1.2395</td>\n", + " <td>-0.5125</td>\n", + " <td>...</td>\n", + " <td>-0.2276</td>\n", + " <td>-0.2946</td>\n", + " <td>0.0777</td>\n", + " <td>0</td>\n", + " <td>-0.8242</td>\n", + " <td>-0.6727</td>\n", + " <td>0.1938</td>\n", + " <td>0.9210</td>\n", + " <td>0.4479</td>\n", + " <td>-0.3899</td>\n", + " </tr>\n", + " <tr>\n", + " <th>369</th>\n", + " <td>-0.2412</td>\n", + " <td>1.3253</td>\n", + " <td>1</td>\n", + " <td>-0.5680</td>\n", + " <td>1.0719</td>\n", + " <td>-0.1707</td>\n", + " <td>-0.1244</td>\n", + " <td>-0.1693</td>\n", + " <td>-1.6694</td>\n", + " <td>-0.4528</td>\n", + " <td>...</td>\n", + " <td>0.5679</td>\n", + " <td>-0.2661</td>\n", + " <td>1.0215</td>\n", + " <td>-2</td>\n", + " <td>-0.5327</td>\n", + " <td>0.3335</td>\n", + " <td>-0.1730</td>\n", + " <td>0.0147</td>\n", + " <td>0.6012</td>\n", + " <td>2.2526</td>\n", + " </tr>\n", + " <tr>\n", + " <th>370</th>\n", + " <td>-0.2396</td>\n", + " <td>0.0435</td>\n", + " <td>0</td>\n", + " <td>-0.3610</td>\n", + " <td>3.1965</td>\n", + " <td>1.3670</td>\n", + " <td>-0.1244</td>\n", + " <td>-0.1693</td>\n", + " <td>0.4439</td>\n", + " <td>-0.5099</td>\n", + " <td>...</td>\n", + " <td>-0.2276</td>\n", + " <td>-0.2289</td>\n", + " <td>0.0521</td>\n", + " <td>-1</td>\n", + " <td>1.0317</td>\n", + " <td>-0.1473</td>\n", + " <td>-0.1517</td>\n", + " <td>0.9384</td>\n", + " <td>-0.3165</td>\n", + " <td>0.6239</td>\n", + " </tr>\n", + " <tr>\n", + " <th>371</th>\n", + " <td>-0.2393</td>\n", + " <td>-0.4475</td>\n", + " <td>0</td>\n", + " <td>0.4772</td>\n", + " <td>2.9612</td>\n", + " <td>-0.7799</td>\n", + " <td>-0.1244</td>\n", + " <td>-0.1693</td>\n", + " <td>0.5778</td>\n", + " <td>1.7607</td>\n", + " <td>...</td>\n", + " <td>-0.2276</td>\n", + " <td>9.4098</td>\n", + " <td>-0.5443</td>\n", + " <td>0</td>\n", + " <td>0.2992</td>\n", + " <td>-0.5451</td>\n", + " <td>-0.2456</td>\n", + " <td>0.8898</td>\n", + " <td>-0.5781</td>\n", + " <td>-0.3899</td>\n", + " </tr>\n", + " <tr>\n", + " <th>372</th>\n", + " <td>-0.1936</td>\n", + " <td>-0.2281</td>\n", + " <td>0</td>\n", + " <td>-0.4124</td>\n", + " <td>-0.1873</td>\n", + " <td>-0.1200</td>\n", + " <td>-0.1244</td>\n", + " <td>-0.0326</td>\n", + " <td>-0.8786</td>\n", + " <td>-0.3912</td>\n", + " <td>...</td>\n", + " <td>-0.2276</td>\n", + " <td>-0.2570</td>\n", + " <td>-0.3810</td>\n", + " <td>-1</td>\n", + " <td>-0.6399</td>\n", + " <td>-0.9128</td>\n", + " <td>0.3367</td>\n", + " <td>-0.4686</td>\n", + " <td>0.8995</td>\n", + " <td>1.3522</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>373 rows × 347 columns</p>\n", + "</div>" + ], + "text/plain": [ + " CXCL14_rnaseq FGF1_rnaseq IFNA8_cnv ADM_rnaseq LTBP2_rnaseq \\\n", + "0 -0.1170 -0.2221 1 -0.5126 -0.3289 \n", + "1 -0.2330 -0.4343 -1 -0.2381 -0.4799 \n", + "2 -0.1384 -0.1597 -1 -0.1521 -0.3348 \n", + "3 -0.1624 -0.3463 -1 0.0272 -0.7623 \n", + "4 -0.2346 -0.4090 -1 -0.2078 0.5702 \n", + ".. ... ... ... ... ... \n", + "368 -0.2417 10.1423 -1 -0.5456 0.8742 \n", + "369 -0.2412 1.3253 1 -0.5680 1.0719 \n", + "370 -0.2396 0.0435 0 -0.3610 3.1965 \n", + "371 -0.2393 -0.4475 0 0.4772 2.9612 \n", + "372 -0.1936 -0.2281 0 -0.4124 -0.1873 \n", + "\n", + " CCL28_rnaseq IFNA7_rnaseq GH2_rnaseq AIMP1_rnaseq DEFB1_rnaseq ... \\\n", + "0 -0.7331 -0.1244 -0.1693 0.5942 -0.4707 ... \n", + "1 -0.0520 -0.1244 -0.1693 1.1854 -0.4820 ... \n", + "2 -0.5310 -0.1244 -0.1693 0.3889 -0.3607 ... \n", + "3 0.8196 -0.1244 -0.1693 -0.0416 0.1661 ... \n", + "4 -0.4219 -0.1244 0.5257 -0.9790 0.3938 ... \n", + ".. ... ... ... ... ... ... \n", + "368 -0.1822 -0.1244 -0.1693 -1.2395 -0.5125 ... \n", + "369 -0.1707 -0.1244 -0.1693 -1.6694 -0.4528 ... \n", + "370 1.3670 -0.1244 -0.1693 0.4439 -0.5099 ... \n", + "371 -0.7799 -0.1244 -0.1693 0.5778 1.7607 ... \n", + "372 -0.1200 -0.1244 -0.0326 -0.8786 -0.3912 ... \n", + "\n", + " NPPB_rnaseq CCL27_rnaseq FASLG_rnaseq FGF20_cnv FAM3C_rnaseq \\\n", + "0 -0.2276 1.2033 0.9826 -1 -0.6161 \n", + "1 -0.2276 -0.2946 -0.5443 -1 -0.3499 \n", + "2 3.4177 -0.2946 -0.5320 0 0.4581 \n", + "3 -0.2276 -0.1020 -0.4682 -1 -0.4391 \n", + "4 -0.2276 -0.1035 -0.4688 -1 1.2596 \n", + ".. ... ... ... ... ... \n", + "368 -0.2276 -0.2946 0.0777 0 -0.8242 \n", + "369 0.5679 -0.2661 1.0215 -2 -0.5327 \n", + "370 -0.2276 -0.2289 0.0521 -1 1.0317 \n", + "371 -0.2276 9.4098 -0.5443 0 0.2992 \n", + "372 -0.2276 -0.2570 -0.3810 -1 -0.6399 \n", + "\n", + " IL18_rnaseq GDF10_rnaseq MYDGF_rnaseq IL10_rnaseq IFNW1_rnaseq \n", + "0 -0.5643 -0.2165 -0.2836 0.9991 -0.3899 \n", + "1 -0.7958 -0.3140 -0.3359 -0.4865 -0.3899 \n", + "2 -0.6179 -0.2107 0.2751 -0.5108 1.0629 \n", + "3 -0.7275 -0.2876 -0.4696 -0.6248 -0.3899 \n", + "4 -0.5807 0.4108 0.1801 -0.6086 -0.3899 \n", + ".. ... ... ... ... ... \n", + "368 -0.6727 0.1938 0.9210 0.4479 -0.3899 \n", + "369 0.3335 -0.1730 0.0147 0.6012 2.2526 \n", + "370 -0.1473 -0.1517 0.9384 -0.3165 0.6239 \n", + "371 -0.5451 -0.2456 0.8898 -0.5781 -0.3899 \n", + "372 -0.9128 0.3367 -0.4686 0.8995 1.3522 \n", + "\n", + "[373 rows x 347 columns]" + ] + }, + "execution_count": 88, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "genomic_features[series_intersecdef series_intersection(s1, s2):\n", + " return pd.Series(list(set(s1) & set(s2)))\n", + "tion(sig, genomic_features.columns)]" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "metadata": {}, + "outputs": [], + "source": [ + "def series_intersection(s1, s2):\n", + " return pd.Series(list(set(s1) & set(s2)))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>NDUFS5_cnv</th>\n", + " <th>MACF1_cnv</th>\n", + " <th>RNA5SP44_cnv</th>\n", + " <th>KIAA0754_cnv</th>\n", + " <th>BMP8A_cnv</th>\n", + " <th>PABPC4_cnv</th>\n", + " <th>SNORA55_cnv</th>\n", + " <th>HEYL_cnv</th>\n", + " <th>HPCAL4_cnv</th>\n", + " <th>NT5C1A_cnv</th>\n", + " <th>...</th>\n", + " <th>ZWINT_rnaseq</th>\n", + " <th>ZXDA_rnaseq</th>\n", + " <th>ZXDB_rnaseq</th>\n", + " <th>ZXDC_rnaseq</th>\n", + " <th>ZYG11A_rnaseq</th>\n", + " <th>ZYG11B_rnaseq</th>\n", + " <th>ZYX_rnaseq</th>\n", + " <th>ZZEF1_rnaseq</th>\n", + " <th>ZZZ3_rnaseq</th>\n", + " <th>TPTEP1_rnaseq</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>-1</td>\n", + " <td>-1</td>\n", + " <td>-1</td>\n", + " <td>-1</td>\n", + " <td>-1</td>\n", + " <td>-1</td>\n", + " <td>-1</td>\n", + " <td>-1</td>\n", + " <td>-1</td>\n", + " <td>-1</td>\n", + " <td>...</td>\n", + " <td>-0.8388</td>\n", + " <td>4.1375</td>\n", + " <td>3.9664</td>\n", + " <td>1.8437</td>\n", + " <td>-0.3959</td>\n", + " <td>-0.2561</td>\n", + " <td>-0.2866</td>\n", + " <td>1.8770</td>\n", + " <td>-0.3179</td>\n", + " <td>-0.3633</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>...</td>\n", + " <td>-0.1083</td>\n", + " <td>0.3393</td>\n", + " <td>0.2769</td>\n", + " <td>1.7320</td>\n", + " <td>-0.0975</td>\n", + " <td>2.6955</td>\n", + " <td>-0.6741</td>\n", + " <td>1.0323</td>\n", + " <td>1.2766</td>\n", + " <td>-0.3982</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>...</td>\n", + " <td>-0.4155</td>\n", + " <td>1.6846</td>\n", + " <td>0.7711</td>\n", + " <td>-0.3061</td>\n", + " <td>-0.5016</td>\n", + " <td>2.8548</td>\n", + " <td>-0.6171</td>\n", + " <td>-0.8608</td>\n", + " <td>-0.0486</td>\n", + " <td>-0.3962</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>...</td>\n", + " <td>-0.8143</td>\n", + " <td>0.8344</td>\n", + " <td>1.5075</td>\n", + " <td>3.6068</td>\n", + " <td>-0.5004</td>\n", + " <td>-0.0747</td>\n", + " <td>-0.2185</td>\n", + " <td>-0.4379</td>\n", + " <td>1.6913</td>\n", + " <td>1.7748</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>...</td>\n", + " <td>0.0983</td>\n", + " <td>-0.7908</td>\n", + " <td>-0.0053</td>\n", + " <td>-0.0643</td>\n", + " <td>-0.3706</td>\n", + " <td>0.3870</td>\n", + " <td>-0.5589</td>\n", + " <td>-0.5979</td>\n", + " <td>0.0047</td>\n", + " <td>-0.3548</td>\n", + " </tr>\n", + " <tr>\n", + " <th>...</th>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " </tr>\n", + " <tr>\n", + " <th>368</th>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>...</td>\n", + " <td>-0.0291</td>\n", + " <td>-0.1058</td>\n", + " <td>-0.6721</td>\n", + " <td>0.2802</td>\n", + " <td>1.9504</td>\n", + " <td>-0.8784</td>\n", + " <td>0.9506</td>\n", + " <td>0.0607</td>\n", + " <td>1.1883</td>\n", + " <td>-0.3521</td>\n", + " </tr>\n", + " <tr>\n", + " <th>369</th>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>...</td>\n", + " <td>0.0497</td>\n", + " <td>0.3673</td>\n", + " <td>-0.2208</td>\n", + " <td>0.3034</td>\n", + " <td>3.2580</td>\n", + " <td>-0.2089</td>\n", + " <td>1.6053</td>\n", + " <td>-0.8746</td>\n", + " <td>-0.4491</td>\n", + " <td>-0.3450</td>\n", + " </tr>\n", + " <tr>\n", + " <th>370</th>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>...</td>\n", + " <td>0.3822</td>\n", + " <td>-0.7003</td>\n", + " <td>-0.7661</td>\n", + " <td>-1.7035</td>\n", + " <td>-0.5423</td>\n", + " <td>-0.3488</td>\n", + " <td>1.3713</td>\n", + " <td>-0.4365</td>\n", + " <td>2.3456</td>\n", + " <td>-0.3866</td>\n", + " </tr>\n", + " <tr>\n", + " <th>371</th>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>...</td>\n", + " <td>-0.6853</td>\n", + " <td>-1.0240</td>\n", + " <td>-1.2890</td>\n", + " <td>-1.5666</td>\n", + " <td>-0.1270</td>\n", + " <td>-1.4662</td>\n", + " <td>0.3981</td>\n", + " <td>-0.5976</td>\n", + " <td>-1.3822</td>\n", + " <td>-0.4157</td>\n", + " </tr>\n", + " <tr>\n", + " <th>372</th>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>...</td>\n", + " <td>0.0517</td>\n", + " <td>-0.3570</td>\n", + " <td>-0.4843</td>\n", + " <td>-0.3792</td>\n", + " <td>-0.1964</td>\n", + " <td>0.4200</td>\n", + " <td>3.2547</td>\n", + " <td>-0.1232</td>\n", + " <td>3.4519</td>\n", + " <td>-0.1962</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>373 rows × 20395 columns</p>\n", + "</div>" + ], + "text/plain": [ + " NDUFS5_cnv MACF1_cnv RNA5SP44_cnv KIAA0754_cnv BMP8A_cnv PABPC4_cnv \\\n", + "0 -1 -1 -1 -1 -1 -1 \n", + "1 2 2 2 2 2 2 \n", + "2 0 0 0 0 0 0 \n", + "3 0 0 0 0 0 0 \n", + "4 0 0 0 0 0 0 \n", + ".. ... ... ... ... ... ... \n", + "368 2 2 2 2 2 2 \n", + "369 0 0 0 0 0 0 \n", + "370 1 1 1 1 1 1 \n", + "371 0 0 0 0 0 0 \n", + "372 0 0 0 0 0 0 \n", + "\n", + " SNORA55_cnv HEYL_cnv HPCAL4_cnv NT5C1A_cnv ... ZWINT_rnaseq \\\n", + "0 -1 -1 -1 -1 ... -0.8388 \n", + "1 2 2 2 2 ... -0.1083 \n", + "2 0 0 0 0 ... -0.4155 \n", + "3 0 0 0 0 ... -0.8143 \n", + "4 0 0 0 0 ... 0.0983 \n", + ".. ... ... ... ... ... ... \n", + "368 2 2 2 2 ... -0.0291 \n", + "369 0 0 0 0 ... 0.0497 \n", + "370 1 1 1 1 ... 0.3822 \n", + "371 0 0 0 0 ... -0.6853 \n", + "372 0 0 0 0 ... 0.0517 \n", + "\n", + " ZXDA_rnaseq ZXDB_rnaseq ZXDC_rnaseq ZYG11A_rnaseq ZYG11B_rnaseq \\\n", + "0 4.1375 3.9664 1.8437 -0.3959 -0.2561 \n", + "1 0.3393 0.2769 1.7320 -0.0975 2.6955 \n", + "2 1.6846 0.7711 -0.3061 -0.5016 2.8548 \n", + "3 0.8344 1.5075 3.6068 -0.5004 -0.0747 \n", + "4 -0.7908 -0.0053 -0.0643 -0.3706 0.3870 \n", + ".. ... ... ... ... ... \n", + "368 -0.1058 -0.6721 0.2802 1.9504 -0.8784 \n", + "369 0.3673 -0.2208 0.3034 3.2580 -0.2089 \n", + "370 -0.7003 -0.7661 -1.7035 -0.5423 -0.3488 \n", + "371 -1.0240 -1.2890 -1.5666 -0.1270 -1.4662 \n", + "372 -0.3570 -0.4843 -0.3792 -0.1964 0.4200 \n", + "\n", + " ZYX_rnaseq ZZEF1_rnaseq ZZZ3_rnaseq TPTEP1_rnaseq \n", + "0 -0.2866 1.8770 -0.3179 -0.3633 \n", + "1 -0.6741 1.0323 1.2766 -0.3982 \n", + "2 -0.6171 -0.8608 -0.0486 -0.3962 \n", + "3 -0.2185 -0.4379 1.6913 1.7748 \n", + "4 -0.5589 -0.5979 0.0047 -0.3548 \n", + ".. ... ... ... ... \n", + "368 0.9506 0.0607 1.1883 -0.3521 \n", + "369 1.6053 -0.8746 -0.4491 -0.3450 \n", + "370 1.3713 -0.4365 2.3456 -0.3866 \n", + "371 0.3981 -0.5976 -1.3822 -0.4157 \n", + "372 3.2547 -0.1232 3.4519 -0.1962 \n", + "\n", + "[373 rows x 20395 columns]" + ] + }, + "execution_count": 68, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "genomic_features" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "if 'case_id' not in slide_data:\n", + " slide_data.index = slide_data.index.str[:12]\n", + " slide_data['case_id'] = slide_data.index\n", + " slide_data = slide_data.reset_index(drop=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "new_cols = list(slide_data.columns[-2:]) + list(slide_data.columns[:-2])\n", + "slide_data = slide_data[new_cols]" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>ZZZ3_rnaseq</th>\n", + " <th>TPTEP1_rnaseq</th>\n", + " <th>slide_id</th>\n", + " <th>site</th>\n", + " <th>is_female</th>\n", + " <th>oncotree_code</th>\n", + " <th>age</th>\n", + " <th>survival_months</th>\n", + " <th>censorship</th>\n", + " <th>train</th>\n", + " <th>...</th>\n", + " <th>ZW10_rnaseq</th>\n", + " <th>ZWILCH_rnaseq</th>\n", + " <th>ZWINT_rnaseq</th>\n", + " <th>ZXDA_rnaseq</th>\n", + " <th>ZXDB_rnaseq</th>\n", + " <th>ZXDC_rnaseq</th>\n", + " <th>ZYG11A_rnaseq</th>\n", + " <th>ZYG11B_rnaseq</th>\n", + " <th>ZYX_rnaseq</th>\n", + " <th>ZZEF1_rnaseq</th>\n", + " </tr>\n", + " <tr>\n", + " <th>case_id</th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>TCGA-2F-A9KO</th>\n", + " <td>-0.3179</td>\n", + " <td>-0.3633</td>\n", + " <td>TCGA-2F-A9KO-01Z-00-DX1.195576CF-B739-4BD9-B15...</td>\n", + " <td>2F</td>\n", + " <td>0</td>\n", + " <td>BLCA</td>\n", + " <td>63</td>\n", + " <td>24.11</td>\n", + " <td>0</td>\n", + " <td>1.0</td>\n", + " <td>...</td>\n", + " <td>-0.7172</td>\n", + " <td>0.7409</td>\n", + " <td>-0.8388</td>\n", + " <td>4.1375</td>\n", + " <td>3.9664</td>\n", + " <td>1.8437</td>\n", + " <td>-0.3959</td>\n", + " <td>-0.2561</td>\n", + " <td>-0.2866</td>\n", + " <td>1.8770</td>\n", + " </tr>\n", + " <tr>\n", + " <th>TCGA-2F-A9KP</th>\n", + " <td>1.2766</td>\n", + " <td>-0.3982</td>\n", + " <td>TCGA-2F-A9KP-01Z-00-DX1.3CDF534E-958F-4467-AA7...</td>\n", + " <td>2F</td>\n", + " <td>0</td>\n", + " <td>BLCA</td>\n", + " <td>66</td>\n", + " <td>11.96</td>\n", + " <td>0</td>\n", + " <td>1.0</td>\n", + " <td>...</td>\n", + " <td>0.6373</td>\n", + " <td>0.8559</td>\n", + " <td>-0.1083</td>\n", + " <td>0.3393</td>\n", + " <td>0.2769</td>\n", + " <td>1.7320</td>\n", + " <td>-0.0975</td>\n", + " <td>2.6955</td>\n", + " <td>-0.6741</td>\n", + " <td>1.0323</td>\n", + " </tr>\n", + " <tr>\n", + " <th>TCGA-2F-A9KP</th>\n", + " <td>1.2766</td>\n", + " <td>-0.3982</td>\n", + " <td>TCGA-2F-A9KP-01Z-00-DX2.718C82A3-252B-498E-BFB...</td>\n", + " <td>2F</td>\n", + " <td>0</td>\n", + " <td>BLCA</td>\n", + " <td>66</td>\n", + " <td>11.96</td>\n", + " <td>0</td>\n", + " <td>1.0</td>\n", + " <td>...</td>\n", + " <td>0.6373</td>\n", + " <td>0.8559</td>\n", + " <td>-0.1083</td>\n", + " <td>0.3393</td>\n", + " <td>0.2769</td>\n", + " <td>1.7320</td>\n", + " <td>-0.0975</td>\n", + " <td>2.6955</td>\n", + " <td>-0.6741</td>\n", + " <td>1.0323</td>\n", + " </tr>\n", + " <tr>\n", + " <th>TCGA-2F-A9KQ</th>\n", + " <td>-0.0486</td>\n", + " <td>-0.3962</td>\n", + " <td>TCGA-2F-A9KQ-01Z-00-DX1.1C8CB2DD-5CC6-4E99-A0F...</td>\n", + " <td>2F</td>\n", + " <td>0</td>\n", + " <td>BLCA</td>\n", + " <td>69</td>\n", + " <td>94.81</td>\n", + " <td>1</td>\n", + " <td>1.0</td>\n", + " <td>...</td>\n", + " <td>-0.5676</td>\n", + " <td>-0.0621</td>\n", + " <td>-0.4155</td>\n", + " <td>1.6846</td>\n", + " <td>0.7711</td>\n", + " <td>-0.3061</td>\n", + " <td>-0.5016</td>\n", + " <td>2.8548</td>\n", + " <td>-0.6171</td>\n", + " <td>-0.8608</td>\n", + " </tr>\n", + " <tr>\n", + " <th>TCGA-2F-A9KR</th>\n", + " <td>1.6913</td>\n", + " <td>1.7748</td>\n", + " <td>TCGA-2F-A9KR-01Z-00-DX1.D6A4BD2D-18F3-4FA6-827...</td>\n", + " <td>2F</td>\n", + " <td>1</td>\n", + " <td>BLCA</td>\n", + " <td>59</td>\n", + " <td>104.57</td>\n", + " <td>0</td>\n", + " <td>1.0</td>\n", + " <td>...</td>\n", + " <td>-1.3825</td>\n", + " <td>0.3550</td>\n", + " <td>-0.8143</td>\n", + " <td>0.8344</td>\n", + " <td>1.5075</td>\n", + " <td>3.6068</td>\n", + " <td>-0.5004</td>\n", + " <td>-0.0747</td>\n", + " <td>-0.2185</td>\n", + " <td>-0.4379</td>\n", + " </tr>\n", + " <tr>\n", + " <th>...</th>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " </tr>\n", + " <tr>\n", + " <th>TCGA-ZF-AA54</th>\n", + " <td>1.1883</td>\n", + " <td>-0.3521</td>\n", + " <td>TCGA-ZF-AA54-01Z-00-DX1.9118BB51-333A-4257-A79...</td>\n", + " <td>ZF</td>\n", + " <td>0</td>\n", + " <td>BLCA</td>\n", + " <td>71</td>\n", + " <td>19.38</td>\n", + " <td>0</td>\n", + " <td>1.0</td>\n", + " <td>...</td>\n", + " <td>-0.0898</td>\n", + " <td>2.1092</td>\n", + " <td>-0.0291</td>\n", + " <td>-0.1058</td>\n", + " <td>-0.6721</td>\n", + " <td>0.2802</td>\n", + " <td>1.9504</td>\n", + " <td>-0.8784</td>\n", + " <td>0.9506</td>\n", + " <td>0.0607</td>\n", + " </tr>\n", + " <tr>\n", + " <th>TCGA-ZF-AA58</th>\n", + " <td>-0.4491</td>\n", + " <td>-0.3450</td>\n", + " <td>TCGA-ZF-AA58-01Z-00-DX1.85C3611E-11FA-4AAE-B88...</td>\n", + " <td>ZF</td>\n", + " <td>1</td>\n", + " <td>BLCA</td>\n", + " <td>61</td>\n", + " <td>54.17</td>\n", + " <td>1</td>\n", + " <td>1.0</td>\n", + " <td>...</td>\n", + " <td>-0.2075</td>\n", + " <td>-0.0617</td>\n", + " <td>0.0497</td>\n", + " <td>0.3673</td>\n", + " <td>-0.2208</td>\n", + " <td>0.3034</td>\n", + " <td>3.2580</td>\n", + " <td>-0.2089</td>\n", + " <td>1.6053</td>\n", + " <td>-0.8746</td>\n", + " </tr>\n", + " <tr>\n", + " <th>TCGA-ZF-AA5H</th>\n", + " <td>2.3456</td>\n", + " <td>-0.3866</td>\n", + " <td>TCGA-ZF-AA5H-01Z-00-DX1.2B5DF00E-E0FD-4C58-A82...</td>\n", + " <td>ZF</td>\n", + " <td>1</td>\n", + " <td>BLCA</td>\n", + " <td>60</td>\n", + " <td>29.47</td>\n", + " <td>1</td>\n", + " <td>1.0</td>\n", + " <td>...</td>\n", + " <td>-1.4118</td>\n", + " <td>-0.1236</td>\n", + " <td>0.3822</td>\n", + " <td>-0.7003</td>\n", + " <td>-0.7661</td>\n", + " <td>-1.7035</td>\n", + " <td>-0.5423</td>\n", + " <td>-0.3488</td>\n", + " <td>1.3713</td>\n", + " <td>-0.4365</td>\n", + " </tr>\n", + " <tr>\n", + " <th>TCGA-ZF-AA5N</th>\n", + " <td>-1.3822</td>\n", + " <td>-0.4157</td>\n", + " <td>TCGA-ZF-AA5N-01Z-00-DX1.A207E3EE-CC7D-4267-A77...</td>\n", + " <td>ZF</td>\n", + " <td>1</td>\n", + " <td>BLCA</td>\n", + " <td>62</td>\n", + " <td>5.52</td>\n", + " <td>0</td>\n", + " <td>1.0</td>\n", + " <td>...</td>\n", + " <td>-0.1733</td>\n", + " <td>-0.2397</td>\n", + " <td>-0.6853</td>\n", + " <td>-1.0240</td>\n", + " <td>-1.2890</td>\n", + " <td>-1.5666</td>\n", + " <td>-0.1270</td>\n", + " <td>-1.4662</td>\n", + " <td>0.3981</td>\n", + " <td>-0.5976</td>\n", + " </tr>\n", + " <tr>\n", + " <th>TCGA-ZF-AA5P</th>\n", + " <td>3.4519</td>\n", + " <td>-0.1962</td>\n", + " <td>TCGA-ZF-AA5P-01Z-00-DX1.B91697A2-A186-4E67-A81...</td>\n", + " <td>ZF</td>\n", + " <td>0</td>\n", + " <td>BLCA</td>\n", + " <td>65</td>\n", + " <td>12.22</td>\n", + " <td>1</td>\n", + " <td>1.0</td>\n", + " <td>...</td>\n", + " <td>-1.1056</td>\n", + " <td>-0.6634</td>\n", + " <td>0.0517</td>\n", + " <td>-0.3570</td>\n", + " <td>-0.4843</td>\n", + " <td>-0.3792</td>\n", + " <td>-0.1964</td>\n", + " <td>0.4200</td>\n", + " <td>3.2547</td>\n", + " <td>-0.1232</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>437 rows × 20403 columns</p>\n", + "</div>" + ], + "text/plain": [ + " ZZZ3_rnaseq TPTEP1_rnaseq \\\n", + "case_id \n", + "TCGA-2F-A9KO -0.3179 -0.3633 \n", + "TCGA-2F-A9KP 1.2766 -0.3982 \n", + "TCGA-2F-A9KP 1.2766 -0.3982 \n", + "TCGA-2F-A9KQ -0.0486 -0.3962 \n", + "TCGA-2F-A9KR 1.6913 1.7748 \n", + "... ... ... \n", + "TCGA-ZF-AA54 1.1883 -0.3521 \n", + "TCGA-ZF-AA58 -0.4491 -0.3450 \n", + "TCGA-ZF-AA5H 2.3456 -0.3866 \n", + "TCGA-ZF-AA5N -1.3822 -0.4157 \n", + "TCGA-ZF-AA5P 3.4519 -0.1962 \n", + "\n", + " slide_id site \\\n", + "case_id \n", + "TCGA-2F-A9KO TCGA-2F-A9KO-01Z-00-DX1.195576CF-B739-4BD9-B15... 2F \n", + "TCGA-2F-A9KP TCGA-2F-A9KP-01Z-00-DX1.3CDF534E-958F-4467-AA7... 2F \n", + "TCGA-2F-A9KP TCGA-2F-A9KP-01Z-00-DX2.718C82A3-252B-498E-BFB... 2F \n", + "TCGA-2F-A9KQ TCGA-2F-A9KQ-01Z-00-DX1.1C8CB2DD-5CC6-4E99-A0F... 2F \n", + "TCGA-2F-A9KR TCGA-2F-A9KR-01Z-00-DX1.D6A4BD2D-18F3-4FA6-827... 2F \n", + "... ... ... \n", + "TCGA-ZF-AA54 TCGA-ZF-AA54-01Z-00-DX1.9118BB51-333A-4257-A79... ZF \n", + "TCGA-ZF-AA58 TCGA-ZF-AA58-01Z-00-DX1.85C3611E-11FA-4AAE-B88... ZF \n", + "TCGA-ZF-AA5H TCGA-ZF-AA5H-01Z-00-DX1.2B5DF00E-E0FD-4C58-A82... ZF \n", + "TCGA-ZF-AA5N TCGA-ZF-AA5N-01Z-00-DX1.A207E3EE-CC7D-4267-A77... ZF \n", + "TCGA-ZF-AA5P TCGA-ZF-AA5P-01Z-00-DX1.B91697A2-A186-4E67-A81... ZF \n", + "\n", + " is_female oncotree_code age survival_months censorship \\\n", + "case_id \n", + "TCGA-2F-A9KO 0 BLCA 63 24.11 0 \n", + "TCGA-2F-A9KP 0 BLCA 66 11.96 0 \n", + "TCGA-2F-A9KP 0 BLCA 66 11.96 0 \n", + "TCGA-2F-A9KQ 0 BLCA 69 94.81 1 \n", + "TCGA-2F-A9KR 1 BLCA 59 104.57 0 \n", + "... ... ... ... ... ... \n", + "TCGA-ZF-AA54 0 BLCA 71 19.38 0 \n", + "TCGA-ZF-AA58 1 BLCA 61 54.17 1 \n", + "TCGA-ZF-AA5H 1 BLCA 60 29.47 1 \n", + "TCGA-ZF-AA5N 1 BLCA 62 5.52 0 \n", + "TCGA-ZF-AA5P 0 BLCA 65 12.22 1 \n", + "\n", + " train ... ZW10_rnaseq ZWILCH_rnaseq ZWINT_rnaseq \\\n", + "case_id ... \n", + "TCGA-2F-A9KO 1.0 ... -0.7172 0.7409 -0.8388 \n", + "TCGA-2F-A9KP 1.0 ... 0.6373 0.8559 -0.1083 \n", + "TCGA-2F-A9KP 1.0 ... 0.6373 0.8559 -0.1083 \n", + "TCGA-2F-A9KQ 1.0 ... -0.5676 -0.0621 -0.4155 \n", + "TCGA-2F-A9KR 1.0 ... -1.3825 0.3550 -0.8143 \n", + "... ... ... ... ... ... \n", + "TCGA-ZF-AA54 1.0 ... -0.0898 2.1092 -0.0291 \n", + "TCGA-ZF-AA58 1.0 ... -0.2075 -0.0617 0.0497 \n", + "TCGA-ZF-AA5H 1.0 ... -1.4118 -0.1236 0.3822 \n", + "TCGA-ZF-AA5N 1.0 ... -0.1733 -0.2397 -0.6853 \n", + "TCGA-ZF-AA5P 1.0 ... -1.1056 -0.6634 0.0517 \n", + "\n", + " ZXDA_rnaseq ZXDB_rnaseq ZXDC_rnaseq ZYG11A_rnaseq \\\n", + "case_id \n", + "TCGA-2F-A9KO 4.1375 3.9664 1.8437 -0.3959 \n", + "TCGA-2F-A9KP 0.3393 0.2769 1.7320 -0.0975 \n", + "TCGA-2F-A9KP 0.3393 0.2769 1.7320 -0.0975 \n", + "TCGA-2F-A9KQ 1.6846 0.7711 -0.3061 -0.5016 \n", + "TCGA-2F-A9KR 0.8344 1.5075 3.6068 -0.5004 \n", + "... ... ... ... ... \n", + "TCGA-ZF-AA54 -0.1058 -0.6721 0.2802 1.9504 \n", + "TCGA-ZF-AA58 0.3673 -0.2208 0.3034 3.2580 \n", + "TCGA-ZF-AA5H -0.7003 -0.7661 -1.7035 -0.5423 \n", + "TCGA-ZF-AA5N -1.0240 -1.2890 -1.5666 -0.1270 \n", + "TCGA-ZF-AA5P -0.3570 -0.4843 -0.3792 -0.1964 \n", + "\n", + " ZYG11B_rnaseq ZYX_rnaseq ZZEF1_rnaseq \n", + "case_id \n", + "TCGA-2F-A9KO -0.2561 -0.2866 1.8770 \n", + "TCGA-2F-A9KP 2.6955 -0.6741 1.0323 \n", + "TCGA-2F-A9KP 2.6955 -0.6741 1.0323 \n", + "TCGA-2F-A9KQ 2.8548 -0.6171 -0.8608 \n", + "TCGA-2F-A9KR -0.0747 -0.2185 -0.4379 \n", + "... ... ... ... \n", + "TCGA-ZF-AA54 -0.8784 0.9506 0.0607 \n", + "TCGA-ZF-AA58 -0.2089 1.6053 -0.8746 \n", + "TCGA-ZF-AA5H -0.3488 1.3713 -0.4365 \n", + "TCGA-ZF-AA5N -1.4662 0.3981 -0.5976 \n", + "TCGA-ZF-AA5P 0.4200 3.2547 -0.1232 \n", + "\n", + "[437 rows x 20403 columns]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "slide_data" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}