[72fbc7]: / SA_MIL_testing.ipynb

Download this file

594 lines (594 with data), 30.5 kB

{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "h2430ByQ7FUA",
        "outputId": "b92d16bd-df4e-46ed-cd43-e718f6b971c2"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Num GPUs Available:  0\n"
          ]
        }
      ],
      "source": [
        "####################\n",
        "###  LIBRARIES  ####\n",
        "####################\n",
        "\n",
        "import numpy as np\n",
        "import warnings\n",
        "import pandas as pd\n",
        "import os\n",
        "import matplotlib.pyplot as plt\n",
        "import cv2\n",
        "\n",
        "# Remove TensorFlow warnings\n",
        "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
        "\n",
        "# Import TensorFlow and Keras for neural network operations\n",
        "import tensorflow as tf\n",
        "from tensorflow import keras\n",
        "from tensorflow.keras import layers\n",
        "from tensorflow.keras.callbacks import EarlyStopping\n",
        "from tensorflow.keras.losses import Loss\n",
        "from tensorflow.python.framework.ops import disable_eager_execution\n",
        "disable_eager_execution()\n",
        "\n",
        "# Set the default float type for TensorFlow to \"float32\"\n",
        "tf.keras.backend.set_floatx(\"float32\")\n",
        "\n",
        "# Print the number of available GPUs\n",
        "print(\"Num GPUs Available: \", len(tf.config.list_physical_devices('GPU')))"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "####################\n",
        "### DATA LOADING ###\n",
        "####################\n",
        "\n",
        "print('Starting preprocessing of bags')\n",
        "\n",
        "# Define directories for image files during testing\n",
        "test_images_dir = './test/'\n",
        "\n",
        "# Get lists of files in the directories\n",
        "test_files = os.listdir(test_images_dir)\n",
        "\n",
        "# Read bag data from CSV files\n",
        "test_bags = pd.read_csv(\"./tables/testing_example.csv\")\n",
        "\n",
        "# Filter test bags based on DCM file existence\n",
        "test_files_dcm = [k[:-4] + '.dcm' for k in test_files]\n",
        "test_bags = test_bags[test_bags.instance_name.isin(test_files_dcm)]"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "mtZQI9Hs7Juq",
        "outputId": "d2d52a49-853d-40b9-b9ae-ebe18ba7cf8f"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Starting preprocessing of bags\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "##########################\n",
        "### BAGS PREPROCESSING ###\n",
        "##########################\n",
        "\n",
        "# Set the desired bag size\n",
        "bag_size = 57\n",
        "\n",
        "# Create additional test bags to reach the desired bag size\n",
        "added_test_bags = pd.DataFrame()\n",
        "for idx in test_bags.bag_name.unique():\n",
        "    bags = test_bags[test_bags.bag_name==idx].copy()\n",
        "    num_add = bag_size - len(bags.instance_name)\n",
        "\n",
        "    aux = bags.iloc[0].copy()\n",
        "    aux.instance_label = 0\n",
        "    aux.instance_name = 'all_zeros'\n",
        "    for i in range(num_add):\n",
        "        added_test_bags = added_test_bags.append(aux)\n",
        "\n",
        "test_bags = test_bags.append(added_test_bags)\n",
        "\n",
        "# Convert bags data to dictionaries for optimization\n",
        "test_bags_dic = {k: list(test_bags[test_bags.bag_name==k].instance_name) for k in test_bags.bag_name.unique()}"
      ],
      "metadata": {
        "id": "5Hhr38Dx7JxO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "####################\n",
        "###  DATALOADER  ###\n",
        "####################\n",
        "dim=(512,512,bag_size)\n",
        "\n",
        "class DataGeneratorMIL(keras.utils.Sequence):\n",
        "    'Generates data for Keras'\n",
        "\n",
        "    def __init__(self, list_IDs, labels=None, batch_size=256, dim=(512,512,512), n_channels=3,\n",
        "                 n_classes=2, shuffle=True, is_train=True):\n",
        "        'Initialization'\n",
        "        self.dim = dim\n",
        "        self.batch_size = batch_size\n",
        "        self.labels = labels\n",
        "        self.is_train = (labels is not None) and is_train\n",
        "        self.list_IDs = list_IDs\n",
        "        self.n_channels = n_channels\n",
        "        self.n_classes = n_classes\n",
        "        self.shuffle = shuffle\n",
        "        self.on_epoch_end()\n",
        "\n",
        "    def __len__(self):\n",
        "        'Denotes the number of batches per epoch'\n",
        "        return int(np.floor(len(self.list_IDs) / self.batch_size))\n",
        "\n",
        "    def __getitem__(self, index):\n",
        "        'Generate one batch of data'\n",
        "        # Generate indexes of the batch\n",
        "        list_IDs_temp = self.list_IDs[index*self.batch_size:(index+1)*self.batch_size]\n",
        "\n",
        "        X = self.__data_generation(list_IDs_temp)\n",
        "        # Generate data\n",
        "        if self.is_train:\n",
        "            y = self.labels[index*self.batch_size:(index+1)*self.batch_size]\n",
        "            return np.array(X), np.array(y, dtype='float64')\n",
        "        else:\n",
        "            return np.array(X)\n",
        "\n",
        "    def on_epoch_end(self):\n",
        "        'Updates indexes after each epoch'\n",
        "        self.indexes = np.arange(len(self.list_IDs))\n",
        "        if self.shuffle == True:\n",
        "            np.random.shuffle(self.indexes)\n",
        "\n",
        "    def __data_generation(self, list_IDs_temp):\n",
        "        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)\n",
        "        # Initialization\n",
        "        X = np.empty((self.batch_size, *self.dim, self.n_channels))\n",
        "\n",
        "        # Generate data\n",
        "        for i, ID in enumerate(list_IDs_temp):\n",
        "            # Store sample\n",
        "            if self.is_train:\n",
        "                ids = train_bags_dic[ID]\n",
        "            else:\n",
        "                ids = test_bags_dic[ID]\n",
        "            imgs = []\n",
        "            for idx in ids:\n",
        "                if idx == 'all_zeros':\n",
        "                    img = np.zeros((self.dim[0], self.dim[1], self.n_channels))\n",
        "                    imgs.append(img)\n",
        "                    continue\n",
        "                if self.is_train:\n",
        "                    _dir = train_files_loc[idx]\n",
        "                    img = np.load(_dir + idx[:-4] + '.npy')\n",
        "                    img = cv2.resize(img, (self.dim[1], self.dim[0]))\n",
        "                    imgs.append(img)\n",
        "                else:\n",
        "                    img = np.load(test_images_dir + idx[:-4] + '.npy')\n",
        "                    img = cv2.resize(img, (self.dim[1], self.dim[0]))\n",
        "                    imgs.append(img)\n",
        "            X[i,] = np.transpose(imgs, [1,2,0,3])\n",
        "\n",
        "        return X"
      ],
      "metadata": {
        "id": "pWlaSxnt7Jz3"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "########################\n",
        "### Test Generator ###\n",
        "########################\n",
        "batch_size = 1\n",
        "\n",
        "# Preparing the test dataset\n",
        "bags2 = test_bags.groupby('bag_name').max()\n",
        "test_dataset = DataGeneratorMIL(np.array(bags2.index), bags2.bag_label, batch_size=1, dim=dim, is_train=False)"
      ],
      "metadata": {
        "id": "w_XGFBbI7J2e"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "####################\n",
        "###    MODEL     ###\n",
        "####################\n",
        "\n",
        "# MILAttentionLayer\n",
        "class MILAttentionLayer(layers.Layer):\n",
        "    \"\"\"Implementation of the attention-based Deep MIL layer.\"\"\"\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        weight_params_dim,\n",
        "        kernel_initializer=\"glorot_uniform\",\n",
        "        kernel_regularizer=None,\n",
        "        use_gated=False,\n",
        "        **kwargs,\n",
        "    ):\n",
        "        super().__init__(**kwargs)\n",
        "\n",
        "        self.weight_params_dim = weight_params_dim\n",
        "        self.use_gated = use_gated\n",
        "\n",
        "        self.kernel_initializer = keras.initializers.get(kernel_initializer)\n",
        "        self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)\n",
        "\n",
        "        self.v_init = self.kernel_initializer\n",
        "        self.w_init = self.kernel_initializer\n",
        "        self.u_init = self.kernel_initializer\n",
        "\n",
        "        self.v_regularizer = self.kernel_regularizer\n",
        "        self.w_regularizer = self.kernel_regularizer\n",
        "        self.u_regularizer = self.kernel_regularizer\n",
        "\n",
        "    def build(self, input_shape):\n",
        "        input_dim = input_shape[1]\n",
        "\n",
        "        self.v_weight_params = self.add_weight(\n",
        "            shape=(input_dim, self.weight_params_dim),\n",
        "            initializer=self.v_init,\n",
        "            name=\"v\",\n",
        "            regularizer=self.v_regularizer,\n",
        "            trainable=True,\n",
        "        )\n",
        "\n",
        "        self.w_weight_params = self.add_weight(\n",
        "            shape=(self.weight_params_dim, 1),\n",
        "            initializer=self.w_init,\n",
        "            name=\"w\",\n",
        "            regularizer=self.w_regularizer,\n",
        "            trainable=True,\n",
        "        )\n",
        "\n",
        "        if self.use_gated:\n",
        "            self.u_weight_params = self.add_weight(\n",
        "                shape=(input_dim, self.weight_params_dim),\n",
        "                initializer=self.u_init,\n",
        "                name=\"u\",\n",
        "                regularizer=self.u_regularizer,\n",
        "                trainable=True,\n",
        "            )\n",
        "        else:\n",
        "            self.u_weight_params = None\n",
        "\n",
        "        self.input_built = True\n",
        "\n",
        "    def call(self, inputs):\n",
        "        instances = self.compute_attention_scores(inputs)\n",
        "        instances = tf.reshape(instances, shape=(-1, dim[2]))\n",
        "        alpha = tf.math.softmax(instances, axis=1)\n",
        "        return alpha\n",
        "\n",
        "    def compute_attention_scores(self, instance):\n",
        "        original_instance = instance\n",
        "        instance = tf.math.tanh(tf.tensordot(instance, self.v_weight_params, axes=1))\n",
        "\n",
        "        if self.use_gated:\n",
        "            instance = instance * tf.math.sigmoid(\n",
        "                tf.tensordot(original_instance, self.u_weight_params, axes=1)\n",
        "            )\n",
        "\n",
        "        return tf.tensordot(instance, self.w_weight_params, axes=1)\n",
        "\n",
        "\n",
        "# Model\n",
        "num_data = batch_size\n",
        "D = bag_size\n",
        "\n",
        "Conv1 = layers.Conv2D(16, (5, 5), data_format=\"channels_last\", activation='relu', kernel_initializer='glorot_uniform', padding='same')\n",
        "Conv2 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
        "Conv3 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
        "Conv4 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
        "Conv5 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
        "Conv6 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
        "\n",
        "def VGG(inp):\n",
        "    inp = tf.reshape(tf.transpose(inp, perm=(0,3,1,2,4)), shape=(-1, dim[0], dim[1], 3))\n",
        "    x = Conv1(inp)\n",
        "    x = layers.BatchNormalization()(x)\n",
        "    x = layers.MaxPool2D((2, 2), data_format=\"channels_last\", strides=(2, 2))(x)\n",
        "    x = Conv2(x)\n",
        "    x = layers.BatchNormalization()(x)\n",
        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
        "    x = layers.Dropout(0.3)(x)\n",
        "\n",
        "    x = Conv3(x)\n",
        "    x = layers.BatchNormalization()(x)\n",
        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
        "    x = Conv4(x)\n",
        "    x = layers.BatchNormalization()(x)\n",
        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
        "\n",
        "    x = Conv5(x)\n",
        "    x = layers.BatchNormalization()(x)\n",
        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
        "    x = layers.Dropout(0.3)(x)\n",
        "\n",
        "    x = Conv6(x)\n",
        "    x = layers.BatchNormalization()(x)\n",
        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
        "    x = layers.Dropout(0.3)(x)\n",
        "\n",
        "    return layers.Flatten()(x)\n",
        "\n",
        "def build_model():\n",
        "    inp = keras.Input(shape=(*dim, 3))\n",
        "    H = VGG(inp)\n",
        "    A = MILAttentionLayer(\n",
        "        weight_params_dim=64,\n",
        "        kernel_regularizer=keras.regularizers.l2(0.01),\n",
        "        use_gated=True,\n",
        "        name=\"alpha\",\n",
        "    )(H)\n",
        "    H = tf.reshape(H, shape=(-1, dim[2], H.shape[1]))\n",
        "    A = tf.expand_dims(A, axis=1)\n",
        "    intermediate = tf.linalg.matmul(A, H)\n",
        "    intermediate = tf.squeeze(intermediate, axis=1)\n",
        "    intermediate = layers.Dropout(0.25)(intermediate)\n",
        "    intermediate = layers.Dense(128)(intermediate)\n",
        "    out = layers.Dense(1, activation='sigmoid')(intermediate)\n",
        "    return keras.Model(inputs=inp, outputs=out)\n",
        "\n",
        "model = build_model()\n",
        "print(model.summary())"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "T-m_DhYo7J48",
        "outputId": "65b2c285-9aae-45b5-e17f-ecc75d17808b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "WARNING:tensorflow:From /usr/local/lib/python3.10/dist-packages/keras/layers/normalization/batch_normalization.py:581: _colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "Colocations handled automatically by placer.\n",
            "/usr/local/lib/python3.10/dist-packages/keras/initializers/initializers.py:120: UserWarning: The initializer GlorotUniform is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initalizer instance more than once.\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Model: \"model\"\n",
            "__________________________________________________________________________________________________\n",
            " Layer (type)                   Output Shape         Param #     Connected to                     \n",
            "==================================================================================================\n",
            " input_1 (InputLayer)           [(None, 512, 512, 5  0           []                               \n",
            "                                7, 3)]                                                            \n",
            "                                                                                                  \n",
            " tf_op_layer_transpose (TensorF  [(None, 57, 512, 51  0          ['input_1[0][0]']                \n",
            " lowOpLayer)                    2, 3)]                                                            \n",
            "                                                                                                  \n",
            " tf_op_layer_Reshape (TensorFlo  [(None, 512, 512, 3  0          ['tf_op_layer_transpose[0][0]']  \n",
            " wOpLayer)                      )]                                                                \n",
            "                                                                                                  \n",
            " conv2d (Conv2D)                (None, 512, 512, 16  1216        ['tf_op_layer_Reshape[0][0]']    \n",
            "                                )                                                                 \n",
            "                                                                                                  \n",
            " batch_normalization (BatchNorm  (None, 512, 512, 16  64         ['conv2d[0][0]']                 \n",
            " alization)                     )                                                                 \n",
            "                                                                                                  \n",
            " max_pooling2d (MaxPooling2D)   (None, 256, 256, 16  0           ['batch_normalization[0][0]']    \n",
            "                                )                                                                 \n",
            "                                                                                                  \n",
            " conv2d_1 (Conv2D)              (None, 254, 254, 32  4640        ['max_pooling2d[0][0]']          \n",
            "                                )                                                                 \n",
            "                                                                                                  \n",
            " batch_normalization_1 (BatchNo  (None, 254, 254, 32  128        ['conv2d_1[0][0]']               \n",
            " rmalization)                   )                                                                 \n",
            "                                                                                                  \n",
            " max_pooling2d_1 (MaxPooling2D)  (None, 127, 127, 32  0          ['batch_normalization_1[0][0]']  \n",
            "                                )                                                                 \n",
            "                                                                                                  \n",
            " dropout (Dropout)              (None, 127, 127, 32  0           ['max_pooling2d_1[0][0]']        \n",
            "                                )                                                                 \n",
            "                                                                                                  \n",
            " conv2d_2 (Conv2D)              (None, 125, 125, 32  9248        ['dropout[0][0]']                \n",
            "                                )                                                                 \n",
            "                                                                                                  \n",
            " batch_normalization_2 (BatchNo  (None, 125, 125, 32  128        ['conv2d_2[0][0]']               \n",
            " rmalization)                   )                                                                 \n",
            "                                                                                                  \n",
            " max_pooling2d_2 (MaxPooling2D)  (None, 62, 62, 32)  0           ['batch_normalization_2[0][0]']  \n",
            "                                                                                                  \n",
            " conv2d_3 (Conv2D)              (None, 60, 60, 32)   9248        ['max_pooling2d_2[0][0]']        \n",
            "                                                                                                  \n",
            " batch_normalization_3 (BatchNo  (None, 60, 60, 32)  128         ['conv2d_3[0][0]']               \n",
            " rmalization)                                                                                     \n",
            "                                                                                                  \n",
            " max_pooling2d_3 (MaxPooling2D)  (None, 30, 30, 32)  0           ['batch_normalization_3[0][0]']  \n",
            "                                                                                                  \n",
            " conv2d_4 (Conv2D)              (None, 28, 28, 32)   9248        ['max_pooling2d_3[0][0]']        \n",
            "                                                                                                  \n",
            " batch_normalization_4 (BatchNo  (None, 28, 28, 32)  128         ['conv2d_4[0][0]']               \n",
            " rmalization)                                                                                     \n",
            "                                                                                                  \n",
            " max_pooling2d_4 (MaxPooling2D)  (None, 14, 14, 32)  0           ['batch_normalization_4[0][0]']  \n",
            "                                                                                                  \n",
            " dropout_1 (Dropout)            (None, 14, 14, 32)   0           ['max_pooling2d_4[0][0]']        \n",
            "                                                                                                  \n",
            " conv2d_5 (Conv2D)              (None, 12, 12, 32)   9248        ['dropout_1[0][0]']              \n",
            "                                                                                                  \n",
            " batch_normalization_5 (BatchNo  (None, 12, 12, 32)  128         ['conv2d_5[0][0]']               \n",
            " rmalization)                                                                                     \n",
            "                                                                                                  \n",
            " max_pooling2d_5 (MaxPooling2D)  (None, 6, 6, 32)    0           ['batch_normalization_5[0][0]']  \n",
            "                                                                                                  \n",
            " dropout_2 (Dropout)            (None, 6, 6, 32)     0           ['max_pooling2d_5[0][0]']        \n",
            "                                                                                                  \n",
            " flatten (Flatten)              (None, 1152)         0           ['dropout_2[0][0]']              \n",
            "                                                                                                  \n",
            " alpha (MILAttentionLayer)      (None, 57)           147520      ['flatten[0][0]']                \n",
            "                                                                                                  \n",
            " tf_op_layer_ExpandDims (Tensor  [(None, 1, 57)]     0           ['alpha[0][0]']                  \n",
            " FlowOpLayer)                                                                                     \n",
            "                                                                                                  \n",
            " tf_op_layer_Reshape_1 (TensorF  [(None, 57, 1152)]  0           ['flatten[0][0]']                \n",
            " lowOpLayer)                                                                                      \n",
            "                                                                                                  \n",
            " tf_op_layer_MatMul (TensorFlow  [(None, 1, 1152)]   0           ['tf_op_layer_ExpandDims[0][0]', \n",
            " OpLayer)                                                         'tf_op_layer_Reshape_1[0][0]']  \n",
            "                                                                                                  \n",
            " tf_op_layer_Squeeze (TensorFlo  [(None, 1152)]      0           ['tf_op_layer_MatMul[0][0]']     \n",
            " wOpLayer)                                                                                        \n",
            "                                                                                                  \n",
            " dropout_3 (Dropout)            (None, 1152)         0           ['tf_op_layer_Squeeze[0][0]']    \n",
            "                                                                                                  \n",
            " dense (Dense)                  (None, 128)          147584      ['dropout_3[0][0]']              \n",
            "                                                                                                  \n",
            " dense_1 (Dense)                (None, 1)            129         ['dense[0][0]']                  \n",
            "                                                                                                  \n",
            "==================================================================================================\n",
            "Total params: 338,785\n",
            "Trainable params: 338,433\n",
            "Non-trainable params: 352\n",
            "__________________________________________________________________________________________________\n",
            "None\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "####################\n",
        "###    Evaluate     ###\n",
        "####################\n",
        "from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score\n",
        "import time\n",
        "\n",
        "for i in range(0, 5):\n",
        "\n",
        "    # Perform prediction on the test dataset using the saved checkpoint\n",
        "    checkpoint_path = \"./att_{}.ckpt\".format(i)\n",
        "    model.load_weights(checkpoint_path)\n",
        "\n",
        "    start = time.process_time()\n",
        "    preds = model.predict(test_dataset)\n",
        "    print('time:', time.process_time() - start)\n",
        "\n",
        "    target = bags2.bag_label\n",
        "\n",
        "    # print('AUC:', roc_auc_score(target[:], preds[:, 0]))\n",
        "    preds_value = (preds[:, 0] > 0.5) * 1\n",
        "    print('Accuracy:', accuracy_score(target, preds_value))\n",
        "    print('Precision:', precision_score(target, preds_value))\n",
        "    print('Recall:', recall_score(target, preds_value))\n",
        "    print('F1 score:', f1_score(target, preds_value))"
      ],
      "metadata": {
        "id": "_W8tkSYT7J9U"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "####################\n",
        "###    Slice Labels Predictions    ###\n",
        "####################\n",
        "for i in range(0, 5):\n",
        "\n",
        "    # Perform prediction on the test dataset using the saved checkpoint\n",
        "    checkpoint_path = \"./models/att_{}.ckpt\".format(i)\n",
        "    model.load_weights(checkpoint_path)\n",
        "\n",
        "    weights_layer = model.get_layer(\"alpha\")\n",
        "    feature_model = keras.Model([model.inputs], [weights_layer.output, model.output])\n",
        "    weights, pred = feature_model.predict(test_dataset)\n",
        "\n",
        "    instance_id = []\n",
        "    bag_id = []\n",
        "    pred_bag_id = []\n",
        "    pred_instance_id = []\n",
        "    for i, ID in enumerate(np.array(bags2.index)):\n",
        "        ids = test_bags_dic[ID]\n",
        "        for ii, idx in enumerate(ids):\n",
        "            instance_id.append(idx)\n",
        "            bag_id.append(ID)\n",
        "            pred_bag_id.append(pred[i][0])\n",
        "            pred_instance_id.append(weights[i,ii])\n",
        "\n",
        "    df_rest = df = pd.DataFrame(list(zip(instance_id, bag_id, pred_bag_id, pred_instance_id)),\n",
        "            columns =['instance_name', 'bag_name', 'bag_cnn_probability', 'cnn_prediction'])\n",
        "\n",
        "    df_rest.to_csv('test_visual.csv')\n"
      ],
      "metadata": {
        "id": "8aVbdSyIq2mu"
      },
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "bqwNkhzoxxvU"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}