468 lines (468 with data), 19.0 kB
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"## Tutorials for training Non-SA-MIL"
],
"metadata": {
"id": "SaF3awTD6gxO"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NT_qyexMFEDk"
},
"outputs": [],
"source": [
"####################\n",
"### LIBRARIES ####\n",
"####################\n",
"\n",
"import numpy as np\n",
"import warnings\n",
"import pandas as pd\n",
"import os\n",
"import matplotlib.pyplot as plt\n",
"import cv2\n",
"\n",
"# Remove TensorFlow warnings\n",
"os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
"\n",
"# Import TensorFlow and Keras for neural network operations\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"from tensorflow.keras.callbacks import EarlyStopping\n",
"from tensorflow.keras.losses import Loss\n",
"from tensorflow.python.framework.ops import disable_eager_execution\n",
"disable_eager_execution()\n",
"\n",
"# Set the default float type for TensorFlow to \"float32\"\n",
"tf.keras.backend.set_floatx(\"float32\")\n",
"\n",
"# Print the number of available GPUs\n",
"print(\"Num GPUs Available: \", len(tf.config.list_physical_devices('GPU')))"
]
},
{
"cell_type": "code",
"source": [
"####################\n",
"### DATA LOADING ###\n",
"####################\n",
"\n",
"print('Starting preprocessing of bags')\n",
"\n",
"# Define directories for image files, have training images from three folders\n",
"train_images_dir1 = './Data/train1/'\n",
"train_images_dir2 = './Data/train2/'\n",
"train_images_dir3 = './Data/train3/'\n",
"\n",
"# Get lists of files in the directories\n",
"train_files1 = set(os.listdir(train_images_dir1))\n",
"train_files2 = set(os.listdir(train_images_dir2))\n",
"train_files3 = set(os.listdir(train_images_dir3))\n",
"\n",
"# Read bag data from CSV files\n",
"train_bags = pd.read_csv(\"./tables/Training_examples.csv\")\n",
"\n",
"# Create a mapping of train files to their respective directories\n",
"dirs_ = [train_images_dir1, train_images_dir2, train_images_dir3]\n",
"train_files_loc = {\n",
" k: dirs_[\n",
" (k[:-4]+'.npy' in train_files1) * 1 +\n",
" (k[:-4]+'.npy' in train_files2) * 2 +\n",
" (k[:-4]+'.npy' in train_files3) * 3 - 1\n",
" ]\n",
" for k in train_bags.instance_name\n",
"}\n",
"\n",
"# Create lists of DCM files for train files in each directory\n",
"train_files1_dcm = [k[:-4] + '.dcm' for k in train_files1]\n",
"train_files2_dcm = [k[:-4] + '.dcm' for k in train_files2]\n",
"train_files3_dcm = [k[:-4] + '.dcm' for k in train_files3]\n",
"\n",
"# Filter train bags based on DCM file existence\n",
"train_bags = train_bags[\n",
" train_bags.instance_name.isin(train_files1_dcm) |\n",
" train_bags.instance_name.isin(train_files2_dcm) |\n",
" train_bags.instance_name.isin(train_files3_dcm)\n",
"]"
],
"metadata": {
"id": "IKYKDKPeFwQb"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"##########################\n",
"### BAGS PREPROCESSING ###\n",
"##########################\n",
"\n",
"# Set the desired bag size\n",
"bag_size = 57\n",
"\n",
"# Create additional train bags to reach the desired bag size\n",
"added_train_bags = pd.DataFrame()\n",
"for idx in train_bags.bag_name.unique():\n",
" bags = train_bags[train_bags.bag_name==idx].copy()\n",
" num_add = bag_size - len(bags.instance_name)\n",
"\n",
" aux = bags.iloc[0].copy()\n",
" aux.instance_label = 0\n",
" aux.instance_name = 'all_zeros'\n",
" for i in range(num_add):\n",
" added_train_bags = added_train_bags.append(aux)\n",
"\n",
"train_bags = train_bags.append(added_train_bags)\n",
"\n",
"# Convert bags data to dictionaries for optimization\n",
"train_bags_dic = {k: list(train_bags[train_bags.bag_name==k].instance_name) for k in train_bags.bag_name.unique()}"
],
"metadata": {
"id": "iOAtdE9nG7zt"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"####################\n",
"### DATALOADER ###\n",
"####################\n",
"\n",
"dim=(512,512,bag_size)\n",
"class DataGeneratorMIL(keras.utils.Sequence):\n",
" 'Generates data for Keras'\n",
"\n",
" def __init__(self, list_IDs, labels=None, batch_size=256, dim=(512,512,512), n_channels=3,\n",
" n_classes=2, shuffle=True, is_train=True):\n",
" 'Initialization'\n",
" self.dim = dim\n",
" self.batch_size = batch_size\n",
" self.labels = labels\n",
" self.is_train = (labels is not None) and is_train\n",
" self.list_IDs = list_IDs\n",
" self.n_channels = n_channels\n",
" self.n_classes = n_classes\n",
" self.shuffle = shuffle\n",
" self.on_epoch_end()\n",
"\n",
" def __len__(self):\n",
" 'Denotes the number of batches per epoch'\n",
" return int(np.floor(len(self.list_IDs) / self.batch_size))\n",
"\n",
" def __getitem__(self, index):\n",
" 'Generate one batch of data'\n",
" # Generate indexes of the batch\n",
" list_IDs_temp = self.list_IDs[index*self.batch_size:(index+1)*self.batch_size]\n",
"\n",
" X = self.__data_generation(list_IDs_temp)\n",
" # Generate data\n",
" if self.is_train:\n",
" y = self.labels[index*self.batch_size:(index+1)*self.batch_size]\n",
" return np.array(X), np.array(y, dtype='float64')\n",
" else:\n",
" return np.array(X)\n",
"\n",
" def on_epoch_end(self):\n",
" 'Updates indexes after each epoch'\n",
" self.indexes = np.arange(len(self.list_IDs))\n",
" if self.shuffle == True:\n",
" np.random.shuffle(self.indexes)\n",
"\n",
" def __data_generation(self, list_IDs_temp):\n",
" 'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)\n",
" # Initialization\n",
" X = np.empty((self.batch_size, *self.dim, self.n_channels))\n",
"\n",
" # Generate data\n",
" for i, ID in enumerate(list_IDs_temp):\n",
" # Store sample\n",
" if self.is_train:\n",
" ids = train_bags_dic[ID]\n",
" else:\n",
" ids = test_bags_dic[ID]\n",
" imgs = []\n",
" for idx in ids:\n",
" if idx == 'all_zeros':\n",
" img = np.zeros((self.dim[0], self.dim[1], self.n_channels))\n",
" imgs.append(img)\n",
" continue\n",
" if self.is_train:\n",
" _dir = train_files_loc[idx]\n",
" img = np.load(_dir + idx[:-4] + '.npy')\n",
" img = cv2.resize(img, (self.dim[1], self.dim[0]))\n",
" imgs.append(img)\n",
" else:\n",
" img = np.load(test_images_dir + idx[:-4] + '.npy')\n",
" img = cv2.resize(img, (self.dim[1], self.dim[0]))\n",
" imgs.append(img)\n",
" X[i,] = np.transpose(imgs, [1,2,0,3])\n",
"\n",
" return X"
],
"metadata": {
"id": "PgPerL5LHItj"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"########################\n",
"### TRAIN/TEST SPLIT ###\n",
"########################\n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"N = len(train_bags.bag_name.unique())\n",
"bags = train_bags.groupby('bag_name').max()\n",
"\n",
"X_train, X_val, y_train, y_val = train_test_split(np.array(bags.index)[:], bags.bag_label[:],\n",
" test_size=0.20, random_state=0,\n",
" stratify=bags.bag_label[:])\n",
"\n",
"batch_size = 4\n",
"\n",
"# Creating the train dataset using DataGeneratorMIL\n",
"train_dataset = DataGeneratorMIL(X_train, y_train, batch_size=batch_size, dim=dim)\n",
"\n",
"# Creating the validation dataset using DataGeneratorMIL\n",
"val_dataset = DataGeneratorMIL(X_val, y_val, batch_size=batch_size, dim=dim, is_augment=False)\n",
"\n",
"# Calculating class weights for imbalanced classes\n",
"classes, counts = np.unique(train_dataset.labels, return_counts=True)\n",
"class_weights = {k: (1 - v/sum(counts)) for k, v in zip(classes, counts)}\n"
],
"metadata": {
"id": "tLyJt9LbHsdz"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"####################\n",
"### MODEL ###\n",
"####################\n",
"\n",
"# MILAttentionLayer\n",
"class MILAttentionLayer(layers.Layer):\n",
" \"\"\"Implementation of the attention-based Deep MIL layer.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" weight_params_dim,\n",
" kernel_initializer=\"glorot_uniform\",\n",
" kernel_regularizer=None,\n",
" use_gated=False,\n",
" **kwargs,\n",
" ):\n",
" super().__init__(**kwargs)\n",
"\n",
" self.weight_params_dim = weight_params_dim\n",
" self.use_gated = use_gated\n",
"\n",
" self.kernel_initializer = keras.initializers.get(kernel_initializer)\n",
" self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)\n",
"\n",
" self.v_init = self.kernel_initializer\n",
" self.w_init = self.kernel_initializer\n",
" self.u_init = self.kernel_initializer\n",
"\n",
" self.v_regularizer = self.kernel_regularizer\n",
" self.w_regularizer = self.kernel_regularizer\n",
" self.u_regularizer = self.kernel_regularizer\n",
"\n",
" def build(self, input_shape):\n",
" input_dim = input_shape[1]\n",
"\n",
" self.v_weight_params = self.add_weight(\n",
" shape=(input_dim, self.weight_params_dim),\n",
" initializer=self.v_init,\n",
" name=\"v\",\n",
" regularizer=self.v_regularizer,\n",
" trainable=True,\n",
" )\n",
"\n",
" self.w_weight_params = self.add_weight(\n",
" shape=(self.weight_params_dim, 1),\n",
" initializer=self.w_init,\n",
" name=\"w\",\n",
" regularizer=self.w_regularizer,\n",
" trainable=True,\n",
" )\n",
"\n",
" if self.use_gated:\n",
" self.u_weight_params = self.add_weight(\n",
" shape=(input_dim, self.weight_params_dim),\n",
" initializer=self.u_init,\n",
" name=\"u\",\n",
" regularizer=self.u_regularizer,\n",
" trainable=True,\n",
" )\n",
" else:\n",
" self.u_weight_params = None\n",
"\n",
" self.input_built = True\n",
"\n",
" def call(self, inputs):\n",
" instances = self.compute_attention_scores(inputs)\n",
" instances = tf.reshape(instances, shape=(-1, dim[2]))\n",
" alpha = tf.math.softmax(instances, axis=1)\n",
" return alpha\n",
"\n",
" def compute_attention_scores(self, instance):\n",
" original_instance = instance\n",
" instance = tf.math.tanh(tf.tensordot(instance, self.v_weight_params, axes=1))\n",
"\n",
" if self.use_gated:\n",
" instance = instance * tf.math.sigmoid(\n",
" tf.tensordot(original_instance, self.u_weight_params, axes=1)\n",
" )\n",
"\n",
" return tf.tensordot(instance, self.w_weight_params, axes=1)\n",
"\n",
"\n",
"# Model\n",
"num_data = batch_size\n",
"D = bag_size\n",
"\n",
"Conv1 = layers.Conv2D(16, (5, 5), data_format=\"channels_last\", activation='relu', kernel_initializer='glorot_uniform', padding='same')\n",
"Conv2 = layers.Conv2D(32, (3,3), data_format=\"channels_last\", activation='relu')\n",
"Conv3 = layers.Conv2D(32, (3,3), data_format=\"channels_last\", activation='relu')\n",
"Conv4 = layers.Conv2D(32, (3,3), data_format=\"channels_last\", activation='relu')\n",
"Conv5 = layers.Conv2D(32, (3,3), data_format=\"channels_last\", activation='relu')\n",
"Conv6 = layers.Conv2D(32, (3,3), data_format=\"channels_last\", activation='relu')\n",
"\n",
"def VGG(inp):\n",
" inp = tf.reshape(tf.transpose(inp, perm=(0,3,1,2,4)), shape=(-1, dim[0], dim[1], 3))\n",
" x = Conv1(inp)\n",
" x = layers.BatchNormalization()(x)\n",
" x = layers.MaxPool2D((2, 2), data_format=\"channels_last\", strides=(2, 2))(x)\n",
" x = Conv2(x)\n",
" x = layers.BatchNormalization()(x)\n",
" x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
" x = layers.Dropout(0.3)(x)\n",
"\n",
" x = Conv3(x)\n",
" x = layers.BatchNormalization()(x)\n",
" x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
" x = Conv4(x)\n",
" x = layers.BatchNormalization()(x)\n",
" x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
"\n",
" x = Conv5(x)\n",
" x = layers.BatchNormalization()(x)\n",
" x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
" x = layers.Dropout(0.3)(x)\n",
"\n",
" x = Conv6(x)\n",
" x = layers.BatchNormalization()(x)\n",
" x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
" x = layers.Dropout(0.3)(x)\n",
"\n",
" return layers.Flatten()(x)\n",
"\n",
"def build_model():\n",
" inp = keras.Input(shape=(*dim, 3))\n",
" H = VGG(inp)\n",
" A = MILAttentionLayer(\n",
" weight_params_dim=64,\n",
" kernel_regularizer=keras.regularizers.l2(0.01),\n",
" use_gated=True,\n",
" name=\"alpha\",\n",
" )(H)\n",
" H = tf.reshape(H, shape=(-1, dim[2], H.shape[1]))\n",
" A = tf.expand_dims(A, axis=1)\n",
" intermediate = tf.linalg.matmul(A, H)\n",
" intermediate = tf.squeeze(intermediate, axis=1)\n",
" intermediate = layers.Dropout(0.25)(intermediate)\n",
" intermediate = layers.Dense(128)(intermediate)\n",
" out = layers.Dense(1, activation='sigmoid')(intermediate)\n",
" return keras.Model(inputs=inp, outputs=out)\n",
"\n",
"model = build_model()\n",
"\n",
"auc = tf.keras.metrics.AUC()\n",
"adam = tf.compat.v1.train.AdamOptimizer(learning_rate=5e-5)\n",
"model.compile(\n",
" optimizer=adam,\n",
" loss='binary_crossentropy',\n",
" metrics=[auc, 'accuracy']\n",
")\n",
"earlyStopping = EarlyStopping(monitor='val_loss', patience=8, verbose=1, mode='min')\n",
"print(model.summary())"
],
"metadata": {
"id": "s_NklWE8HypS"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"####################\n",
"### Train ###\n",
"####################\n",
"for i in range(0, 5):\n",
" checkpoint_path = \"./model/att_{}.ckpt\".format(i)\n",
" checkpoint_dir = os.path.dirname(checkpoint_path)\n",
"\n",
" cp_callback = keras.callbacks.ModelCheckpoint(\n",
" filepath=checkpoint_path,\n",
" monitor='val_loss',\n",
" save_best_only=True,\n",
" save_weights_only=True,\n",
" verbose=1,\n",
" mode='min'\n",
" )\n",
"\n",
"\n",
" history = model.fit(\n",
" train_dataset,\n",
" validation_data=val_dataset,\n",
" epochs=200,\n",
" callbacks=[earlyStopping, cp_callback],\n",
" class_weight=class_weights\n",
" )\n",
"\n",
" hist_df = pd.DataFrame(history.history)\n",
" hist_csv_file = './log.csv'\n",
"\n",
" with open(hist_csv_file, mode='w') as f:\n",
" hist_df.to_csv(f)"
],
"metadata": {
"id": "_4jZX1ChItFa"
},
"execution_count": null,
"outputs": []
}
]
}