Diff of /SA_MIL_training.ipynb [000000] .. [05c4fa]

Switch to unified view

a b/SA_MIL_training.ipynb
1
{
2
  "nbformat": 4,
3
  "nbformat_minor": 0,
4
  "metadata": {
5
    "colab": {
6
      "provenance": []
7
    },
8
    "kernelspec": {
9
      "name": "python3",
10
      "display_name": "Python 3"
11
    },
12
    "language_info": {
13
      "name": "python"
14
    }
15
  },
16
  "cells": [
17
    {
18
      "cell_type": "code",
19
      "execution_count": 1,
20
      "metadata": {
21
        "id": "FauGw-yKqt0k",
22
        "colab": {
23
          "base_uri": "https://localhost:8080/"
24
        },
25
        "outputId": "e03c49ff-a6fa-44a6-9df9-c9f392824821"
26
      },
27
      "outputs": [
28
        {
29
          "output_type": "stream",
30
          "name": "stdout",
31
          "text": [
32
            "Num GPUs Available:  0\n"
33
          ]
34
        }
35
      ],
36
      "source": [
37
        "####################\n",
38
        "###  LIBRARIES  ####\n",
39
        "####################\n",
40
        "\n",
41
        "import numpy as np\n",
42
        "import warnings\n",
43
        "import pandas as pd\n",
44
        "import os\n",
45
        "import matplotlib.pyplot as plt\n",
46
        "import cv2\n",
47
        "import networkx as nx\n",
48
        "\n",
49
        "# Remove TensorFlow warnings\n",
50
        "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
51
        "\n",
52
        "# Import TensorFlow and Keras for neural network operations\n",
53
        "import tensorflow as tf\n",
54
        "from tensorflow import keras\n",
55
        "from tensorflow.keras import layers\n",
56
        "from tensorflow.keras.callbacks import EarlyStopping\n",
57
        "from tensorflow.keras.losses import Loss\n",
58
        "from tensorflow.python.framework.ops import disable_eager_execution\n",
59
        "disable_eager_execution()\n",
60
        "\n",
61
        "# Set the default float type for TensorFlow to \"float32\"\n",
62
        "tf.keras.backend.set_floatx(\"float32\")\n",
63
        "\n",
64
        "# Print the number of available GPUs\n",
65
        "print(\"Num GPUs Available: \", len(tf.config.list_physical_devices('GPU')))"
66
      ]
67
    },
68
    {
69
      "cell_type": "code",
70
      "source": [
71
        "####################\n",
72
        "### DATA LOADING ###\n",
73
        "####################\n",
74
        "\n",
75
        "print('Starting preprocessing of bags')\n",
76
        "\n",
77
        "# Define directories for image files, have training images from three folders.\n",
78
        "train_images_dir1 = './Data/train1/'\n",
79
        "train_images_dir2 = './Data/train2/'\n",
80
        "train_images_dir3 = './Data/train3/'\n",
81
        "\n",
82
        "# Get lists of files in the directories\n",
83
        "train_files1 = set(os.listdir(train_images_dir1))\n",
84
        "train_files2 = set(os.listdir(train_images_dir2))\n",
85
        "train_files3 = set(os.listdir(train_images_dir3))\n",
86
        "\n",
87
        "# Read bag data from CSV files\n",
88
        "train_bags = pd.read_csv(\"./tables/Training_examples.csv\")\n",
89
        "\n",
90
        "# Create a mapping of train files to their respective directories\n",
91
        "dirs_ = [train_images_dir1, train_images_dir2, train_images_dir3]\n",
92
        "train_files_loc = {\n",
93
        "    k: dirs_[\n",
94
        "        (k[:-4]+'.npy' in train_files1) * 1 +\n",
95
        "        (k[:-4]+'.npy' in train_files2) * 2 +\n",
96
        "        (k[:-4]+'.npy' in train_files3) * 3 - 1\n",
97
        "    ]\n",
98
        "    for k in train_bags.instance_name\n",
99
        "}\n",
100
        "\n",
101
        "# Create lists of DCM files for train files in each directory\n",
102
        "train_files1_dcm = [k[:-4] + '.dcm' for k in train_files1]\n",
103
        "train_files2_dcm = [k[:-4] + '.dcm' for k in train_files2]\n",
104
        "train_files3_dcm = [k[:-4] + '.dcm' for k in train_files3]\n",
105
        "\n",
106
        "# Filter train bags based on DCM file existence\n",
107
        "train_bags = train_bags[\n",
108
        "    train_bags.instance_name.isin(train_files1_dcm) |\n",
109
        "    train_bags.instance_name.isin(train_files2_dcm) |\n",
110
        "    train_bags.instance_name.isin(train_files3_dcm)\n",
111
        "]"
112
      ],
113
      "metadata": {
114
        "id": "sOoozeNfqxrz"
115
      },
116
      "execution_count": null,
117
      "outputs": []
118
    },
119
    {
120
      "cell_type": "code",
121
      "source": [
122
        "##########################\n",
123
        "### BAGS PREPROCESSING ###\n",
124
        "##########################\n",
125
        "\n",
126
        "# Set the desired bag size\n",
127
        "bag_size = 57\n",
128
        "\n",
129
        "# Create additional train bags to reach the desired bag size\n",
130
        "added_train_bags = pd.DataFrame()\n",
131
        "for idx in train_bags.bag_name.unique():\n",
132
        "    bags = train_bags[train_bags.bag_name==idx].copy()\n",
133
        "    num_add = bag_size - len(bags.instance_name)\n",
134
        "\n",
135
        "    aux = bags.iloc[0].copy()\n",
136
        "    aux.instance_label = 0\n",
137
        "    aux.instance_name = 'all_zeros'\n",
138
        "    for i in range(num_add):\n",
139
        "        added_train_bags = added_train_bags.append(aux)\n",
140
        "\n",
141
        "train_bags = train_bags.append(added_train_bags)\n",
142
        "\n",
143
        "# Convert bags data to dictionaries for optimization\n",
144
        "train_bags_dic = {k: list(train_bags[train_bags.bag_name==k].instance_name) for k in train_bags.bag_name.unique()}"
145
      ],
146
      "metadata": {
147
        "id": "1amWUWZBqx0x"
148
      },
149
      "execution_count": null,
150
      "outputs": []
151
    },
152
    {
153
      "cell_type": "code",
154
      "source": [
155
        "####################\n",
156
        "###  DATALOADER  ###\n",
157
        "####################\n",
158
        "\n",
159
        "dim=(512,512,bag_size)\n",
160
        "class DataGeneratorMIL(keras.utils.Sequence):\n",
161
        "    'Generates data for Keras'\n",
162
        "\n",
163
        "    def __init__(self, list_IDs, labels=None, batch_size=256, dim=(512,512,512), n_channels=3,\n",
164
        "                 n_classes=2, shuffle=True, is_train=True):\n",
165
        "        'Initialization'\n",
166
        "        self.dim = dim\n",
167
        "        self.batch_size = batch_size\n",
168
        "        self.labels = labels\n",
169
        "        self.is_train = (labels is not None) and is_train\n",
170
        "        self.list_IDs = list_IDs\n",
171
        "        self.n_channels = n_channels\n",
172
        "        self.n_classes = n_classes\n",
173
        "        self.shuffle = shuffle\n",
174
        "        self.on_epoch_end()\n",
175
        "\n",
176
        "    def __len__(self):\n",
177
        "        'Denotes the number of batches per epoch'\n",
178
        "        return int(np.floor(len(self.list_IDs) / self.batch_size))\n",
179
        "\n",
180
        "    def __getitem__(self, index):\n",
181
        "        'Generate one batch of data'\n",
182
        "        # Generate indexes of the batch\n",
183
        "        list_IDs_temp = self.list_IDs[index*self.batch_size:(index+1)*self.batch_size]\n",
184
        "\n",
185
        "        X = self.__data_generation(list_IDs_temp)\n",
186
        "        # Generate data\n",
187
        "        if self.is_train:\n",
188
        "            y = self.labels[index*self.batch_size:(index+1)*self.batch_size]\n",
189
        "            return np.array(X), np.array(y, dtype='float64')\n",
190
        "        else:\n",
191
        "            return np.array(X)\n",
192
        "\n",
193
        "    def on_epoch_end(self):\n",
194
        "        'Updates indexes after each epoch'\n",
195
        "        self.indexes = np.arange(len(self.list_IDs))\n",
196
        "        if self.shuffle == True:\n",
197
        "            np.random.shuffle(self.indexes)\n",
198
        "\n",
199
        "    def __data_generation(self, list_IDs_temp):\n",
200
        "        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)\n",
201
        "        # Initialization\n",
202
        "        X = np.empty((self.batch_size, *self.dim, self.n_channels))\n",
203
        "\n",
204
        "        # Generate data\n",
205
        "        for i, ID in enumerate(list_IDs_temp):\n",
206
        "            # Store sample\n",
207
        "            if self.is_train:\n",
208
        "                ids = train_bags_dic[ID]\n",
209
        "            else:\n",
210
        "                ids = test_bags_dic[ID]\n",
211
        "            imgs = []\n",
212
        "            for idx in ids:\n",
213
        "                if idx == 'all_zeros':\n",
214
        "                    img = np.zeros((self.dim[0], self.dim[1], self.n_channels))\n",
215
        "                    imgs.append(img)\n",
216
        "                    continue\n",
217
        "                if self.is_train:\n",
218
        "                    _dir = train_files_loc[idx]\n",
219
        "                    img = np.load(_dir + idx[:-4] + '.npy')\n",
220
        "                    img = cv2.resize(img, (self.dim[1], self.dim[0]))\n",
221
        "                    imgs.append(img)\n",
222
        "                else:\n",
223
        "                    img = np.load(test_images_dir + idx[:-4] + '.npy')\n",
224
        "                    img = cv2.resize(img, (self.dim[1], self.dim[0]))\n",
225
        "                    imgs.append(img)\n",
226
        "            X[i,] = np.transpose(imgs, [1,2,0,3])\n",
227
        "\n",
228
        "        return X"
229
      ],
230
      "metadata": {
231
        "id": "CkWPfMEjqx1z"
232
      },
233
      "execution_count": null,
234
      "outputs": []
235
    },
236
    {
237
      "cell_type": "code",
238
      "source": [
239
        "########################\n",
240
        "### TRAIN/TEST SPLIT ###\n",
241
        "########################\n",
242
        "\n",
243
        "from sklearn.model_selection import train_test_split\n",
244
        "\n",
245
        "N = len(train_bags.bag_name.unique())\n",
246
        "bags = train_bags.groupby('bag_name').max()\n",
247
        "\n",
248
        "X_train, X_val, y_train, y_val = train_test_split(np.array(bags.index)[:], bags.bag_label[:],\n",
249
        "                                                 test_size=0.20, random_state=0,\n",
250
        "                                                 stratify=bags.bag_label[:])\n",
251
        "\n",
252
        "batch_size = 4\n",
253
        "\n",
254
        "# Creating the train dataset using DataGeneratorMIL\n",
255
        "train_dataset = DataGeneratorMIL(X_train, y_train, batch_size=batch_size, dim=dim)\n",
256
        "\n",
257
        "# Creating the validation dataset using DataGeneratorMIL\n",
258
        "val_dataset = DataGeneratorMIL(X_val, y_val, batch_size=batch_size, dim=dim, is_augment=False)"
259
      ],
260
      "metadata": {
261
        "id": "MUdsNAz-qx4P"
262
      },
263
      "execution_count": null,
264
      "outputs": []
265
    },
266
    {
267
      "cell_type": "code",
268
      "source": [
269
        "########################\n",
270
        "### SA-Loss ###\n",
271
        "########################\n",
272
        "class smoothMIL(tf.keras.losses.Loss):\n",
273
        "    def __init__(self, att_weights, alpha, S_k):\n",
274
        "        super(smoothMIL, self).__init__()\n",
275
        "        self.att_weights = att_weights\n",
276
        "        self.alpha = alpha\n",
277
        "        self.S_k = S_k\n",
278
        "\n",
279
        "    def compute_Laplacian(self, bag_size):\n",
280
        "        G = nx.Graph()\n",
281
        "        for e in range(bag_size - 1):\n",
282
        "            G.add_edge(e + 1, e + 2)\n",
283
        "        degree_matrix = np.diag(list(dict(G.degree()).values())) + np.eye(bag_size)\n",
284
        "        adjacency_matrix = nx.adjacency_matrix(G).toarray()\n",
285
        "        L = degree_matrix - adjacency_matrix\n",
286
        "        return L\n",
287
        "\n",
288
        "    def call(self, y_true, y_pred):\n",
289
        "        bce = tf.keras.losses.BinaryCrossentropy()\n",
290
        "\n",
291
        "        loss1 = bce(y_true, y_pred)\n",
292
        "        L = self.compute_Laplacian(bag_size)\n",
293
        "\n",
294
        "        if self.S_k == 1:\n",
295
        "            VV = tf.linalg.matmul(self.att_weights, L)\n",
296
        "            loss2 = tf.linalg.matmul(VV, tf.transpose(self.att_weights, (0, 2, 1)))\n",
297
        "        elif self.S_k == 2:\n",
298
        "            VV = tf.linalg.matmul(self.att_weights, L)\n",
299
        "            VV = tf.linalg.matmul(VV, L)\n",
300
        "            loss2 = tf.linalg.matmul(VV, tf.transpose(self.att_weights, (0, 2, 1)))\n",
301
        "\n",
302
        "        loss2 = tf.math.reduce_mean(loss2)\n",
303
        "        loss_combined = tf.math.add(self.alpha * loss1, (1 - self.alpha) * loss2)\n",
304
        "\n",
305
        "        return loss_combined"
306
      ],
307
      "metadata": {
308
        "id": "HbzoGw3lrFlJ"
309
      },
310
      "execution_count": null,
311
      "outputs": []
312
    },
313
    {
314
      "cell_type": "code",
315
      "source": [
316
        "####################\n",
317
        "###    MODEL     ###\n",
318
        "####################\n",
319
        "\n",
320
        "# MILAttentionLayer\n",
321
        "class MILAttentionLayer(layers.Layer):\n",
322
        "    \"\"\"Implementation of the attention-based Deep MIL layer.\"\"\"\n",
323
        "\n",
324
        "    def __init__(\n",
325
        "        self,\n",
326
        "        weight_params_dim,\n",
327
        "        kernel_initializer=\"glorot_uniform\",\n",
328
        "        kernel_regularizer=None,\n",
329
        "        use_gated=False,\n",
330
        "        **kwargs,\n",
331
        "    ):\n",
332
        "        super().__init__(**kwargs)\n",
333
        "\n",
334
        "        self.weight_params_dim = weight_params_dim\n",
335
        "        self.use_gated = use_gated\n",
336
        "\n",
337
        "        self.kernel_initializer = keras.initializers.get(kernel_initializer)\n",
338
        "        self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)\n",
339
        "\n",
340
        "        self.v_init = self.kernel_initializer\n",
341
        "        self.w_init = self.kernel_initializer\n",
342
        "        self.u_init = self.kernel_initializer\n",
343
        "\n",
344
        "        self.v_regularizer = self.kernel_regularizer\n",
345
        "        self.w_regularizer = self.kernel_regularizer\n",
346
        "        self.u_regularizer = self.kernel_regularizer\n",
347
        "\n",
348
        "    def build(self, input_shape):\n",
349
        "        input_dim = input_shape[1]\n",
350
        "\n",
351
        "        self.v_weight_params = self.add_weight(\n",
352
        "            shape=(input_dim, self.weight_params_dim),\n",
353
        "            initializer=self.v_init,\n",
354
        "            name=\"v\",\n",
355
        "            regularizer=self.v_regularizer,\n",
356
        "            trainable=True,\n",
357
        "        )\n",
358
        "\n",
359
        "        self.w_weight_params = self.add_weight(\n",
360
        "            shape=(self.weight_params_dim, 1),\n",
361
        "            initializer=self.w_init,\n",
362
        "            name=\"w\",\n",
363
        "            regularizer=self.w_regularizer,\n",
364
        "            trainable=True,\n",
365
        "        )\n",
366
        "\n",
367
        "        if self.use_gated:\n",
368
        "            self.u_weight_params = self.add_weight(\n",
369
        "                shape=(input_dim, self.weight_params_dim),\n",
370
        "                initializer=self.u_init,\n",
371
        "                name=\"u\",\n",
372
        "                regularizer=self.u_regularizer,\n",
373
        "                trainable=True,\n",
374
        "            )\n",
375
        "        else:\n",
376
        "            self.u_weight_params = None\n",
377
        "\n",
378
        "        self.input_built = True\n",
379
        "\n",
380
        "    def call(self, inputs):\n",
381
        "        instances = self.compute_attention_scores(inputs)\n",
382
        "        instances = tf.reshape(instances, shape=(-1, dim[2]))\n",
383
        "        alpha = tf.math.softmax(instances, axis=1)\n",
384
        "        return alpha\n",
385
        "\n",
386
        "    def compute_attention_scores(self, instance):\n",
387
        "        original_instance = instance\n",
388
        "        instance = tf.math.tanh(tf.tensordot(instance, self.v_weight_params, axes=1))\n",
389
        "\n",
390
        "        if self.use_gated:\n",
391
        "            instance = instance * tf.math.sigmoid(\n",
392
        "                tf.tensordot(original_instance, self.u_weight_params, axes=1)\n",
393
        "            )\n",
394
        "\n",
395
        "        return tf.tensordot(instance, self.w_weight_params, axes=1)\n",
396
        "\n",
397
        "\n",
398
        "# Model\n",
399
        "num_data = batch_size\n",
400
        "D = bag_size\n",
401
        "\n",
402
        "Conv1 = layers.Conv2D(16, (5, 5), data_format=\"channels_last\", activation='relu', kernel_initializer='glorot_uniform', padding='same')\n",
403
        "Conv2 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
404
        "Conv3 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
405
        "Conv4 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
406
        "Conv5 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
407
        "Conv6 = layers.Conv2D(32, (3,3),  data_format=\"channels_last\", activation='relu')\n",
408
        "\n",
409
        "def VGG(inp):\n",
410
        "    inp = tf.reshape(tf.transpose(inp, perm=(0,3,1,2,4)), shape=(-1, dim[0], dim[1], 3))\n",
411
        "    x = Conv1(inp)\n",
412
        "    x = layers.BatchNormalization()(x)\n",
413
        "    x = layers.MaxPool2D((2, 2), data_format=\"channels_last\", strides=(2, 2))(x)\n",
414
        "    x = Conv2(x)\n",
415
        "    x = layers.BatchNormalization()(x)\n",
416
        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
417
        "    x = layers.Dropout(0.3)(x)\n",
418
        "\n",
419
        "    x = Conv3(x)\n",
420
        "    x = layers.BatchNormalization()(x)\n",
421
        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
422
        "    x = Conv4(x)\n",
423
        "    x = layers.BatchNormalization()(x)\n",
424
        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
425
        "\n",
426
        "    x = Conv5(x)\n",
427
        "    x = layers.BatchNormalization()(x)\n",
428
        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
429
        "    x = layers.Dropout(0.3)(x)\n",
430
        "\n",
431
        "    x = Conv6(x)\n",
432
        "    x = layers.BatchNormalization()(x)\n",
433
        "    x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
434
        "    x = layers.Dropout(0.3)(x)\n",
435
        "\n",
436
        "    return layers.Flatten()(x)\n",
437
        "\n",
438
        "def build_model():\n",
439
        "    inp = keras.Input(shape=(*dim, 3))\n",
440
        "    H = VGG(inp)\n",
441
        "    A = MILAttentionLayer(\n",
442
        "        weight_params_dim=64,\n",
443
        "        kernel_regularizer=keras.regularizers.l2(0.01),\n",
444
        "        use_gated=True,\n",
445
        "        name=\"alpha\",\n",
446
        "    )(H)\n",
447
        "    H = tf.reshape(H, shape=(-1, dim[2], H.shape[1]))\n",
448
        "    A = tf.expand_dims(A, axis=1)\n",
449
        "    intermediate = tf.linalg.matmul(A, H)\n",
450
        "    intermediate = tf.squeeze(intermediate, axis=1)\n",
451
        "    intermediate = layers.Dropout(0.25)(intermediate)\n",
452
        "    intermediate = layers.Dense(128)(intermediate)\n",
453
        "    out = layers.Dense(1, activation='sigmoid')(intermediate)\n",
454
        "    return A, keras.Model(inputs=inp, outputs=out)\n",
455
        "\n",
456
        "A, model = build_model()\n",
457
        "\n",
458
        "auc = tf.keras.metrics.AUC()\n",
459
        "adam = tf.compat.v1.train.AdamOptimizer(learning_rate=5e-5)\n",
460
        "model.compile(\n",
461
        "    optimizer=adam,\n",
462
        "    loss= smoothMIL(A, 0.5, 1),\n",
463
        "    metrics=[auc, 'accuracy']\n",
464
        ")\n",
465
        "earlyStopping = EarlyStopping(monitor='val_loss', patience=8, verbose=1, mode='min')\n",
466
        "print(model.summary())"
467
      ],
468
      "metadata": {
469
        "id": "TUE7oM38qx6y"
470
      },
471
      "execution_count": null,
472
      "outputs": []
473
    },
474
    {
475
      "cell_type": "code",
476
      "source": [
477
        "####################\n",
478
        "###    Train    ###\n",
479
        "####################\n",
480
        "for i in range(0, 5):\n",
481
        "    checkpoint_path = \"./model/att_{}.ckpt\".format(i)\n",
482
        "    checkpoint_dir = os.path.dirname(checkpoint_path)\n",
483
        "\n",
484
        "    cp_callback = keras.callbacks.ModelCheckpoint(\n",
485
        "        filepath=checkpoint_path,\n",
486
        "        monitor='val_loss',\n",
487
        "        save_best_only=True,\n",
488
        "        save_weights_only=True,\n",
489
        "        verbose=1,\n",
490
        "        mode='min'\n",
491
        "    )\n",
492
        "\n",
493
        "\n",
494
        "    history = model.fit(\n",
495
        "        train_dataset,\n",
496
        "        validation_data=val_dataset,\n",
497
        "        epochs=200,\n",
498
        "        callbacks=[earlyStopping, cp_callback],\n",
499
        "    )\n",
500
        "\n",
501
        "    hist_df = pd.DataFrame(history.history)\n",
502
        "    hist_csv_file = './log.csv'\n",
503
        "\n",
504
        "    with open(hist_csv_file, mode='w') as f:\n",
505
        "        hist_df.to_csv(f)"
506
      ],
507
      "metadata": {
508
        "id": "ffrwrgDRqx9f"
509
      },
510
      "execution_count": null,
511
      "outputs": []
512
    }
513
  ]
514
}