[9e8054]: / paper / NAR / 07_channel_effects / MNIST / 00-MNIST.ipynb

Download this file

217 lines (216 with data), 8.9 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e374cee0-2db8-4df5-a1f1-7be63928112b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/shenwanxiang/anaconda3/envs/molmap/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/home/shenwanxiang/anaconda3/envs/molmap/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/home/shenwanxiang/anaconda3/envs/molmap/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/home/shenwanxiang/anaconda3/envs/molmap/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/home/shenwanxiang/anaconda3/envs/molmap/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/home/shenwanxiang/anaconda3/envs/molmap/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAABICAYAAADI6S+jAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAACDElEQVR4nO3aP2pUURjG4e8ElWhs/BPstBIUFBQHtyK4gNmQnQtwCZItCFEbVyFRiKCVcGxsHNRhYI7vzPF5ulxu8X4EfsVlWu+9APj3DtIDAP5XAgwQIsAAIQIMECLAACECDBByYd0LrbVlVS2rqo4uHjy5d+3y8FEp368fpScM9eXKvP+785r3tqqq82+X0hOGOvy8NkV77evHd2e99+PV522T3wEvbl3tp88ebnXYLvn0/Gl6wlAnjx+lJwzzuu6nJwx18v52esJQd1/dSE8Y6s2Lw7e998Xqc58gAEIEGCBEgAFCBBggRIABQgQYIESAAUIEGCBEgAFCBBggRIABQgQYIESAAUIEGCBEgAFCBBggRIABQgQYIESAAUIEGCBEgAFCBBggRIABQgQYIESAAUIEGCBEgAFCBBggRIABQgQYIESAAUIEGCBEgAFCBBggRIABQgQYIESAAUIEGCBEgAFCBBggRIABQgQYIESAAUIEGCBEgAFCBBggRIABQgQYIESAAUIEGCBEgAFCBBggpPXe//5Ca8uqWv7880FVfRg9KuhmVZ2lRwwy821V7tt3s993p/d+vPpwbYB/ebm10977YquzdsjM9818W5X79t3s9/2JTxAAIQIMELJpgF8OWbE7Zr5v5tuq3LfvZr/vtzb6BgzA9vgEARAiwAAhAgwQIsAAIQIMEPID0IJRG96Z/ToAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x72 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from sklearn.utils import shuffle\n",
    "import matplotlib.pyplot as plt\n",
    "from aggmap import AggMap, AggMapNet\n",
    "\n",
    "from sklearn.model_selection import KFold, StratifiedKFold\n",
    "\n",
    "\n",
    "import seaborn as sns\n",
    "channel_list = [1, 3, 5, 7, 9, 11]\n",
    "color = sns.color_palette(\"rainbow_r\", len(channel_list)) #PiYG\n",
    "sns.palplot(color)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "72eb532d-0cd5-45a7-b453-ed108fcef78d",
   "metadata": {},
   "outputs": [],
   "source": [
    "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() #"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4e87802f-df13-4957-87fd-b0e562b2d9b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "_, w, h = x_train.shape\n",
    "orignal_cols = ['p-%s' % str((i+1)).zfill(len(str(w*h))) for i in range(w*h)]\n",
    "x_train_df = pd.DataFrame(x_train.reshape(x_train.shape[0], w*h), columns = orignal_cols)\n",
    "x_test_df = pd.DataFrame(x_test.reshape(x_test.shape[0], w*h), columns = orignal_cols)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ff55cb0-d6a7-4490-b8e9-3f2a7a5b82ae",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ba9efed8-11f9-42dd-9819-a171551a5c6b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2021-10-13 10:28:17,088 - \u001b[32mINFO\u001b[0m - [bidd-aggmap]\u001b[0m - Calculating distance ...\u001b[0m\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|##########| 784/784 [00:02<00:00, 385.87it/s]\n"
     ]
    }
   ],
   "source": [
    "mp = AggMap(x_train_df, metric='correlation', by_scipy=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4873b2cb-bc44-44f4-b853-b03af2229ea1",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2021-10-13 22:29:43,082 - \u001b[32mINFO\u001b[0m - [bidd-aggmap]\u001b[0m - applying hierarchical clustering to obtain group information ...\u001b[0m\n",
      "2021-10-13 22:29:49,900 - \u001b[32mINFO\u001b[0m - [bidd-aggmap]\u001b[0m - Applying grid feature map(assignment), this may take several minutes(1~30 min)\u001b[0m\n",
      "2021-10-13 22:29:50,543 - \u001b[32mINFO\u001b[0m - [bidd-aggmap]\u001b[0m - Finished\u001b[0m\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|##########| 60000/60000 [00:48<00:00, 1234.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " input train and test X shape is (48000, 28, 28, 41), (12000, 28, 28, 41) \n"
     ]
    }
   ],
   "source": [
    "n_splits = 5 #5-fold reapeat 5 times\n",
    "for c in channel_list: \n",
    "    run_all = []\n",
    "    for repeat_seed in [8]: #3 repeats random seeds\n",
    "\n",
    "        outer = StratifiedKFold(n_splits = n_splits, shuffle = True, random_state = repeat_seed)\n",
    "        outer_idx = list(outer.split(x_train_df, pd.Series(y_train)))\n",
    "\n",
    "        mp = mp.fit(cluster_channels = c, verbose = 0)\n",
    "        \n",
    "        X = mp.batch_transform(x_train_df.values)\n",
    "        Y = pd.get_dummies(pd.Series(y_train)).values\n",
    "\n",
    "        for i, idx in enumerate(outer_idx):\n",
    "            \n",
    "            train_idx, valid_idx = idx\n",
    "            fold_num = \"fold_%s\" % str(i+1).zfill(2) \n",
    "            \n",
    "            validY = Y[valid_idx]\n",
    "            validX = X[valid_idx]\n",
    "\n",
    "            trainX = X[train_idx]\n",
    "            trainY = Y[train_idx]\n",
    "\n",
    "            print(\"\\n input train and test X shape is %s, %s \" % (trainX.shape,  validX.shape))\n",
    "            clf = AggMapNet.MultiClassEstimator(epochs = 100,  batch_size = 64, verbose=0,\n",
    "                                                conv1_kernel_size = 3, dense_layers = [128, 64], \n",
    "                                                lr = 0.0001,  gpuid=1)\n",
    "            clf.fit(trainX, trainY, X_valid = validX, y_valid = validY)\n",
    "\n",
    "            history = clf.history\n",
    "            history['fold'] = fold_num\n",
    "            history['c'] = c\n",
    "            history['repeat_seed'] = repeat_seed\n",
    "            \n",
    "            run_all.append(history)\n",
    "        \n",
    "    pd.DataFrame(run_all).to_csv('C=%s.csv' % c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "000875fa-214c-426f-a6e6-e6a80cd86930",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}