[a2d0eb]: / notebooks / exploration.ipynb

Download this file

948 lines (947 with data), 115.3 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "23f2705f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Path to dataset files: /home/kesslermatics/.cache/kagglehub/datasets/raghadalharbi/breast-cancer-gene-expression-profiles-metabric/versions/1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_24481/2461765115.py:9: DtypeWarning: Columns (678,688,690,692) have mixed types. Specify dtype option on import or set low_memory=False.\n",
      "  df = pd.read_csv(os.path.join(path, \"METABRIC_RNA_Mutation.csv\"))\n"
     ]
    }
   ],
   "source": [
    "import kagglehub\n",
    "import os\n",
    "import pandas as pd\n",
    "\n",
    "# Download latest version\n",
    "path = kagglehub.dataset_download(\"raghadalharbi/breast-cancer-gene-expression-profiles-metabric\")\n",
    "\n",
    "print(\"Path to dataset files:\", path)\n",
    "df = pd.read_csv(os.path.join(path, \"METABRIC_RNA_Mutation.csv\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "05cff3e5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>patient_id</th>\n",
       "      <th>age_at_diagnosis</th>\n",
       "      <th>chemotherapy</th>\n",
       "      <th>cohort</th>\n",
       "      <th>neoplasm_histologic_grade</th>\n",
       "      <th>hormone_therapy</th>\n",
       "      <th>lymph_nodes_examined_positive</th>\n",
       "      <th>mutation_count</th>\n",
       "      <th>nottingham_prognostic_index</th>\n",
       "      <th>overall_survival_months</th>\n",
       "      <th>...</th>\n",
       "      <th>srd5a1</th>\n",
       "      <th>srd5a2</th>\n",
       "      <th>srd5a3</th>\n",
       "      <th>st7</th>\n",
       "      <th>star</th>\n",
       "      <th>tnk2</th>\n",
       "      <th>tulp4</th>\n",
       "      <th>ugt2b15</th>\n",
       "      <th>ugt2b17</th>\n",
       "      <th>ugt2b7</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>1904.000000</td>\n",
       "      <td>1904.000000</td>\n",
       "      <td>1904.000000</td>\n",
       "      <td>1904.000000</td>\n",
       "      <td>1832.000000</td>\n",
       "      <td>1904.000000</td>\n",
       "      <td>1904.000000</td>\n",
       "      <td>1859.000000</td>\n",
       "      <td>1904.000000</td>\n",
       "      <td>1904.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>1.904000e+03</td>\n",
       "      <td>1.904000e+03</td>\n",
       "      <td>1.904000e+03</td>\n",
       "      <td>1.904000e+03</td>\n",
       "      <td>1904.000000</td>\n",
       "      <td>1.904000e+03</td>\n",
       "      <td>1.904000e+03</td>\n",
       "      <td>1.904000e+03</td>\n",
       "      <td>1904.000000</td>\n",
       "      <td>1.904000e+03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>3921.982143</td>\n",
       "      <td>61.087054</td>\n",
       "      <td>0.207983</td>\n",
       "      <td>2.643908</td>\n",
       "      <td>2.415939</td>\n",
       "      <td>0.616597</td>\n",
       "      <td>2.002101</td>\n",
       "      <td>5.697687</td>\n",
       "      <td>4.033019</td>\n",
       "      <td>125.121324</td>\n",
       "      <td>...</td>\n",
       "      <td>4.726891e-07</td>\n",
       "      <td>-3.676471e-07</td>\n",
       "      <td>-9.453782e-07</td>\n",
       "      <td>-1.050420e-07</td>\n",
       "      <td>-0.000002</td>\n",
       "      <td>3.676471e-07</td>\n",
       "      <td>4.726891e-07</td>\n",
       "      <td>7.878151e-07</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>3.731842e-18</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>2358.478332</td>\n",
       "      <td>12.978711</td>\n",
       "      <td>0.405971</td>\n",
       "      <td>1.228615</td>\n",
       "      <td>0.650612</td>\n",
       "      <td>0.486343</td>\n",
       "      <td>4.079993</td>\n",
       "      <td>4.058778</td>\n",
       "      <td>1.144492</td>\n",
       "      <td>76.334148</td>\n",
       "      <td>...</td>\n",
       "      <td>1.000263e+00</td>\n",
       "      <td>1.000262e+00</td>\n",
       "      <td>1.000262e+00</td>\n",
       "      <td>1.000263e+00</td>\n",
       "      <td>1.000262</td>\n",
       "      <td>1.000264e+00</td>\n",
       "      <td>1.000262e+00</td>\n",
       "      <td>1.000263e+00</td>\n",
       "      <td>1.000262</td>\n",
       "      <td>1.000262e+00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>21.930000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>-2.120800e+00</td>\n",
       "      <td>-3.364800e+00</td>\n",
       "      <td>-2.719400e+00</td>\n",
       "      <td>-4.982700e+00</td>\n",
       "      <td>-2.981700</td>\n",
       "      <td>-3.833300e+00</td>\n",
       "      <td>-3.609300e+00</td>\n",
       "      <td>-1.166900e+00</td>\n",
       "      <td>-2.112600</td>\n",
       "      <td>-1.051600e+00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>896.500000</td>\n",
       "      <td>51.375000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>3.046000</td>\n",
       "      <td>60.825000</td>\n",
       "      <td>...</td>\n",
       "      <td>-6.188500e-01</td>\n",
       "      <td>-6.104750e-01</td>\n",
       "      <td>-6.741750e-01</td>\n",
       "      <td>-6.136750e-01</td>\n",
       "      <td>-0.632900</td>\n",
       "      <td>-6.664750e-01</td>\n",
       "      <td>-7.102000e-01</td>\n",
       "      <td>-5.058250e-01</td>\n",
       "      <td>-0.476200</td>\n",
       "      <td>-7.260000e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>4730.500000</td>\n",
       "      <td>61.770000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.000000</td>\n",
       "      <td>4.042000</td>\n",
       "      <td>115.616667</td>\n",
       "      <td>...</td>\n",
       "      <td>-2.456500e-01</td>\n",
       "      <td>-4.690000e-02</td>\n",
       "      <td>-1.422500e-01</td>\n",
       "      <td>-5.175000e-02</td>\n",
       "      <td>-0.026650</td>\n",
       "      <td>7.000000e-04</td>\n",
       "      <td>-2.980000e-02</td>\n",
       "      <td>-2.885500e-01</td>\n",
       "      <td>-0.133400</td>\n",
       "      <td>-4.248000e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>5536.250000</td>\n",
       "      <td>70.592500</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>7.000000</td>\n",
       "      <td>5.040250</td>\n",
       "      <td>184.716667</td>\n",
       "      <td>...</td>\n",
       "      <td>3.306000e-01</td>\n",
       "      <td>5.144500e-01</td>\n",
       "      <td>5.146000e-01</td>\n",
       "      <td>5.787750e-01</td>\n",
       "      <td>0.590350</td>\n",
       "      <td>6.429000e-01</td>\n",
       "      <td>5.957250e-01</td>\n",
       "      <td>6.022500e-02</td>\n",
       "      <td>0.270375</td>\n",
       "      <td>4.284000e-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>7299.000000</td>\n",
       "      <td>96.290000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>5.000000</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>45.000000</td>\n",
       "      <td>80.000000</td>\n",
       "      <td>6.360000</td>\n",
       "      <td>355.200000</td>\n",
       "      <td>...</td>\n",
       "      <td>6.534900e+00</td>\n",
       "      <td>1.027030e+01</td>\n",
       "      <td>6.329000e+00</td>\n",
       "      <td>4.571300e+00</td>\n",
       "      <td>12.742300</td>\n",
       "      <td>3.938800e+00</td>\n",
       "      <td>3.833400e+00</td>\n",
       "      <td>1.088490e+01</td>\n",
       "      <td>12.643900</td>\n",
       "      <td>3.284400e+00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>8 rows × 503 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        patient_id  age_at_diagnosis  chemotherapy       cohort  \\\n",
       "count  1904.000000       1904.000000   1904.000000  1904.000000   \n",
       "mean   3921.982143         61.087054      0.207983     2.643908   \n",
       "std    2358.478332         12.978711      0.405971     1.228615   \n",
       "min       0.000000         21.930000      0.000000     1.000000   \n",
       "25%     896.500000         51.375000      0.000000     1.000000   \n",
       "50%    4730.500000         61.770000      0.000000     3.000000   \n",
       "75%    5536.250000         70.592500      0.000000     3.000000   \n",
       "max    7299.000000         96.290000      1.000000     5.000000   \n",
       "\n",
       "       neoplasm_histologic_grade  hormone_therapy  \\\n",
       "count                1832.000000      1904.000000   \n",
       "mean                    2.415939         0.616597   \n",
       "std                     0.650612         0.486343   \n",
       "min                     1.000000         0.000000   \n",
       "25%                     2.000000         0.000000   \n",
       "50%                     3.000000         1.000000   \n",
       "75%                     3.000000         1.000000   \n",
       "max                     3.000000         1.000000   \n",
       "\n",
       "       lymph_nodes_examined_positive  mutation_count  \\\n",
       "count                    1904.000000     1859.000000   \n",
       "mean                        2.002101        5.697687   \n",
       "std                         4.079993        4.058778   \n",
       "min                         0.000000        1.000000   \n",
       "25%                         0.000000        3.000000   \n",
       "50%                         0.000000        5.000000   \n",
       "75%                         2.000000        7.000000   \n",
       "max                        45.000000       80.000000   \n",
       "\n",
       "       nottingham_prognostic_index  overall_survival_months  ...  \\\n",
       "count                  1904.000000              1904.000000  ...   \n",
       "mean                      4.033019               125.121324  ...   \n",
       "std                       1.144492                76.334148  ...   \n",
       "min                       1.000000                 0.000000  ...   \n",
       "25%                       3.046000                60.825000  ...   \n",
       "50%                       4.042000               115.616667  ...   \n",
       "75%                       5.040250               184.716667  ...   \n",
       "max                       6.360000               355.200000  ...   \n",
       "\n",
       "             srd5a1        srd5a2        srd5a3           st7         star  \\\n",
       "count  1.904000e+03  1.904000e+03  1.904000e+03  1.904000e+03  1904.000000   \n",
       "mean   4.726891e-07 -3.676471e-07 -9.453782e-07 -1.050420e-07    -0.000002   \n",
       "std    1.000263e+00  1.000262e+00  1.000262e+00  1.000263e+00     1.000262   \n",
       "min   -2.120800e+00 -3.364800e+00 -2.719400e+00 -4.982700e+00    -2.981700   \n",
       "25%   -6.188500e-01 -6.104750e-01 -6.741750e-01 -6.136750e-01    -0.632900   \n",
       "50%   -2.456500e-01 -4.690000e-02 -1.422500e-01 -5.175000e-02    -0.026650   \n",
       "75%    3.306000e-01  5.144500e-01  5.146000e-01  5.787750e-01     0.590350   \n",
       "max    6.534900e+00  1.027030e+01  6.329000e+00  4.571300e+00    12.742300   \n",
       "\n",
       "               tnk2         tulp4       ugt2b15      ugt2b17        ugt2b7  \n",
       "count  1.904000e+03  1.904000e+03  1.904000e+03  1904.000000  1.904000e+03  \n",
       "mean   3.676471e-07  4.726891e-07  7.878151e-07     0.000000  3.731842e-18  \n",
       "std    1.000264e+00  1.000262e+00  1.000263e+00     1.000262  1.000262e+00  \n",
       "min   -3.833300e+00 -3.609300e+00 -1.166900e+00    -2.112600 -1.051600e+00  \n",
       "25%   -6.664750e-01 -7.102000e-01 -5.058250e-01    -0.476200 -7.260000e-01  \n",
       "50%    7.000000e-04 -2.980000e-02 -2.885500e-01    -0.133400 -4.248000e-01  \n",
       "75%    6.429000e-01  5.957250e-01  6.022500e-02     0.270375  4.284000e-01  \n",
       "max    3.938800e+00  3.833400e+00  1.088490e+01    12.643900  3.284400e+00  \n",
       "\n",
       "[8 rows x 503 columns]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "3630c8da",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>patient_id</th>\n",
       "      <th>age_at_diagnosis</th>\n",
       "      <th>type_of_breast_surgery</th>\n",
       "      <th>cancer_type</th>\n",
       "      <th>cancer_type_detailed</th>\n",
       "      <th>cellularity</th>\n",
       "      <th>chemotherapy</th>\n",
       "      <th>pam50_+_claudin-low_subtype</th>\n",
       "      <th>cohort</th>\n",
       "      <th>er_status_measured_by_ihc</th>\n",
       "      <th>...</th>\n",
       "      <th>mtap_mut</th>\n",
       "      <th>ppp2cb_mut</th>\n",
       "      <th>smarcd1_mut</th>\n",
       "      <th>nras_mut</th>\n",
       "      <th>ndfip1_mut</th>\n",
       "      <th>hras_mut</th>\n",
       "      <th>prps2_mut</th>\n",
       "      <th>smarcb1_mut</th>\n",
       "      <th>stmn2_mut</th>\n",
       "      <th>siah1_mut</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>75.65</td>\n",
       "      <td>MASTECTOMY</td>\n",
       "      <td>Breast Cancer</td>\n",
       "      <td>Breast Invasive Ductal Carcinoma</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>claudin-low</td>\n",
       "      <td>1.0</td>\n",
       "      <td>Positve</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>43.19</td>\n",
       "      <td>BREAST CONSERVING</td>\n",
       "      <td>Breast Cancer</td>\n",
       "      <td>Breast Invasive Ductal Carcinoma</td>\n",
       "      <td>High</td>\n",
       "      <td>0</td>\n",
       "      <td>LumA</td>\n",
       "      <td>1.0</td>\n",
       "      <td>Positve</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>5</td>\n",
       "      <td>48.87</td>\n",
       "      <td>MASTECTOMY</td>\n",
       "      <td>Breast Cancer</td>\n",
       "      <td>Breast Invasive Ductal Carcinoma</td>\n",
       "      <td>High</td>\n",
       "      <td>1</td>\n",
       "      <td>LumB</td>\n",
       "      <td>1.0</td>\n",
       "      <td>Positve</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>6</td>\n",
       "      <td>47.68</td>\n",
       "      <td>MASTECTOMY</td>\n",
       "      <td>Breast Cancer</td>\n",
       "      <td>Breast Mixed Ductal and Lobular Carcinoma</td>\n",
       "      <td>Moderate</td>\n",
       "      <td>1</td>\n",
       "      <td>LumB</td>\n",
       "      <td>1.0</td>\n",
       "      <td>Positve</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>8</td>\n",
       "      <td>76.97</td>\n",
       "      <td>MASTECTOMY</td>\n",
       "      <td>Breast Cancer</td>\n",
       "      <td>Breast Mixed Ductal and Lobular Carcinoma</td>\n",
       "      <td>High</td>\n",
       "      <td>1</td>\n",
       "      <td>LumB</td>\n",
       "      <td>1.0</td>\n",
       "      <td>Positve</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 693 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   patient_id  age_at_diagnosis type_of_breast_surgery    cancer_type  \\\n",
       "0           0             75.65             MASTECTOMY  Breast Cancer   \n",
       "1           2             43.19      BREAST CONSERVING  Breast Cancer   \n",
       "2           5             48.87             MASTECTOMY  Breast Cancer   \n",
       "3           6             47.68             MASTECTOMY  Breast Cancer   \n",
       "4           8             76.97             MASTECTOMY  Breast Cancer   \n",
       "\n",
       "                        cancer_type_detailed cellularity  chemotherapy  \\\n",
       "0           Breast Invasive Ductal Carcinoma         NaN             0   \n",
       "1           Breast Invasive Ductal Carcinoma        High             0   \n",
       "2           Breast Invasive Ductal Carcinoma        High             1   \n",
       "3  Breast Mixed Ductal and Lobular Carcinoma    Moderate             1   \n",
       "4  Breast Mixed Ductal and Lobular Carcinoma        High             1   \n",
       "\n",
       "  pam50_+_claudin-low_subtype  cohort er_status_measured_by_ihc  ... mtap_mut  \\\n",
       "0                 claudin-low     1.0                   Positve  ...        0   \n",
       "1                        LumA     1.0                   Positve  ...        0   \n",
       "2                        LumB     1.0                   Positve  ...        0   \n",
       "3                        LumB     1.0                   Positve  ...        0   \n",
       "4                        LumB     1.0                   Positve  ...        0   \n",
       "\n",
       "   ppp2cb_mut smarcd1_mut nras_mut ndfip1_mut  hras_mut prps2_mut smarcb1_mut  \\\n",
       "0           0           0        0          0         0         0           0   \n",
       "1           0           0        0          0         0         0           0   \n",
       "2           0           0        0          0         0         0           0   \n",
       "3           0           0        0          0         0         0           0   \n",
       "4           0           0        0          0         0         0           0   \n",
       "\n",
       "  stmn2_mut  siah1_mut  \n",
       "0         0          0  \n",
       "1         0          0  \n",
       "2         0          0  \n",
       "3         0          0  \n",
       "4         0          0  \n",
       "\n",
       "[5 rows x 693 columns]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "32de0ca9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Figure size 640x480 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 1200x1000 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import os\n",
    "\n",
    "# Create output folder\n",
    "os.makedirs(\"outputs/figures\", exist_ok=True)\n",
    "\n",
    "# 1. Chemotherapy distribution\n",
    "sns.countplot(x=\"chemotherapy\", data=df)\n",
    "plt.title(\"Chemotherapy Decision Distribution\")\n",
    "plt.savefig(\"../outputs/figures/chemotherapy_distribution.png\")\n",
    "plt.clf()\n",
    "\n",
    "# 2. Age at diagnosis distribution\n",
    "sns.histplot(df[\"age_at_diagnosis\"], kde=True, bins=30)\n",
    "plt.title(\"Age at Diagnosis\")\n",
    "plt.xlabel(\"Age\")\n",
    "plt.ylabel(\"Count\")\n",
    "plt.savefig(\"../outputs/figures/age_distribution.png\")\n",
    "plt.clf()\n",
    "\n",
    "# 3. Tumor size distribution\n",
    "if \"tumor_size\" in df.columns:\n",
    "    sns.histplot(df[\"tumor_size\"].dropna(), kde=True, bins=30)\n",
    "    plt.title(\"Tumor Size\")\n",
    "    plt.xlabel(\"Size (mm)\")\n",
    "    plt.savefig(\"../outputs/figures/tumor_size_distribution.png\")\n",
    "    plt.clf()\n",
    "\n",
    "# 4. Chemotherapy by PAM50 subtype\n",
    "if \"pam50_+_claudin-low_subtype\" in df.columns:\n",
    "    sns.countplot(x=\"pam50_+_claudin-low_subtype\", hue=\"chemotherapy\", data=df)\n",
    "    plt.title(\"Chemotherapy by PAM50 Subtype\")\n",
    "    plt.xticks(rotation=45)\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\"../outputs/figures/chemotherapy_by_pam50_subtype.png\")\n",
    "    plt.clf()\n",
    "\n",
    "# 5. Chemotherapy vs. hormone therapy\n",
    "if \"hormone_therapy\" in df.columns:\n",
    "    sns.countplot(x=\"chemotherapy\", hue=\"hormone_therapy\", data=df)\n",
    "    plt.title(\"Chemotherapy vs. Hormone Therapy\")\n",
    "    plt.savefig(\"../outputs/figures/chemotherapy_vs_hormone_therapy.png\")\n",
    "    plt.clf()\n",
    "\n",
    "# 6. Correlation heatmap of numeric features\n",
    "numeric_df = df.select_dtypes(include=[\"int64\", \"float64\"]).drop(columns=[\"patient_id\"], errors=\"ignore\")\n",
    "plt.figure(figsize=(12, 10))\n",
    "sns.heatmap(numeric_df.corr(), cmap=\"coolwarm\", linewidths=0.5)\n",
    "plt.title(\"Correlation Heatmap of Numerical Features\")\n",
    "plt.savefig(\"../outputs/figures/numerical_correlation_heatmap.png\")\n",
    "plt.clf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a71904ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/200 | Train Loss: 0.6784 | Val Loss: 0.6489\n",
      "Epoch 2/200 | Train Loss: 0.6003 | Val Loss: 0.5370\n",
      "Epoch 3/200 | Train Loss: 0.4942 | Val Loss: 0.4278\n",
      "Epoch 4/200 | Train Loss: 0.4268 | Val Loss: 0.3687\n",
      "Epoch 5/200 | Train Loss: 0.4064 | Val Loss: 0.3369\n",
      "Epoch 6/200 | Train Loss: 0.3759 | Val Loss: 0.3160\n",
      "Epoch 7/200 | Train Loss: 0.3660 | Val Loss: 0.2986\n",
      "Epoch 8/200 | Train Loss: 0.3521 | Val Loss: 0.2831\n",
      "Epoch 9/200 | Train Loss: 0.3457 | Val Loss: 0.2691\n",
      "Epoch 10/200 | Train Loss: 0.3315 | Val Loss: 0.2599\n",
      "Epoch 11/200 | Train Loss: 0.3281 | Val Loss: 0.2540\n",
      "Epoch 12/200 | Train Loss: 0.3422 | Val Loss: 0.2497\n",
      "Epoch 13/200 | Train Loss: 0.3301 | Val Loss: 0.2468\n",
      "Epoch 14/200 | Train Loss: 0.3064 | Val Loss: 0.2449\n",
      "Epoch 15/200 | Train Loss: 0.2887 | Val Loss: 0.2413\n",
      "Epoch 16/200 | Train Loss: 0.3206 | Val Loss: 0.2394\n",
      "Epoch 17/200 | Train Loss: 0.3045 | Val Loss: 0.2392\n",
      "Epoch 18/200 | Train Loss: 0.2961 | Val Loss: 0.2377\n",
      "Epoch 19/200 | Train Loss: 0.2955 | Val Loss: 0.2373\n",
      "Epoch 20/200 | Train Loss: 0.3007 | Val Loss: 0.2364\n",
      "Epoch 21/200 | Train Loss: 0.2977 | Val Loss: 0.2355\n",
      "Epoch 22/200 | Train Loss: 0.3060 | Val Loss: 0.2354\n",
      "Epoch 23/200 | Train Loss: 0.2944 | Val Loss: 0.2336\n",
      "Epoch 24/200 | Train Loss: 0.3073 | Val Loss: 0.2336\n",
      "Epoch 25/200 | Train Loss: 0.2963 | Val Loss: 0.2346\n",
      "Epoch 26/200 | Train Loss: 0.2977 | Val Loss: 0.2349\n",
      "Epoch 27/200 | Train Loss: 0.2821 | Val Loss: 0.2342\n",
      "Epoch 28/200 | Train Loss: 0.2795 | Val Loss: 0.2332\n",
      "Epoch 29/200 | Train Loss: 0.2866 | Val Loss: 0.2344\n",
      "Epoch 30/200 | Train Loss: 0.2887 | Val Loss: 0.2339\n",
      "Epoch 31/200 | Train Loss: 0.2858 | Val Loss: 0.2329\n",
      "Epoch 32/200 | Train Loss: 0.2819 | Val Loss: 0.2324\n",
      "Epoch 33/200 | Train Loss: 0.2700 | Val Loss: 0.2329\n",
      "Epoch 34/200 | Train Loss: 0.2979 | Val Loss: 0.2337\n",
      "Epoch 35/200 | Train Loss: 0.2724 | Val Loss: 0.2311\n",
      "Epoch 36/200 | Train Loss: 0.2624 | Val Loss: 0.2323\n",
      "Epoch 37/200 | Train Loss: 0.2906 | Val Loss: 0.2329\n",
      "Epoch 38/200 | Train Loss: 0.3033 | Val Loss: 0.2327\n",
      "Epoch 39/200 | Train Loss: 0.2641 | Val Loss: 0.2330\n",
      "Epoch 40/200 | Train Loss: 0.2754 | Val Loss: 0.2324\n",
      "Early stopping at epoch 40\n"
     ]
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 800x500 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
    "from sklearn.compose import ColumnTransformer\n",
    "from sklearn.pipeline import Pipeline\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "selected_features = [\n",
    "    \"age_at_diagnosis\",\n",
    "    \"tumor_size\",\n",
    "    \"tumor_stage\",\n",
    "    \"lymph_nodes_examined_positive\",\n",
    "    \"nottingham_prognostic_index\",\n",
    "    \"cellularity\",\n",
    "    \"neoplasm_histologic_grade\",\n",
    "    \"inferred_menopausal_state\",\n",
    "    \"er_status_measured_by_ihc\",\n",
    "    \"pr_status\",\n",
    "    \"her2_status\"\n",
    "]\n",
    "\n",
    "df = df.dropna(subset=selected_features + [\"chemotherapy\"])\n",
    "X = df[selected_features].copy()\n",
    "y = df[\"chemotherapy\"].astype(int)\n",
    "\n",
    "# Typen korrigieren\n",
    "categorical_cols = X.select_dtypes(include=[\"object\", \"bool\", \"category\"]).columns.tolist()\n",
    "X[categorical_cols] = X[categorical_cols].astype(str)\n",
    "numeric_cols = X.select_dtypes(include=[\"int64\", \"float64\"]).columns.tolist()\n",
    "\n",
    "# Preprocessing Pipeline\n",
    "preprocessor = ColumnTransformer([\n",
    "    (\"num\", StandardScaler(), numeric_cols),\n",
    "    (\"cat\", OneHotEncoder(drop=\"first\", handle_unknown=\"ignore\"), categorical_cols)\n",
    "])\n",
    "X_processed = preprocessor.fit_transform(X)\n",
    "input_dim = X_processed.shape[1]\n",
    "\n",
    "X_tensor = torch.tensor(X_processed, dtype=torch.float32)\n",
    "y_tensor = torch.tensor(y.values, dtype=torch.float32).unsqueeze(1)\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(X_tensor, y_tensor, test_size=0.2, stratify=y, random_state=42)\n",
    "train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=32, shuffle=True)\n",
    "\n",
    "class ChemotherapyModel(nn.Module):\n",
    "    def __init__(self, in_dim):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(in_dim, 32),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(0.5),\n",
    "            nn.Linear(32, 16),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(0.5),\n",
    "            nn.Linear(16, 1),\n",
    "            nn.Sigmoid()\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "\n",
    "model = ChemotherapyModel(input_dim)\n",
    "loss_fn = nn.BCELoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "best_val_loss = float(\"inf\")\n",
    "patience = 5\n",
    "counter = 0\n",
    "best_model_state = None  # to store best model weights\n",
    "\n",
    "train_losses, val_losses = [], []\n",
    "for epoch in range(200):\n",
    "    model.train()\n",
    "    epoch_loss = 0\n",
    "    for xb, yb in train_loader:\n",
    "        preds = model(xb)\n",
    "        loss = loss_fn(preds, yb)\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        epoch_loss += loss.item()\n",
    "    train_losses.append(epoch_loss / len(train_loader))\n",
    "\n",
    "    # Validation loss\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        val_preds = model(X_test)\n",
    "        val_loss = loss_fn(val_preds, y_test)\n",
    "        val_losses.append(val_loss.item())\n",
    "\n",
    "    print(f\"Epoch {epoch+1}/200 | Train Loss: {train_losses[-1]:.4f} | Val Loss: {val_losses[-1]:.4f}\")\n",
    "\n",
    "    # Early stopping check\n",
    "    if val_loss.item() < best_val_loss:\n",
    "        best_val_loss = val_loss.item()\n",
    "        best_model_state = model.state_dict()\n",
    "        counter = 0\n",
    "    else:\n",
    "        counter += 1\n",
    "        if counter >= patience:\n",
    "            print(f\"Early stopping at epoch {epoch+1}\")\n",
    "            break\n",
    "\n",
    "os.makedirs(\"../outputs/figures\", exist_ok=True)\n",
    "plt.figure(figsize=(8,5))\n",
    "plt.plot(train_losses, label=\"Train Loss\")\n",
    "plt.plot(val_losses, label=\"Validation Loss\")\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.ylabel(\"BCE Loss\")\n",
    "plt.title(\"Learning Curve (Chemotherapy Prediction)\")\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../outputs/figures/nn_learning_curve_chemotherapy.png\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f358fffc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy : 0.9035\n",
      "Precision: 0.8000\n",
      "Recall   : 0.7586\n",
      "F1 Score : 0.7788\n",
      "ROC AUC  : 0.9550\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import (\n",
    "    accuracy_score, precision_score, recall_score,\n",
    "    f1_score, roc_auc_score, confusion_matrix, ConfusionMatrixDisplay\n",
    ")\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# Model in eval mode\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    y_pred_prob = model(X_test).numpy()\n",
    "    y_pred = (y_pred_prob >= 0.5).astype(int)\n",
    "    y_true = y_test.numpy()\n",
    "\n",
    "# Evaluation metrics\n",
    "accuracy = accuracy_score(y_true, y_pred)\n",
    "precision = precision_score(y_true, y_pred, zero_division=0)\n",
    "recall = recall_score(y_true, y_pred, zero_division=0)\n",
    "f1 = f1_score(y_true, y_pred, zero_division=0)\n",
    "auc = roc_auc_score(y_true, y_pred_prob)\n",
    "\n",
    "# Confusion matrix\n",
    "cm = confusion_matrix(y_true, y_pred)\n",
    "disp = ConfusionMatrixDisplay(confusion_matrix=cm)\n",
    "disp.plot(cmap=\"Blues\")\n",
    "plt.title(\"Confusion Matrix - Chemotherapy Prediction\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../outputs/figures/confusion_matrix_chemotherapy.png\")\n",
    "plt.show()\n",
    "\n",
    "# Print results\n",
    "print(f\"Accuracy : {accuracy:.4f}\")\n",
    "print(f\"Precision: {precision:.4f}\")\n",
    "print(f\"Recall   : {recall:.4f}\")\n",
    "print(f\"F1 Score : {f1:.4f}\")\n",
    "print(f\"ROC AUC  : {auc:.4f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c2b952a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}