[72fbc7]: / SA_MIL_training.ipynb

Download this file

514 lines (514 with data), 20.9 kB

{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "FauGw-yKqt0k",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "e03c49ff-a6fa-44a6-9df9-c9f392824821"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Num GPUs Available:  0\n"
          ]
        }
      ],
      "source": [
        "####################\n",
        "###  LIBRARIES  ####\n",
        "####################\n",
        "\n",
        "import numpy as np\n",
        "import warnings\n",
        "import pandas as pd\n",
        "import os\n",
        "import matplotlib.pyplot as plt\n",
        "import cv2\n",
        "import networkx as nx\n",
        "\n",
        "# Remove TensorFlow warnings\n",
        "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
        "\n",
        "# Import TensorFlow and Keras for neural network operations\n",
        "import tensorflow as tf\n",
        "from tensorflow import keras\n",
        "from tensorflow.keras import layers\n",
        "from tensorflow.keras.callbacks import EarlyStopping\n",
        "from tensorflow.keras.losses import Loss\n",
        "from tensorflow.python.framework.ops import disable_eager_execution\n",
        "disable_eager_execution()\n",
        "\n",
        "# Set the default float type for TensorFlow to \"float32\"\n",
        "tf.keras.backend.set_floatx(\"float32\")\n",
        "\n",
        "# Print the number of available GPUs\n",
        "print(\"Num GPUs Available: \", len(tf.config.list_physical_devices('GPU')))"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "####################\n",
        "### DATA LOADING ###\n",
        "####################\n",
        "\n",
        "print('Starting preprocessing of bags')\n",
        "\n",
        "# Define directories for image files, have training images from three folders.\n",
        "train_images_dir1 = './Data/train1/'\n",
        "train_images_dir2 = './Data/train2/'\n",
        "train_images_dir3 = './Data/train3/'\n",
        "\n",
        "# Get lists of files in the directories\n",
        "train_files1 = set(os.listdir(train_images_dir1))\n",
        "train_files2 = set(os.listdir(train_images_dir2))\n",
        "train_files3 = set(os.listdir(train_images_dir3))\n",
        "\n",
        "# Read bag data from CSV files\n",
        "train_bags = pd.read_csv(\"./tables/Training_examples.csv\")\n",
        "\n",
        "# Create a mapping of train files to their respective directories\n",
        "dirs_ = [train_images_dir1, train_images_dir2, train_images_dir3]\n",
        "train_files_loc = {\n",
        "    k: dirs_[\n",
        "        (k[:-4]+'.npy' in train_files1) * 1 +\n",
        "        (k[:-4]+'.npy' in train_files2) * 2 +\n",
        "        (k[:-4]+'.npy' in train_files3) * 3 - 1\n",
        "    ]\n",
        "    for k in train_bags.instance_name\n",
        "}\n",
        "\n",
        "# Create lists of DCM files for train files in each directory\n",
        "train_files1_dcm = [k[:-4] + '.dcm' for k in train_files1]\n",
        "train_files2_dcm = [k[:-4] + '.dcm' for k in train_files2]\n",
        "train_files3_dcm = [k[:-4] + '.dcm' for k in train_files3]\n",
        "\n",
        "# Filter train bags based on DCM file existence\n",
        "train_bags = train_bags[\n",
        "    train_bags.instance_name.isin(train_files1_dcm) |\n",
        "    train_bags.instance_name.isin(train_files2_dcm) |\n",
        "    train_bags.instance_name.isin(train_files3_dcm)\n",
        "]"
      ],
      "metadata": {
        "id": "sOoozeNfqxrz"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "##########################\n",
        "### BAGS PREPROCESSING ###\n",
        "##########################\n",
        "\n",
        "# Set the desired bag size\n",
        "bag_size = 57\n",
        "\n",
        "# Create additional train bags to reach the desired bag size\n",
        "added_train_bags = pd.DataFrame()\n",
        "for idx in train_bags.bag_name.unique():\n",
        "    bags = train_bags[train_bags.bag_name==idx].copy()\n",
        "    num_add = bag_size - len(bags.instance_name)\n",
        "\n",
        "    aux = bags.iloc[0].copy()\n",
        "    aux.instance_label = 0\n",
        "    aux.instance_name = 'all_zeros'\n",
        "    for i in range(num_add):\n",
        "        added_train_bags = added_train_bags.append(aux)\n",
        "\n",
        "train_bags = train_bags.append(added_train_bags)\n",
        "\n",
        "# Convert bags data to dictionaries for optimization\n",
        "train_bags_dic = {k: list(train_bags[train_bags.bag_name==k].instance_name) for k in train_bags.bag_name.unique()}"
      ],
      "metadata": {
        "id": "1amWUWZBqx0x"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "####################\n",
        "###  DATALOADER  ###\n",
        "####################\n",
        "\n",
        "dim=(512,512,bag_size)\n",
        "class DataGeneratorMIL(keras.utils.Sequence):\n",
        "    'Generates data for Keras'\n",
        "\n",
        "    def __init__(self, list_IDs, labels=None, batch_size=256, dim=(512,512,512), n_channels=3,\n",
        "                 n_classes=2, shuffle=True, is_train=True):\n",
        "        'Initialization'\n",
        "        self.dim = dim\n",
        "        self.batch_size = batch_size\n",
        "        self.labels = labels\n",
        "        self.is_train = (labels is not None) and is_train\n",
        "        self.list_IDs = list_IDs\n",
        "        self.n_channels = n_channels\n",
        "        self.n_classes = n_classes\n",
        "        self.shuffle = shuffle\n",
        "        self.on_epoch_end()\n",
        "\n",
        "    def __len__(self):\n",
        "        'Denotes the number of batches per epoch'\n",
        "        return int(np.floor(len(self.list_IDs) / self.batch_size))\n",
        "\n",
        "    def __getitem__(self, index):\n",
        "        'Generate one batch of data'\n",
        "        # Generate indexes of the batch\n",
        "        list_IDs_temp = self.list_IDs[index*self.batch_size:(index+1)*self.batch_size]\n",
        "\n",
        "        X = self.__data_generation(list_IDs_temp)\n",
        "        # Generate data\n",
        "        if self.is_train:\n",
        "            y = self.labels[index*self.batch_size:(index+1)*self.batch_size]\n",
        "            return np.array(X), np.array(y, dtype='float64')\n",
        "        else:\n",
        "            return np.array(X)\n",
        "\n",
        "    def on_epoch_end(self):\n",
        "        'Updates indexes after each epoch'\n",
        "        self.indexes = np.arange(len(self.list_IDs))\n",
        "        if self.shuffle == True:\n",
        "            np.random.shuffle(self.indexes)\n",
        "\n",
        "    def __data_generation(self, list_IDs_temp):\n",
        "        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)\n",
        "        # Initialization\n",
        "        X = np.empty((self.batch_size, *self.dim, self.n_channels))\n",
        "\n",
        "        # Generate data\n",
        "        for i, ID in enumerate(list_IDs_temp):\n",
        "            # Store sample\n",
        "            if self.is_train:\n",
        "                ids = train_bags_dic[ID]\n",
        "            else:\n",
        "                ids = test_bags_dic[ID]\n",
        "            imgs = []\n",
        "            for idx in ids:\n",
        "                if idx == 'all_zeros':\n",
        "                    img = np.zeros((self.dim[0], self.dim[1], self.n_channels))\n",
        "                    imgs.append(img)\n",
        "                    continue\n",
        "                if self.is_train:\n",
        "                    _dir = train_files_loc[idx]\n",
        "                    img = np.load(_dir + idx[:-4] + '.npy')\n",
        "                    img = cv2.resize(img, (self.dim[1], self.dim[0]))\n",
        "                    imgs.append(img)\n",
        "                else:\n",
        "                    img = np.load(test_images_dir + idx[:-4] + '.npy')\n",
        "                    img = cv2.resize(img, (self.dim[1], self.dim[0]))\n",
        "                    imgs.append(img)\n",
        "            X[i,] = np.transpose(imgs, [1,2,0,3])\n",
        "\n",
        "        return X"
      ],
      "metadata": {
        "id": "CkWPfMEjqx1z"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "########################\n",
        "### TRAIN/TEST SPLIT ###\n",
        "########################\n",
        "\n",
        "from sklearn.model_selection import train_test_split\n",
        "\n",
        "N = len(train_bags.bag_name.unique())\n",
        "bags = train_bags.groupby('bag_name').max()\n",
        "\n",
        "X_train, X_val, y_train, y_val = train_test_split(np.array(bags.index)[:], bags.bag_label[:],\n",
        "                                                 test_size=0.20, random_state=0,\n",
        "                                                 stratify=bags.bag_label[:])\n",
        "\n",
        "batch_size = 4\n",
        "\n",
        "# Creating the train dataset using DataGeneratorMIL\n",
        "train_dataset = DataGeneratorMIL(X_train, y_train, batch_size=batch_size, dim=dim)\n",
        "\n",
        "# Creating the validation dataset using DataGeneratorMIL\n",
        "val_dataset = DataGeneratorMIL(X_val, y_val, batch_size=batch_size, dim=dim, is_augment=False)"
      ],
      "metadata": {
        "id": "MUdsNAz-qx4P"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "########################\n",
        "### SA-Loss ###\n",
        "########################\n",
        "class smoothMIL(tf.keras.losses.Loss):\n",
        "    def __init__(self, att_weights, alpha, S_k):\n",
        "        super(smoothMIL, self).__init__()\n",
        "        self.att_weights = att_weights\n",
        "        self.alpha = alpha\n",
        "        self.S_k = S_k\n",
        "\n",
        "    def compute_Laplacian(self, bag_size):\n",
        "        G = nx.Graph()\n",
        "        for e in range(bag_size - 1):\n",
        "            G.add_edge(e + 1, e + 2)\n",
        "        degree_matrix = np.diag(list(dict(G.degree()).values())) + np.eye(bag_size)\n",
        "        adjacency_matrix = nx.adjacency_matrix(G).toarray()\n",
        "        L = degree_matrix - adjacency_matrix\n",
        "        return L\n",
        "\n",
        "    def call(self, y_true, y_pred):\n",
        "        bce = tf.keras.losses.BinaryCrossentropy()\n",
        "\n",
        "        loss1 = bce(y_true, y_pred)\n",
        "        L = self.compute_Laplacian(bag_size)\n",
        "\n",
        "        if self.S_k == 1:\n",
        "            VV = tf.linalg.matmul(self.att_weights, L)\n",
        "            loss2 = tf.linalg.matmul(VV, tf.transpose(self.att_weights, (0, 2, 1)))\n",
        "        elif self.S_k == 2:\n",
        "            VV = tf.linalg.matmul(self.att_weights, L)\n",
        "            VV = tf.linalg.matmul(VV, L)\n",
        "            loss2 = tf.linalg.matmul(VV, tf.transpose(self.att_weights, (0, 2, 1)))\n",
        "\n",
        "        loss2 = tf.math.reduce_mean(loss2)\n",
        "        loss_combined = tf.math.add(self.alpha * loss1, (1 - self.alpha) * loss2)\n",
        "\n",
        "        return loss_combined"
      ],
      "metadata": {
        "id": "HbzoGw3lrFlJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "####################\n",
        "###    MODEL     ###\n",
        "####################\n",
        "\n",
        "# MILAttentionLayer\n",
        "class MILAttentionLayer(layers.Layer):\n",
        "    \"\"\"Implementation of the attention-based Deep MIL layer.\"\"\"\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        weight_params_dim,\n",
        "        kernel_initializer=\"glorot_uniform\",\n",
        "        kernel_regularizer=None,\n",
        "        use_gated=False,\n",
        "        **kwargs,\n",
        "    ):\n",
        "        super().__init__(**kwargs)\n",
        "\n",
        "        self.weight_params_dim = weight_params_dim\n",
        "        self.use_gated = use_gated\n",
        "\n",
        "        self.kernel_initializer = keras.initializers.get(kernel_initializer)\n",
        "        self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)\n",
        "\n",
        "        self.v_init = self.kernel_initializer\n",
        "        self.w_init = self.kernel_initializer\n",
        "        self.u_init = self.kernel_initializer\n",
        "\n",
        "        self.v_regularizer = self.kernel_regularizer\n",
        "        self.w_regularizer = self.kernel_regularizer\n",
        "        self.u_regularizer = self.kernel_regularizer\n",
        "\n",
        "    def build(self, input_shape):\n",
        "        input_dim = input_shape[1]\n",
        "\n",
        "        self.v_weight_params = self.add_weight(\n",
        "            shape=(input_dim, self.weight_params_dim),\n",
        "            initializer=self.v_init,\n",
        "            name=\"v\",\n",
        "            regularizer=self.v_regularizer,\n",
        "            trainable=True,\n",
        "        )\n",
        "\n",
        "        self.w_weight_params = self.add_weight(\n",
        "            shape=(self.weight_params_dim, 1),\n",
        "            initializer=self.w_init,\n",
        "            name=\"w\",\n",
        "            regularizer=self.w_regularizer,\n",
        "            trainable=True,\n",
        "        )\n",
        "\n",
        "        if self.use_gated:\n",
        "            self.u_weight_params = self.add_weight(\n",
        "                shape=(input_dim, self.weight_params_dim),\n",
        "                initializer=self.u_init,\n",
        "                name=\"u\",\n",
        "                regularizer=self.u_regularizer,\n",
        "                trainable=True,\n",
        "            )\n",
        "        else:\n",
        "            self.u_weight_params = None\n",
        "\n",
        "        self.input_built = True\n",
        "\n",
        "    def call(self, inputs):\n",
        "        instances = self.compute_attention_scores(inputs)\n",
        "        instances = tf.reshape(instances, shape=(-1, dim[2]))\n",
        "        alpha = tf.math.softmax(instances, axis=1)\n",
        "        return alpha\n",
        "\n",
        "    def compute_attention_scores(self, instance):\n",
        "        original_instance = instance\n",
        "        instance = tf.math.tanh(tf.tensordot(instance, self.v_weight_params, axes=1))\n",
        "\n",
        "        if self.use_gated:\n",
        "            instance = instance * tf.math.sigmoid(\n",
        "                tf.tensordot(original_instance, self.u_weight_params, axes=1)\n",
        "            )\n",
        "\n",
        "        return tf.tensordot(instance, self.w_weight_params, axes=1)\n",
        "\n",
        "\n",
        "# Model\n",
        "num_data = batch_size\n",
        "D = bag_size\n",
        "\n",
        "Conv1 = layers.Conv2D(16, (5, 5), data_format=\"channels_last\", activation='relu', kernel_initializer='glorot_uniform', padding='same')\n",
        "Conv2 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
        "Conv3 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
        "Conv4 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
        "Conv5 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
        "Conv6 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
        "\n",
        "def VGG(inp):\n",
        "    inp = tf.reshape(tf.transpose(inp, perm=(0,3,1,2,4)), shape=(-1, dim[0], dim[1], 3))\n",
        "    x = Conv1(inp)\n",
        "    x = layers.BatchNormalization()(x)\n",
        "    x = layers.MaxPool2D((2, 2), data_format=\"channels_last\", strides=(2, 2))(x)\n",
        "    x = Conv2(x)\n",
        "    x = layers.BatchNormalization()(x)\n",
        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
        "    x = layers.Dropout(0.3)(x)\n",
        "\n",
        "    x = Conv3(x)\n",
        "    x = layers.BatchNormalization()(x)\n",
        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
        "    x = Conv4(x)\n",
        "    x = layers.BatchNormalization()(x)\n",
        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
        "\n",
        "    x = Conv5(x)\n",
        "    x = layers.BatchNormalization()(x)\n",
        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
        "    x = layers.Dropout(0.3)(x)\n",
        "\n",
        "    x = Conv6(x)\n",
        "    x = layers.BatchNormalization()(x)\n",
        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
        "    x = layers.Dropout(0.3)(x)\n",
        "\n",
        "    return layers.Flatten()(x)\n",
        "\n",
        "def build_model():\n",
        "    inp = keras.Input(shape=(*dim, 3))\n",
        "    H = VGG(inp)\n",
        "    A = MILAttentionLayer(\n",
        "        weight_params_dim=64,\n",
        "        kernel_regularizer=keras.regularizers.l2(0.01),\n",
        "        use_gated=True,\n",
        "        name=\"alpha\",\n",
        "    )(H)\n",
        "    H = tf.reshape(H, shape=(-1, dim[2], H.shape[1]))\n",
        "    A = tf.expand_dims(A, axis=1)\n",
        "    intermediate = tf.linalg.matmul(A, H)\n",
        "    intermediate = tf.squeeze(intermediate, axis=1)\n",
        "    intermediate = layers.Dropout(0.25)(intermediate)\n",
        "    intermediate = layers.Dense(128)(intermediate)\n",
        "    out = layers.Dense(1, activation='sigmoid')(intermediate)\n",
        "    return A, keras.Model(inputs=inp, outputs=out)\n",
        "\n",
        "A, model = build_model()\n",
        "\n",
        "auc = tf.keras.metrics.AUC()\n",
        "adam = tf.compat.v1.train.AdamOptimizer(learning_rate=5e-5)\n",
        "model.compile(\n",
        "    optimizer=adam,\n",
        "    loss= smoothMIL(A, 0.5, 1),\n",
        "    metrics=[auc, 'accuracy']\n",
        ")\n",
        "earlyStopping = EarlyStopping(monitor='val_loss', patience=8, verbose=1, mode='min')\n",
        "print(model.summary())"
      ],
      "metadata": {
        "id": "TUE7oM38qx6y"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "####################\n",
        "###    Train    ###\n",
        "####################\n",
        "for i in range(0, 5):\n",
        "    checkpoint_path = \"./model/att_{}.ckpt\".format(i)\n",
        "    checkpoint_dir = os.path.dirname(checkpoint_path)\n",
        "\n",
        "    cp_callback = keras.callbacks.ModelCheckpoint(\n",
        "        filepath=checkpoint_path,\n",
        "        monitor='val_loss',\n",
        "        save_best_only=True,\n",
        "        save_weights_only=True,\n",
        "        verbose=1,\n",
        "        mode='min'\n",
        "    )\n",
        "\n",
        "\n",
        "    history = model.fit(\n",
        "        train_dataset,\n",
        "        validation_data=val_dataset,\n",
        "        epochs=200,\n",
        "        callbacks=[earlyStopping, cp_callback],\n",
        "    )\n",
        "\n",
        "    hist_df = pd.DataFrame(history.history)\n",
        "    hist_csv_file = './log.csv'\n",
        "\n",
        "    with open(hist_csv_file, mode='w') as f:\n",
        "        hist_df.to_csv(f)"
      ],
      "metadata": {
        "id": "ffrwrgDRqx9f"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}