Diff of /SA_MIL_training.ipynb [000000] .. [05c4fa]

Switch to side-by-side view

--- a
+++ b/SA_MIL_training.ipynb
@@ -0,0 +1,514 @@
+{
+  "nbformat": 4,
+  "nbformat_minor": 0,
+  "metadata": {
+    "colab": {
+      "provenance": []
+    },
+    "kernelspec": {
+      "name": "python3",
+      "display_name": "Python 3"
+    },
+    "language_info": {
+      "name": "python"
+    }
+  },
+  "cells": [
+    {
+      "cell_type": "code",
+      "execution_count": 1,
+      "metadata": {
+        "id": "FauGw-yKqt0k",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "e03c49ff-a6fa-44a6-9df9-c9f392824821"
+      },
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Num GPUs Available:  0\n"
+          ]
+        }
+      ],
+      "source": [
+        "####################\n",
+        "###  LIBRARIES  ####\n",
+        "####################\n",
+        "\n",
+        "import numpy as np\n",
+        "import warnings\n",
+        "import pandas as pd\n",
+        "import os\n",
+        "import matplotlib.pyplot as plt\n",
+        "import cv2\n",
+        "import networkx as nx\n",
+        "\n",
+        "# Remove TensorFlow warnings\n",
+        "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
+        "\n",
+        "# Import TensorFlow and Keras for neural network operations\n",
+        "import tensorflow as tf\n",
+        "from tensorflow import keras\n",
+        "from tensorflow.keras import layers\n",
+        "from tensorflow.keras.callbacks import EarlyStopping\n",
+        "from tensorflow.keras.losses import Loss\n",
+        "from tensorflow.python.framework.ops import disable_eager_execution\n",
+        "disable_eager_execution()\n",
+        "\n",
+        "# Set the default float type for TensorFlow to \"float32\"\n",
+        "tf.keras.backend.set_floatx(\"float32\")\n",
+        "\n",
+        "# Print the number of available GPUs\n",
+        "print(\"Num GPUs Available: \", len(tf.config.list_physical_devices('GPU')))"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "####################\n",
+        "### DATA LOADING ###\n",
+        "####################\n",
+        "\n",
+        "print('Starting preprocessing of bags')\n",
+        "\n",
+        "# Define directories for image files, have training images from three folders.\n",
+        "train_images_dir1 = './Data/train1/'\n",
+        "train_images_dir2 = './Data/train2/'\n",
+        "train_images_dir3 = './Data/train3/'\n",
+        "\n",
+        "# Get lists of files in the directories\n",
+        "train_files1 = set(os.listdir(train_images_dir1))\n",
+        "train_files2 = set(os.listdir(train_images_dir2))\n",
+        "train_files3 = set(os.listdir(train_images_dir3))\n",
+        "\n",
+        "# Read bag data from CSV files\n",
+        "train_bags = pd.read_csv(\"./tables/Training_examples.csv\")\n",
+        "\n",
+        "# Create a mapping of train files to their respective directories\n",
+        "dirs_ = [train_images_dir1, train_images_dir2, train_images_dir3]\n",
+        "train_files_loc = {\n",
+        "    k: dirs_[\n",
+        "        (k[:-4]+'.npy' in train_files1) * 1 +\n",
+        "        (k[:-4]+'.npy' in train_files2) * 2 +\n",
+        "        (k[:-4]+'.npy' in train_files3) * 3 - 1\n",
+        "    ]\n",
+        "    for k in train_bags.instance_name\n",
+        "}\n",
+        "\n",
+        "# Create lists of DCM files for train files in each directory\n",
+        "train_files1_dcm = [k[:-4] + '.dcm' for k in train_files1]\n",
+        "train_files2_dcm = [k[:-4] + '.dcm' for k in train_files2]\n",
+        "train_files3_dcm = [k[:-4] + '.dcm' for k in train_files3]\n",
+        "\n",
+        "# Filter train bags based on DCM file existence\n",
+        "train_bags = train_bags[\n",
+        "    train_bags.instance_name.isin(train_files1_dcm) |\n",
+        "    train_bags.instance_name.isin(train_files2_dcm) |\n",
+        "    train_bags.instance_name.isin(train_files3_dcm)\n",
+        "]"
+      ],
+      "metadata": {
+        "id": "sOoozeNfqxrz"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "##########################\n",
+        "### BAGS PREPROCESSING ###\n",
+        "##########################\n",
+        "\n",
+        "# Set the desired bag size\n",
+        "bag_size = 57\n",
+        "\n",
+        "# Create additional train bags to reach the desired bag size\n",
+        "added_train_bags = pd.DataFrame()\n",
+        "for idx in train_bags.bag_name.unique():\n",
+        "    bags = train_bags[train_bags.bag_name==idx].copy()\n",
+        "    num_add = bag_size - len(bags.instance_name)\n",
+        "\n",
+        "    aux = bags.iloc[0].copy()\n",
+        "    aux.instance_label = 0\n",
+        "    aux.instance_name = 'all_zeros'\n",
+        "    for i in range(num_add):\n",
+        "        added_train_bags = added_train_bags.append(aux)\n",
+        "\n",
+        "train_bags = train_bags.append(added_train_bags)\n",
+        "\n",
+        "# Convert bags data to dictionaries for optimization\n",
+        "train_bags_dic = {k: list(train_bags[train_bags.bag_name==k].instance_name) for k in train_bags.bag_name.unique()}"
+      ],
+      "metadata": {
+        "id": "1amWUWZBqx0x"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "####################\n",
+        "###  DATALOADER  ###\n",
+        "####################\n",
+        "\n",
+        "dim=(512,512,bag_size)\n",
+        "class DataGeneratorMIL(keras.utils.Sequence):\n",
+        "    'Generates data for Keras'\n",
+        "\n",
+        "    def __init__(self, list_IDs, labels=None, batch_size=256, dim=(512,512,512), n_channels=3,\n",
+        "                 n_classes=2, shuffle=True, is_train=True):\n",
+        "        'Initialization'\n",
+        "        self.dim = dim\n",
+        "        self.batch_size = batch_size\n",
+        "        self.labels = labels\n",
+        "        self.is_train = (labels is not None) and is_train\n",
+        "        self.list_IDs = list_IDs\n",
+        "        self.n_channels = n_channels\n",
+        "        self.n_classes = n_classes\n",
+        "        self.shuffle = shuffle\n",
+        "        self.on_epoch_end()\n",
+        "\n",
+        "    def __len__(self):\n",
+        "        'Denotes the number of batches per epoch'\n",
+        "        return int(np.floor(len(self.list_IDs) / self.batch_size))\n",
+        "\n",
+        "    def __getitem__(self, index):\n",
+        "        'Generate one batch of data'\n",
+        "        # Generate indexes of the batch\n",
+        "        list_IDs_temp = self.list_IDs[index*self.batch_size:(index+1)*self.batch_size]\n",
+        "\n",
+        "        X = self.__data_generation(list_IDs_temp)\n",
+        "        # Generate data\n",
+        "        if self.is_train:\n",
+        "            y = self.labels[index*self.batch_size:(index+1)*self.batch_size]\n",
+        "            return np.array(X), np.array(y, dtype='float64')\n",
+        "        else:\n",
+        "            return np.array(X)\n",
+        "\n",
+        "    def on_epoch_end(self):\n",
+        "        'Updates indexes after each epoch'\n",
+        "        self.indexes = np.arange(len(self.list_IDs))\n",
+        "        if self.shuffle == True:\n",
+        "            np.random.shuffle(self.indexes)\n",
+        "\n",
+        "    def __data_generation(self, list_IDs_temp):\n",
+        "        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)\n",
+        "        # Initialization\n",
+        "        X = np.empty((self.batch_size, *self.dim, self.n_channels))\n",
+        "\n",
+        "        # Generate data\n",
+        "        for i, ID in enumerate(list_IDs_temp):\n",
+        "            # Store sample\n",
+        "            if self.is_train:\n",
+        "                ids = train_bags_dic[ID]\n",
+        "            else:\n",
+        "                ids = test_bags_dic[ID]\n",
+        "            imgs = []\n",
+        "            for idx in ids:\n",
+        "                if idx == 'all_zeros':\n",
+        "                    img = np.zeros((self.dim[0], self.dim[1], self.n_channels))\n",
+        "                    imgs.append(img)\n",
+        "                    continue\n",
+        "                if self.is_train:\n",
+        "                    _dir = train_files_loc[idx]\n",
+        "                    img = np.load(_dir + idx[:-4] + '.npy')\n",
+        "                    img = cv2.resize(img, (self.dim[1], self.dim[0]))\n",
+        "                    imgs.append(img)\n",
+        "                else:\n",
+        "                    img = np.load(test_images_dir + idx[:-4] + '.npy')\n",
+        "                    img = cv2.resize(img, (self.dim[1], self.dim[0]))\n",
+        "                    imgs.append(img)\n",
+        "            X[i,] = np.transpose(imgs, [1,2,0,3])\n",
+        "\n",
+        "        return X"
+      ],
+      "metadata": {
+        "id": "CkWPfMEjqx1z"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "########################\n",
+        "### TRAIN/TEST SPLIT ###\n",
+        "########################\n",
+        "\n",
+        "from sklearn.model_selection import train_test_split\n",
+        "\n",
+        "N = len(train_bags.bag_name.unique())\n",
+        "bags = train_bags.groupby('bag_name').max()\n",
+        "\n",
+        "X_train, X_val, y_train, y_val = train_test_split(np.array(bags.index)[:], bags.bag_label[:],\n",
+        "                                                 test_size=0.20, random_state=0,\n",
+        "                                                 stratify=bags.bag_label[:])\n",
+        "\n",
+        "batch_size = 4\n",
+        "\n",
+        "# Creating the train dataset using DataGeneratorMIL\n",
+        "train_dataset = DataGeneratorMIL(X_train, y_train, batch_size=batch_size, dim=dim)\n",
+        "\n",
+        "# Creating the validation dataset using DataGeneratorMIL\n",
+        "val_dataset = DataGeneratorMIL(X_val, y_val, batch_size=batch_size, dim=dim, is_augment=False)"
+      ],
+      "metadata": {
+        "id": "MUdsNAz-qx4P"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "########################\n",
+        "### SA-Loss ###\n",
+        "########################\n",
+        "class smoothMIL(tf.keras.losses.Loss):\n",
+        "    def __init__(self, att_weights, alpha, S_k):\n",
+        "        super(smoothMIL, self).__init__()\n",
+        "        self.att_weights = att_weights\n",
+        "        self.alpha = alpha\n",
+        "        self.S_k = S_k\n",
+        "\n",
+        "    def compute_Laplacian(self, bag_size):\n",
+        "        G = nx.Graph()\n",
+        "        for e in range(bag_size - 1):\n",
+        "            G.add_edge(e + 1, e + 2)\n",
+        "        degree_matrix = np.diag(list(dict(G.degree()).values())) + np.eye(bag_size)\n",
+        "        adjacency_matrix = nx.adjacency_matrix(G).toarray()\n",
+        "        L = degree_matrix - adjacency_matrix\n",
+        "        return L\n",
+        "\n",
+        "    def call(self, y_true, y_pred):\n",
+        "        bce = tf.keras.losses.BinaryCrossentropy()\n",
+        "\n",
+        "        loss1 = bce(y_true, y_pred)\n",
+        "        L = self.compute_Laplacian(bag_size)\n",
+        "\n",
+        "        if self.S_k == 1:\n",
+        "            VV = tf.linalg.matmul(self.att_weights, L)\n",
+        "            loss2 = tf.linalg.matmul(VV, tf.transpose(self.att_weights, (0, 2, 1)))\n",
+        "        elif self.S_k == 2:\n",
+        "            VV = tf.linalg.matmul(self.att_weights, L)\n",
+        "            VV = tf.linalg.matmul(VV, L)\n",
+        "            loss2 = tf.linalg.matmul(VV, tf.transpose(self.att_weights, (0, 2, 1)))\n",
+        "\n",
+        "        loss2 = tf.math.reduce_mean(loss2)\n",
+        "        loss_combined = tf.math.add(self.alpha * loss1, (1 - self.alpha) * loss2)\n",
+        "\n",
+        "        return loss_combined"
+      ],
+      "metadata": {
+        "id": "HbzoGw3lrFlJ"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "####################\n",
+        "###    MODEL     ###\n",
+        "####################\n",
+        "\n",
+        "# MILAttentionLayer\n",
+        "class MILAttentionLayer(layers.Layer):\n",
+        "    \"\"\"Implementation of the attention-based Deep MIL layer.\"\"\"\n",
+        "\n",
+        "    def __init__(\n",
+        "        self,\n",
+        "        weight_params_dim,\n",
+        "        kernel_initializer=\"glorot_uniform\",\n",
+        "        kernel_regularizer=None,\n",
+        "        use_gated=False,\n",
+        "        **kwargs,\n",
+        "    ):\n",
+        "        super().__init__(**kwargs)\n",
+        "\n",
+        "        self.weight_params_dim = weight_params_dim\n",
+        "        self.use_gated = use_gated\n",
+        "\n",
+        "        self.kernel_initializer = keras.initializers.get(kernel_initializer)\n",
+        "        self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)\n",
+        "\n",
+        "        self.v_init = self.kernel_initializer\n",
+        "        self.w_init = self.kernel_initializer\n",
+        "        self.u_init = self.kernel_initializer\n",
+        "\n",
+        "        self.v_regularizer = self.kernel_regularizer\n",
+        "        self.w_regularizer = self.kernel_regularizer\n",
+        "        self.u_regularizer = self.kernel_regularizer\n",
+        "\n",
+        "    def build(self, input_shape):\n",
+        "        input_dim = input_shape[1]\n",
+        "\n",
+        "        self.v_weight_params = self.add_weight(\n",
+        "            shape=(input_dim, self.weight_params_dim),\n",
+        "            initializer=self.v_init,\n",
+        "            name=\"v\",\n",
+        "            regularizer=self.v_regularizer,\n",
+        "            trainable=True,\n",
+        "        )\n",
+        "\n",
+        "        self.w_weight_params = self.add_weight(\n",
+        "            shape=(self.weight_params_dim, 1),\n",
+        "            initializer=self.w_init,\n",
+        "            name=\"w\",\n",
+        "            regularizer=self.w_regularizer,\n",
+        "            trainable=True,\n",
+        "        )\n",
+        "\n",
+        "        if self.use_gated:\n",
+        "            self.u_weight_params = self.add_weight(\n",
+        "                shape=(input_dim, self.weight_params_dim),\n",
+        "                initializer=self.u_init,\n",
+        "                name=\"u\",\n",
+        "                regularizer=self.u_regularizer,\n",
+        "                trainable=True,\n",
+        "            )\n",
+        "        else:\n",
+        "            self.u_weight_params = None\n",
+        "\n",
+        "        self.input_built = True\n",
+        "\n",
+        "    def call(self, inputs):\n",
+        "        instances = self.compute_attention_scores(inputs)\n",
+        "        instances = tf.reshape(instances, shape=(-1, dim[2]))\n",
+        "        alpha = tf.math.softmax(instances, axis=1)\n",
+        "        return alpha\n",
+        "\n",
+        "    def compute_attention_scores(self, instance):\n",
+        "        original_instance = instance\n",
+        "        instance = tf.math.tanh(tf.tensordot(instance, self.v_weight_params, axes=1))\n",
+        "\n",
+        "        if self.use_gated:\n",
+        "            instance = instance * tf.math.sigmoid(\n",
+        "                tf.tensordot(original_instance, self.u_weight_params, axes=1)\n",
+        "            )\n",
+        "\n",
+        "        return tf.tensordot(instance, self.w_weight_params, axes=1)\n",
+        "\n",
+        "\n",
+        "# Model\n",
+        "num_data = batch_size\n",
+        "D = bag_size\n",
+        "\n",
+        "Conv1 = layers.Conv2D(16, (5, 5), data_format=\"channels_last\", activation='relu', kernel_initializer='glorot_uniform', padding='same')\n",
+        "Conv2 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
+        "Conv3 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
+        "Conv4 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
+        "Conv5 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
+        "Conv6 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
+        "\n",
+        "def VGG(inp):\n",
+        "    inp = tf.reshape(tf.transpose(inp, perm=(0,3,1,2,4)), shape=(-1, dim[0], dim[1], 3))\n",
+        "    x = Conv1(inp)\n",
+        "    x = layers.BatchNormalization()(x)\n",
+        "    x = layers.MaxPool2D((2, 2), data_format=\"channels_last\", strides=(2, 2))(x)\n",
+        "    x = Conv2(x)\n",
+        "    x = layers.BatchNormalization()(x)\n",
+        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
+        "    x = layers.Dropout(0.3)(x)\n",
+        "\n",
+        "    x = Conv3(x)\n",
+        "    x = layers.BatchNormalization()(x)\n",
+        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
+        "    x = Conv4(x)\n",
+        "    x = layers.BatchNormalization()(x)\n",
+        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
+        "\n",
+        "    x = Conv5(x)\n",
+        "    x = layers.BatchNormalization()(x)\n",
+        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
+        "    x = layers.Dropout(0.3)(x)\n",
+        "\n",
+        "    x = Conv6(x)\n",
+        "    x = layers.BatchNormalization()(x)\n",
+        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
+        "    x = layers.Dropout(0.3)(x)\n",
+        "\n",
+        "    return layers.Flatten()(x)\n",
+        "\n",
+        "def build_model():\n",
+        "    inp = keras.Input(shape=(*dim, 3))\n",
+        "    H = VGG(inp)\n",
+        "    A = MILAttentionLayer(\n",
+        "        weight_params_dim=64,\n",
+        "        kernel_regularizer=keras.regularizers.l2(0.01),\n",
+        "        use_gated=True,\n",
+        "        name=\"alpha\",\n",
+        "    )(H)\n",
+        "    H = tf.reshape(H, shape=(-1, dim[2], H.shape[1]))\n",
+        "    A = tf.expand_dims(A, axis=1)\n",
+        "    intermediate = tf.linalg.matmul(A, H)\n",
+        "    intermediate = tf.squeeze(intermediate, axis=1)\n",
+        "    intermediate = layers.Dropout(0.25)(intermediate)\n",
+        "    intermediate = layers.Dense(128)(intermediate)\n",
+        "    out = layers.Dense(1, activation='sigmoid')(intermediate)\n",
+        "    return A, keras.Model(inputs=inp, outputs=out)\n",
+        "\n",
+        "A, model = build_model()\n",
+        "\n",
+        "auc = tf.keras.metrics.AUC()\n",
+        "adam = tf.compat.v1.train.AdamOptimizer(learning_rate=5e-5)\n",
+        "model.compile(\n",
+        "    optimizer=adam,\n",
+        "    loss= smoothMIL(A, 0.5, 1),\n",
+        "    metrics=[auc, 'accuracy']\n",
+        ")\n",
+        "earlyStopping = EarlyStopping(monitor='val_loss', patience=8, verbose=1, mode='min')\n",
+        "print(model.summary())"
+      ],
+      "metadata": {
+        "id": "TUE7oM38qx6y"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "####################\n",
+        "###    Train    ###\n",
+        "####################\n",
+        "for i in range(0, 5):\n",
+        "    checkpoint_path = \"./model/att_{}.ckpt\".format(i)\n",
+        "    checkpoint_dir = os.path.dirname(checkpoint_path)\n",
+        "\n",
+        "    cp_callback = keras.callbacks.ModelCheckpoint(\n",
+        "        filepath=checkpoint_path,\n",
+        "        monitor='val_loss',\n",
+        "        save_best_only=True,\n",
+        "        save_weights_only=True,\n",
+        "        verbose=1,\n",
+        "        mode='min'\n",
+        "    )\n",
+        "\n",
+        "\n",
+        "    history = model.fit(\n",
+        "        train_dataset,\n",
+        "        validation_data=val_dataset,\n",
+        "        epochs=200,\n",
+        "        callbacks=[earlyStopping, cp_callback],\n",
+        "    )\n",
+        "\n",
+        "    hist_df = pd.DataFrame(history.history)\n",
+        "    hist_csv_file = './log.csv'\n",
+        "\n",
+        "    with open(hist_csv_file, mode='w') as f:\n",
+        "        hist_df.to_csv(f)"
+      ],
+      "metadata": {
+        "id": "ffrwrgDRqx9f"
+      },
+      "execution_count": null,
+      "outputs": []
+    }
+  ]
+}
\ No newline at end of file