Diff of /SA_MIL_testing.ipynb [000000] .. [05c4fa]

Switch to unified view

a b/SA_MIL_testing.ipynb
1
{
2
  "nbformat": 4,
3
  "nbformat_minor": 0,
4
  "metadata": {
5
    "colab": {
6
      "provenance": []
7
    },
8
    "kernelspec": {
9
      "name": "python3",
10
      "display_name": "Python 3"
11
    },
12
    "language_info": {
13
      "name": "python"
14
    }
15
  },
16
  "cells": [
17
    {
18
      "cell_type": "code",
19
      "execution_count": null,
20
      "metadata": {
21
        "colab": {
22
          "base_uri": "https://localhost:8080/"
23
        },
24
        "id": "h2430ByQ7FUA",
25
        "outputId": "b92d16bd-df4e-46ed-cd43-e718f6b971c2"
26
      },
27
      "outputs": [
28
        {
29
          "output_type": "stream",
30
          "name": "stdout",
31
          "text": [
32
            "Num GPUs Available:  0\n"
33
          ]
34
        }
35
      ],
36
      "source": [
37
        "####################\n",
38
        "###  LIBRARIES  ####\n",
39
        "####################\n",
40
        "\n",
41
        "import numpy as np\n",
42
        "import warnings\n",
43
        "import pandas as pd\n",
44
        "import os\n",
45
        "import matplotlib.pyplot as plt\n",
46
        "import cv2\n",
47
        "\n",
48
        "# Remove TensorFlow warnings\n",
49
        "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
50
        "\n",
51
        "# Import TensorFlow and Keras for neural network operations\n",
52
        "import tensorflow as tf\n",
53
        "from tensorflow import keras\n",
54
        "from tensorflow.keras import layers\n",
55
        "from tensorflow.keras.callbacks import EarlyStopping\n",
56
        "from tensorflow.keras.losses import Loss\n",
57
        "from tensorflow.python.framework.ops import disable_eager_execution\n",
58
        "disable_eager_execution()\n",
59
        "\n",
60
        "# Set the default float type for TensorFlow to \"float32\"\n",
61
        "tf.keras.backend.set_floatx(\"float32\")\n",
62
        "\n",
63
        "# Print the number of available GPUs\n",
64
        "print(\"Num GPUs Available: \", len(tf.config.list_physical_devices('GPU')))"
65
      ]
66
    },
67
    {
68
      "cell_type": "code",
69
      "source": [
70
        "####################\n",
71
        "### DATA LOADING ###\n",
72
        "####################\n",
73
        "\n",
74
        "print('Starting preprocessing of bags')\n",
75
        "\n",
76
        "# Define directories for image files during testing\n",
77
        "test_images_dir = './test/'\n",
78
        "\n",
79
        "# Get lists of files in the directories\n",
80
        "test_files = os.listdir(test_images_dir)\n",
81
        "\n",
82
        "# Read bag data from CSV files\n",
83
        "test_bags = pd.read_csv(\"./tables/testing_example.csv\")\n",
84
        "\n",
85
        "# Filter test bags based on DCM file existence\n",
86
        "test_files_dcm = [k[:-4] + '.dcm' for k in test_files]\n",
87
        "test_bags = test_bags[test_bags.instance_name.isin(test_files_dcm)]"
88
      ],
89
      "metadata": {
90
        "colab": {
91
          "base_uri": "https://localhost:8080/"
92
        },
93
        "id": "mtZQI9Hs7Juq",
94
        "outputId": "d2d52a49-853d-40b9-b9ae-ebe18ba7cf8f"
95
      },
96
      "execution_count": null,
97
      "outputs": [
98
        {
99
          "output_type": "stream",
100
          "name": "stdout",
101
          "text": [
102
            "Starting preprocessing of bags\n"
103
          ]
104
        }
105
      ]
106
    },
107
    {
108
      "cell_type": "code",
109
      "source": [
110
        "##########################\n",
111
        "### BAGS PREPROCESSING ###\n",
112
        "##########################\n",
113
        "\n",
114
        "# Set the desired bag size\n",
115
        "bag_size = 57\n",
116
        "\n",
117
        "# Create additional test bags to reach the desired bag size\n",
118
        "added_test_bags = pd.DataFrame()\n",
119
        "for idx in test_bags.bag_name.unique():\n",
120
        "    bags = test_bags[test_bags.bag_name==idx].copy()\n",
121
        "    num_add = bag_size - len(bags.instance_name)\n",
122
        "\n",
123
        "    aux = bags.iloc[0].copy()\n",
124
        "    aux.instance_label = 0\n",
125
        "    aux.instance_name = 'all_zeros'\n",
126
        "    for i in range(num_add):\n",
127
        "        added_test_bags = added_test_bags.append(aux)\n",
128
        "\n",
129
        "test_bags = test_bags.append(added_test_bags)\n",
130
        "\n",
131
        "# Convert bags data to dictionaries for optimization\n",
132
        "test_bags_dic = {k: list(test_bags[test_bags.bag_name==k].instance_name) for k in test_bags.bag_name.unique()}"
133
      ],
134
      "metadata": {
135
        "id": "5Hhr38Dx7JxO"
136
      },
137
      "execution_count": null,
138
      "outputs": []
139
    },
140
    {
141
      "cell_type": "code",
142
      "source": [
143
        "####################\n",
144
        "###  DATALOADER  ###\n",
145
        "####################\n",
146
        "dim=(512,512,bag_size)\n",
147
        "\n",
148
        "class DataGeneratorMIL(keras.utils.Sequence):\n",
149
        "    'Generates data for Keras'\n",
150
        "\n",
151
        "    def __init__(self, list_IDs, labels=None, batch_size=256, dim=(512,512,512), n_channels=3,\n",
152
        "                 n_classes=2, shuffle=True, is_train=True):\n",
153
        "        'Initialization'\n",
154
        "        self.dim = dim\n",
155
        "        self.batch_size = batch_size\n",
156
        "        self.labels = labels\n",
157
        "        self.is_train = (labels is not None) and is_train\n",
158
        "        self.list_IDs = list_IDs\n",
159
        "        self.n_channels = n_channels\n",
160
        "        self.n_classes = n_classes\n",
161
        "        self.shuffle = shuffle\n",
162
        "        self.on_epoch_end()\n",
163
        "\n",
164
        "    def __len__(self):\n",
165
        "        'Denotes the number of batches per epoch'\n",
166
        "        return int(np.floor(len(self.list_IDs) / self.batch_size))\n",
167
        "\n",
168
        "    def __getitem__(self, index):\n",
169
        "        'Generate one batch of data'\n",
170
        "        # Generate indexes of the batch\n",
171
        "        list_IDs_temp = self.list_IDs[index*self.batch_size:(index+1)*self.batch_size]\n",
172
        "\n",
173
        "        X = self.__data_generation(list_IDs_temp)\n",
174
        "        # Generate data\n",
175
        "        if self.is_train:\n",
176
        "            y = self.labels[index*self.batch_size:(index+1)*self.batch_size]\n",
177
        "            return np.array(X), np.array(y, dtype='float64')\n",
178
        "        else:\n",
179
        "            return np.array(X)\n",
180
        "\n",
181
        "    def on_epoch_end(self):\n",
182
        "        'Updates indexes after each epoch'\n",
183
        "        self.indexes = np.arange(len(self.list_IDs))\n",
184
        "        if self.shuffle == True:\n",
185
        "            np.random.shuffle(self.indexes)\n",
186
        "\n",
187
        "    def __data_generation(self, list_IDs_temp):\n",
188
        "        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)\n",
189
        "        # Initialization\n",
190
        "        X = np.empty((self.batch_size, *self.dim, self.n_channels))\n",
191
        "\n",
192
        "        # Generate data\n",
193
        "        for i, ID in enumerate(list_IDs_temp):\n",
194
        "            # Store sample\n",
195
        "            if self.is_train:\n",
196
        "                ids = train_bags_dic[ID]\n",
197
        "            else:\n",
198
        "                ids = test_bags_dic[ID]\n",
199
        "            imgs = []\n",
200
        "            for idx in ids:\n",
201
        "                if idx == 'all_zeros':\n",
202
        "                    img = np.zeros((self.dim[0], self.dim[1], self.n_channels))\n",
203
        "                    imgs.append(img)\n",
204
        "                    continue\n",
205
        "                if self.is_train:\n",
206
        "                    _dir = train_files_loc[idx]\n",
207
        "                    img = np.load(_dir + idx[:-4] + '.npy')\n",
208
        "                    img = cv2.resize(img, (self.dim[1], self.dim[0]))\n",
209
        "                    imgs.append(img)\n",
210
        "                else:\n",
211
        "                    img = np.load(test_images_dir + idx[:-4] + '.npy')\n",
212
        "                    img = cv2.resize(img, (self.dim[1], self.dim[0]))\n",
213
        "                    imgs.append(img)\n",
214
        "            X[i,] = np.transpose(imgs, [1,2,0,3])\n",
215
        "\n",
216
        "        return X"
217
      ],
218
      "metadata": {
219
        "id": "pWlaSxnt7Jz3"
220
      },
221
      "execution_count": null,
222
      "outputs": []
223
    },
224
    {
225
      "cell_type": "code",
226
      "source": [
227
        "########################\n",
228
        "### Test Generator ###\n",
229
        "########################\n",
230
        "batch_size = 1\n",
231
        "\n",
232
        "# Preparing the test dataset\n",
233
        "bags2 = test_bags.groupby('bag_name').max()\n",
234
        "test_dataset = DataGeneratorMIL(np.array(bags2.index), bags2.bag_label, batch_size=1, dim=dim, is_train=False)"
235
      ],
236
      "metadata": {
237
        "id": "w_XGFBbI7J2e"
238
      },
239
      "execution_count": null,
240
      "outputs": []
241
    },
242
    {
243
      "cell_type": "code",
244
      "source": [
245
        "####################\n",
246
        "###    MODEL     ###\n",
247
        "####################\n",
248
        "\n",
249
        "# MILAttentionLayer\n",
250
        "class MILAttentionLayer(layers.Layer):\n",
251
        "    \"\"\"Implementation of the attention-based Deep MIL layer.\"\"\"\n",
252
        "\n",
253
        "    def __init__(\n",
254
        "        self,\n",
255
        "        weight_params_dim,\n",
256
        "        kernel_initializer=\"glorot_uniform\",\n",
257
        "        kernel_regularizer=None,\n",
258
        "        use_gated=False,\n",
259
        "        **kwargs,\n",
260
        "    ):\n",
261
        "        super().__init__(**kwargs)\n",
262
        "\n",
263
        "        self.weight_params_dim = weight_params_dim\n",
264
        "        self.use_gated = use_gated\n",
265
        "\n",
266
        "        self.kernel_initializer = keras.initializers.get(kernel_initializer)\n",
267
        "        self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)\n",
268
        "\n",
269
        "        self.v_init = self.kernel_initializer\n",
270
        "        self.w_init = self.kernel_initializer\n",
271
        "        self.u_init = self.kernel_initializer\n",
272
        "\n",
273
        "        self.v_regularizer = self.kernel_regularizer\n",
274
        "        self.w_regularizer = self.kernel_regularizer\n",
275
        "        self.u_regularizer = self.kernel_regularizer\n",
276
        "\n",
277
        "    def build(self, input_shape):\n",
278
        "        input_dim = input_shape[1]\n",
279
        "\n",
280
        "        self.v_weight_params = self.add_weight(\n",
281
        "            shape=(input_dim, self.weight_params_dim),\n",
282
        "            initializer=self.v_init,\n",
283
        "            name=\"v\",\n",
284
        "            regularizer=self.v_regularizer,\n",
285
        "            trainable=True,\n",
286
        "        )\n",
287
        "\n",
288
        "        self.w_weight_params = self.add_weight(\n",
289
        "            shape=(self.weight_params_dim, 1),\n",
290
        "            initializer=self.w_init,\n",
291
        "            name=\"w\",\n",
292
        "            regularizer=self.w_regularizer,\n",
293
        "            trainable=True,\n",
294
        "        )\n",
295
        "\n",
296
        "        if self.use_gated:\n",
297
        "            self.u_weight_params = self.add_weight(\n",
298
        "                shape=(input_dim, self.weight_params_dim),\n",
299
        "                initializer=self.u_init,\n",
300
        "                name=\"u\",\n",
301
        "                regularizer=self.u_regularizer,\n",
302
        "                trainable=True,\n",
303
        "            )\n",
304
        "        else:\n",
305
        "            self.u_weight_params = None\n",
306
        "\n",
307
        "        self.input_built = True\n",
308
        "\n",
309
        "    def call(self, inputs):\n",
310
        "        instances = self.compute_attention_scores(inputs)\n",
311
        "        instances = tf.reshape(instances, shape=(-1, dim[2]))\n",
312
        "        alpha = tf.math.softmax(instances, axis=1)\n",
313
        "        return alpha\n",
314
        "\n",
315
        "    def compute_attention_scores(self, instance):\n",
316
        "        original_instance = instance\n",
317
        "        instance = tf.math.tanh(tf.tensordot(instance, self.v_weight_params, axes=1))\n",
318
        "\n",
319
        "        if self.use_gated:\n",
320
        "            instance = instance * tf.math.sigmoid(\n",
321
        "                tf.tensordot(original_instance, self.u_weight_params, axes=1)\n",
322
        "            )\n",
323
        "\n",
324
        "        return tf.tensordot(instance, self.w_weight_params, axes=1)\n",
325
        "\n",
326
        "\n",
327
        "# Model\n",
328
        "num_data = batch_size\n",
329
        "D = bag_size\n",
330
        "\n",
331
        "Conv1 = layers.Conv2D(16, (5, 5), data_format=\"channels_last\", activation='relu', kernel_initializer='glorot_uniform', padding='same')\n",
332
        "Conv2 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
333
        "Conv3 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
334
        "Conv4 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
335
        "Conv5 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
336
        "Conv6 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
337
        "\n",
338
        "def VGG(inp):\n",
339
        "    inp = tf.reshape(tf.transpose(inp, perm=(0,3,1,2,4)), shape=(-1, dim[0], dim[1], 3))\n",
340
        "    x = Conv1(inp)\n",
341
        "    x = layers.BatchNormalization()(x)\n",
342
        "    x = layers.MaxPool2D((2, 2), data_format=\"channels_last\", strides=(2, 2))(x)\n",
343
        "    x = Conv2(x)\n",
344
        "    x = layers.BatchNormalization()(x)\n",
345
        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
346
        "    x = layers.Dropout(0.3)(x)\n",
347
        "\n",
348
        "    x = Conv3(x)\n",
349
        "    x = layers.BatchNormalization()(x)\n",
350
        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
351
        "    x = Conv4(x)\n",
352
        "    x = layers.BatchNormalization()(x)\n",
353
        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
354
        "\n",
355
        "    x = Conv5(x)\n",
356
        "    x = layers.BatchNormalization()(x)\n",
357
        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
358
        "    x = layers.Dropout(0.3)(x)\n",
359
        "\n",
360
        "    x = Conv6(x)\n",
361
        "    x = layers.BatchNormalization()(x)\n",
362
        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
363
        "    x = layers.Dropout(0.3)(x)\n",
364
        "\n",
365
        "    return layers.Flatten()(x)\n",
366
        "\n",
367
        "def build_model():\n",
368
        "    inp = keras.Input(shape=(*dim, 3))\n",
369
        "    H = VGG(inp)\n",
370
        "    A = MILAttentionLayer(\n",
371
        "        weight_params_dim=64,\n",
372
        "        kernel_regularizer=keras.regularizers.l2(0.01),\n",
373
        "        use_gated=True,\n",
374
        "        name=\"alpha\",\n",
375
        "    )(H)\n",
376
        "    H = tf.reshape(H, shape=(-1, dim[2], H.shape[1]))\n",
377
        "    A = tf.expand_dims(A, axis=1)\n",
378
        "    intermediate = tf.linalg.matmul(A, H)\n",
379
        "    intermediate = tf.squeeze(intermediate, axis=1)\n",
380
        "    intermediate = layers.Dropout(0.25)(intermediate)\n",
381
        "    intermediate = layers.Dense(128)(intermediate)\n",
382
        "    out = layers.Dense(1, activation='sigmoid')(intermediate)\n",
383
        "    return keras.Model(inputs=inp, outputs=out)\n",
384
        "\n",
385
        "model = build_model()\n",
386
        "print(model.summary())"
387
      ],
388
      "metadata": {
389
        "colab": {
390
          "base_uri": "https://localhost:8080/"
391
        },
392
        "id": "T-m_DhYo7J48",
393
        "outputId": "65b2c285-9aae-45b5-e17f-ecc75d17808b"
394
      },
395
      "execution_count": null,
396
      "outputs": [
397
        {
398
          "output_type": "stream",
399
          "name": "stderr",
400
          "text": [
401
            "WARNING:tensorflow:From /usr/local/lib/python3.10/dist-packages/keras/layers/normalization/batch_normalization.py:581: _colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
402
            "Instructions for updating:\n",
403
            "Colocations handled automatically by placer.\n",
404
            "/usr/local/lib/python3.10/dist-packages/keras/initializers/initializers.py:120: UserWarning: The initializer GlorotUniform is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initalizer instance more than once.\n",
405
            "  warnings.warn(\n"
406
          ]
407
        },
408
        {
409
          "output_type": "stream",
410
          "name": "stdout",
411
          "text": [
412
            "Model: \"model\"\n",
413
            "__________________________________________________________________________________________________\n",
414
            " Layer (type)                   Output Shape         Param #     Connected to                     \n",
415
            "==================================================================================================\n",
416
            " input_1 (InputLayer)           [(None, 512, 512, 5  0           []                               \n",
417
            "                                7, 3)]                                                            \n",
418
            "                                                                                                  \n",
419
            " tf_op_layer_transpose (TensorF  [(None, 57, 512, 51  0          ['input_1[0][0]']                \n",
420
            " lowOpLayer)                    2, 3)]                                                            \n",
421
            "                                                                                                  \n",
422
            " tf_op_layer_Reshape (TensorFlo  [(None, 512, 512, 3  0          ['tf_op_layer_transpose[0][0]']  \n",
423
            " wOpLayer)                      )]                                                                \n",
424
            "                                                                                                  \n",
425
            " conv2d (Conv2D)                (None, 512, 512, 16  1216        ['tf_op_layer_Reshape[0][0]']    \n",
426
            "                                )                                                                 \n",
427
            "                                                                                                  \n",
428
            " batch_normalization (BatchNorm  (None, 512, 512, 16  64         ['conv2d[0][0]']                 \n",
429
            " alization)                     )                                                                 \n",
430
            "                                                                                                  \n",
431
            " max_pooling2d (MaxPooling2D)   (None, 256, 256, 16  0           ['batch_normalization[0][0]']    \n",
432
            "                                )                                                                 \n",
433
            "                                                                                                  \n",
434
            " conv2d_1 (Conv2D)              (None, 254, 254, 32  4640        ['max_pooling2d[0][0]']          \n",
435
            "                                )                                                                 \n",
436
            "                                                                                                  \n",
437
            " batch_normalization_1 (BatchNo  (None, 254, 254, 32  128        ['conv2d_1[0][0]']               \n",
438
            " rmalization)                   )                                                                 \n",
439
            "                                                                                                  \n",
440
            " max_pooling2d_1 (MaxPooling2D)  (None, 127, 127, 32  0          ['batch_normalization_1[0][0]']  \n",
441
            "                                )                                                                 \n",
442
            "                                                                                                  \n",
443
            " dropout (Dropout)              (None, 127, 127, 32  0           ['max_pooling2d_1[0][0]']        \n",
444
            "                                )                                                                 \n",
445
            "                                                                                                  \n",
446
            " conv2d_2 (Conv2D)              (None, 125, 125, 32  9248        ['dropout[0][0]']                \n",
447
            "                                )                                                                 \n",
448
            "                                                                                                  \n",
449
            " batch_normalization_2 (BatchNo  (None, 125, 125, 32  128        ['conv2d_2[0][0]']               \n",
450
            " rmalization)                   )                                                                 \n",
451
            "                                                                                                  \n",
452
            " max_pooling2d_2 (MaxPooling2D)  (None, 62, 62, 32)  0           ['batch_normalization_2[0][0]']  \n",
453
            "                                                                                                  \n",
454
            " conv2d_3 (Conv2D)              (None, 60, 60, 32)   9248        ['max_pooling2d_2[0][0]']        \n",
455
            "                                                                                                  \n",
456
            " batch_normalization_3 (BatchNo  (None, 60, 60, 32)  128         ['conv2d_3[0][0]']               \n",
457
            " rmalization)                                                                                     \n",
458
            "                                                                                                  \n",
459
            " max_pooling2d_3 (MaxPooling2D)  (None, 30, 30, 32)  0           ['batch_normalization_3[0][0]']  \n",
460
            "                                                                                                  \n",
461
            " conv2d_4 (Conv2D)              (None, 28, 28, 32)   9248        ['max_pooling2d_3[0][0]']        \n",
462
            "                                                                                                  \n",
463
            " batch_normalization_4 (BatchNo  (None, 28, 28, 32)  128         ['conv2d_4[0][0]']               \n",
464
            " rmalization)                                                                                     \n",
465
            "                                                                                                  \n",
466
            " max_pooling2d_4 (MaxPooling2D)  (None, 14, 14, 32)  0           ['batch_normalization_4[0][0]']  \n",
467
            "                                                                                                  \n",
468
            " dropout_1 (Dropout)            (None, 14, 14, 32)   0           ['max_pooling2d_4[0][0]']        \n",
469
            "                                                                                                  \n",
470
            " conv2d_5 (Conv2D)              (None, 12, 12, 32)   9248        ['dropout_1[0][0]']              \n",
471
            "                                                                                                  \n",
472
            " batch_normalization_5 (BatchNo  (None, 12, 12, 32)  128         ['conv2d_5[0][0]']               \n",
473
            " rmalization)                                                                                     \n",
474
            "                                                                                                  \n",
475
            " max_pooling2d_5 (MaxPooling2D)  (None, 6, 6, 32)    0           ['batch_normalization_5[0][0]']  \n",
476
            "                                                                                                  \n",
477
            " dropout_2 (Dropout)            (None, 6, 6, 32)     0           ['max_pooling2d_5[0][0]']        \n",
478
            "                                                                                                  \n",
479
            " flatten (Flatten)              (None, 1152)         0           ['dropout_2[0][0]']              \n",
480
            "                                                                                                  \n",
481
            " alpha (MILAttentionLayer)      (None, 57)           147520      ['flatten[0][0]']                \n",
482
            "                                                                                                  \n",
483
            " tf_op_layer_ExpandDims (Tensor  [(None, 1, 57)]     0           ['alpha[0][0]']                  \n",
484
            " FlowOpLayer)                                                                                     \n",
485
            "                                                                                                  \n",
486
            " tf_op_layer_Reshape_1 (TensorF  [(None, 57, 1152)]  0           ['flatten[0][0]']                \n",
487
            " lowOpLayer)                                                                                      \n",
488
            "                                                                                                  \n",
489
            " tf_op_layer_MatMul (TensorFlow  [(None, 1, 1152)]   0           ['tf_op_layer_ExpandDims[0][0]', \n",
490
            " OpLayer)                                                         'tf_op_layer_Reshape_1[0][0]']  \n",
491
            "                                                                                                  \n",
492
            " tf_op_layer_Squeeze (TensorFlo  [(None, 1152)]      0           ['tf_op_layer_MatMul[0][0]']     \n",
493
            " wOpLayer)                                                                                        \n",
494
            "                                                                                                  \n",
495
            " dropout_3 (Dropout)            (None, 1152)         0           ['tf_op_layer_Squeeze[0][0]']    \n",
496
            "                                                                                                  \n",
497
            " dense (Dense)                  (None, 128)          147584      ['dropout_3[0][0]']              \n",
498
            "                                                                                                  \n",
499
            " dense_1 (Dense)                (None, 1)            129         ['dense[0][0]']                  \n",
500
            "                                                                                                  \n",
501
            "==================================================================================================\n",
502
            "Total params: 338,785\n",
503
            "Trainable params: 338,433\n",
504
            "Non-trainable params: 352\n",
505
            "__________________________________________________________________________________________________\n",
506
            "None\n"
507
          ]
508
        }
509
      ]
510
    },
511
    {
512
      "cell_type": "code",
513
      "source": [
514
        "####################\n",
515
        "###    Evaluate     ###\n",
516
        "####################\n",
517
        "from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score\n",
518
        "import time\n",
519
        "\n",
520
        "for i in range(0, 5):\n",
521
        "\n",
522
        "    # Perform prediction on the test dataset using the saved checkpoint\n",
523
        "    checkpoint_path = \"./att_{}.ckpt\".format(i)\n",
524
        "    model.load_weights(checkpoint_path)\n",
525
        "\n",
526
        "    start = time.process_time()\n",
527
        "    preds = model.predict(test_dataset)\n",
528
        "    print('time:', time.process_time() - start)\n",
529
        "\n",
530
        "    target = bags2.bag_label\n",
531
        "\n",
532
        "    # print('AUC:', roc_auc_score(target[:], preds[:, 0]))\n",
533
        "    preds_value = (preds[:, 0] > 0.5) * 1\n",
534
        "    print('Accuracy:', accuracy_score(target, preds_value))\n",
535
        "    print('Precision:', precision_score(target, preds_value))\n",
536
        "    print('Recall:', recall_score(target, preds_value))\n",
537
        "    print('F1 score:', f1_score(target, preds_value))"
538
      ],
539
      "metadata": {
540
        "id": "_W8tkSYT7J9U"
541
      },
542
      "execution_count": null,
543
      "outputs": []
544
    },
545
    {
546
      "cell_type": "code",
547
      "source": [
548
        "####################\n",
549
        "###    Slice Labels Predictions    ###\n",
550
        "####################\n",
551
        "for i in range(0, 5):\n",
552
        "\n",
553
        "    # Perform prediction on the test dataset using the saved checkpoint\n",
554
        "    checkpoint_path = \"./models/att_{}.ckpt\".format(i)\n",
555
        "    model.load_weights(checkpoint_path)\n",
556
        "\n",
557
        "    weights_layer = model.get_layer(\"alpha\")\n",
558
        "    feature_model = keras.Model([model.inputs], [weights_layer.output, model.output])\n",
559
        "    weights, pred = feature_model.predict(test_dataset)\n",
560
        "\n",
561
        "    instance_id = []\n",
562
        "    bag_id = []\n",
563
        "    pred_bag_id = []\n",
564
        "    pred_instance_id = []\n",
565
        "    for i, ID in enumerate(np.array(bags2.index)):\n",
566
        "        ids = test_bags_dic[ID]\n",
567
        "        for ii, idx in enumerate(ids):\n",
568
        "            instance_id.append(idx)\n",
569
        "            bag_id.append(ID)\n",
570
        "            pred_bag_id.append(pred[i][0])\n",
571
        "            pred_instance_id.append(weights[i,ii])\n",
572
        "\n",
573
        "    df_rest = df = pd.DataFrame(list(zip(instance_id, bag_id, pred_bag_id, pred_instance_id)),\n",
574
        "            columns =['instance_name', 'bag_name', 'bag_cnn_probability', 'cnn_prediction'])\n",
575
        "\n",
576
        "    df_rest.to_csv('test_visual.csv')\n"
577
      ],
578
      "metadata": {
579
        "id": "8aVbdSyIq2mu"
580
      },
581
      "execution_count": 1,
582
      "outputs": []
583
    },
584
    {
585
      "cell_type": "code",
586
      "source": [],
587
      "metadata": {
588
        "id": "bqwNkhzoxxvU"
589
      },
590
      "execution_count": null,
591
      "outputs": []
592
    }
593
  ]
594
}