594 lines (594 with data), 30.5 kB
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "h2430ByQ7FUA",
"outputId": "b92d16bd-df4e-46ed-cd43-e718f6b971c2"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Num GPUs Available: 0\n"
]
}
],
"source": [
"####################\n",
"### LIBRARIES ####\n",
"####################\n",
"\n",
"import numpy as np\n",
"import warnings\n",
"import pandas as pd\n",
"import os\n",
"import matplotlib.pyplot as plt\n",
"import cv2\n",
"\n",
"# Remove TensorFlow warnings\n",
"os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
"\n",
"# Import TensorFlow and Keras for neural network operations\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"from tensorflow.keras.callbacks import EarlyStopping\n",
"from tensorflow.keras.losses import Loss\n",
"from tensorflow.python.framework.ops import disable_eager_execution\n",
"disable_eager_execution()\n",
"\n",
"# Set the default float type for TensorFlow to \"float32\"\n",
"tf.keras.backend.set_floatx(\"float32\")\n",
"\n",
"# Print the number of available GPUs\n",
"print(\"Num GPUs Available: \", len(tf.config.list_physical_devices('GPU')))"
]
},
{
"cell_type": "code",
"source": [
"####################\n",
"### DATA LOADING ###\n",
"####################\n",
"\n",
"print('Starting preprocessing of bags')\n",
"\n",
"# Define directories for image files during testing\n",
"test_images_dir = './test/'\n",
"\n",
"# Get lists of files in the directories\n",
"test_files = os.listdir(test_images_dir)\n",
"\n",
"# Read bag data from CSV files\n",
"test_bags = pd.read_csv(\"./tables/testing_example.csv\")\n",
"\n",
"# Filter test bags based on DCM file existence\n",
"test_files_dcm = [k[:-4] + '.dcm' for k in test_files]\n",
"test_bags = test_bags[test_bags.instance_name.isin(test_files_dcm)]"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "mtZQI9Hs7Juq",
"outputId": "d2d52a49-853d-40b9-b9ae-ebe18ba7cf8f"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Starting preprocessing of bags\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"##########################\n",
"### BAGS PREPROCESSING ###\n",
"##########################\n",
"\n",
"# Set the desired bag size\n",
"bag_size = 57\n",
"\n",
"# Create additional test bags to reach the desired bag size\n",
"added_test_bags = pd.DataFrame()\n",
"for idx in test_bags.bag_name.unique():\n",
" bags = test_bags[test_bags.bag_name==idx].copy()\n",
" num_add = bag_size - len(bags.instance_name)\n",
"\n",
" aux = bags.iloc[0].copy()\n",
" aux.instance_label = 0\n",
" aux.instance_name = 'all_zeros'\n",
" for i in range(num_add):\n",
" added_test_bags = added_test_bags.append(aux)\n",
"\n",
"test_bags = test_bags.append(added_test_bags)\n",
"\n",
"# Convert bags data to dictionaries for optimization\n",
"test_bags_dic = {k: list(test_bags[test_bags.bag_name==k].instance_name) for k in test_bags.bag_name.unique()}"
],
"metadata": {
"id": "5Hhr38Dx7JxO"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"####################\n",
"### DATALOADER ###\n",
"####################\n",
"dim=(512,512,bag_size)\n",
"\n",
"class DataGeneratorMIL(keras.utils.Sequence):\n",
" 'Generates data for Keras'\n",
"\n",
" def __init__(self, list_IDs, labels=None, batch_size=256, dim=(512,512,512), n_channels=3,\n",
" n_classes=2, shuffle=True, is_train=True):\n",
" 'Initialization'\n",
" self.dim = dim\n",
" self.batch_size = batch_size\n",
" self.labels = labels\n",
" self.is_train = (labels is not None) and is_train\n",
" self.list_IDs = list_IDs\n",
" self.n_channels = n_channels\n",
" self.n_classes = n_classes\n",
" self.shuffle = shuffle\n",
" self.on_epoch_end()\n",
"\n",
" def __len__(self):\n",
" 'Denotes the number of batches per epoch'\n",
" return int(np.floor(len(self.list_IDs) / self.batch_size))\n",
"\n",
" def __getitem__(self, index):\n",
" 'Generate one batch of data'\n",
" # Generate indexes of the batch\n",
" list_IDs_temp = self.list_IDs[index*self.batch_size:(index+1)*self.batch_size]\n",
"\n",
" X = self.__data_generation(list_IDs_temp)\n",
" # Generate data\n",
" if self.is_train:\n",
" y = self.labels[index*self.batch_size:(index+1)*self.batch_size]\n",
" return np.array(X), np.array(y, dtype='float64')\n",
" else:\n",
" return np.array(X)\n",
"\n",
" def on_epoch_end(self):\n",
" 'Updates indexes after each epoch'\n",
" self.indexes = np.arange(len(self.list_IDs))\n",
" if self.shuffle == True:\n",
" np.random.shuffle(self.indexes)\n",
"\n",
" def __data_generation(self, list_IDs_temp):\n",
" 'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)\n",
" # Initialization\n",
" X = np.empty((self.batch_size, *self.dim, self.n_channels))\n",
"\n",
" # Generate data\n",
" for i, ID in enumerate(list_IDs_temp):\n",
" # Store sample\n",
" if self.is_train:\n",
" ids = train_bags_dic[ID]\n",
" else:\n",
" ids = test_bags_dic[ID]\n",
" imgs = []\n",
" for idx in ids:\n",
" if idx == 'all_zeros':\n",
" img = np.zeros((self.dim[0], self.dim[1], self.n_channels))\n",
" imgs.append(img)\n",
" continue\n",
" if self.is_train:\n",
" _dir = train_files_loc[idx]\n",
" img = np.load(_dir + idx[:-4] + '.npy')\n",
" img = cv2.resize(img, (self.dim[1], self.dim[0]))\n",
" imgs.append(img)\n",
" else:\n",
" img = np.load(test_images_dir + idx[:-4] + '.npy')\n",
" img = cv2.resize(img, (self.dim[1], self.dim[0]))\n",
" imgs.append(img)\n",
" X[i,] = np.transpose(imgs, [1,2,0,3])\n",
"\n",
" return X"
],
"metadata": {
"id": "pWlaSxnt7Jz3"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"########################\n",
"### Test Generator ###\n",
"########################\n",
"batch_size = 1\n",
"\n",
"# Preparing the test dataset\n",
"bags2 = test_bags.groupby('bag_name').max()\n",
"test_dataset = DataGeneratorMIL(np.array(bags2.index), bags2.bag_label, batch_size=1, dim=dim, is_train=False)"
],
"metadata": {
"id": "w_XGFBbI7J2e"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"####################\n",
"### MODEL ###\n",
"####################\n",
"\n",
"# MILAttentionLayer\n",
"class MILAttentionLayer(layers.Layer):\n",
" \"\"\"Implementation of the attention-based Deep MIL layer.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" weight_params_dim,\n",
" kernel_initializer=\"glorot_uniform\",\n",
" kernel_regularizer=None,\n",
" use_gated=False,\n",
" **kwargs,\n",
" ):\n",
" super().__init__(**kwargs)\n",
"\n",
" self.weight_params_dim = weight_params_dim\n",
" self.use_gated = use_gated\n",
"\n",
" self.kernel_initializer = keras.initializers.get(kernel_initializer)\n",
" self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)\n",
"\n",
" self.v_init = self.kernel_initializer\n",
" self.w_init = self.kernel_initializer\n",
" self.u_init = self.kernel_initializer\n",
"\n",
" self.v_regularizer = self.kernel_regularizer\n",
" self.w_regularizer = self.kernel_regularizer\n",
" self.u_regularizer = self.kernel_regularizer\n",
"\n",
" def build(self, input_shape):\n",
" input_dim = input_shape[1]\n",
"\n",
" self.v_weight_params = self.add_weight(\n",
" shape=(input_dim, self.weight_params_dim),\n",
" initializer=self.v_init,\n",
" name=\"v\",\n",
" regularizer=self.v_regularizer,\n",
" trainable=True,\n",
" )\n",
"\n",
" self.w_weight_params = self.add_weight(\n",
" shape=(self.weight_params_dim, 1),\n",
" initializer=self.w_init,\n",
" name=\"w\",\n",
" regularizer=self.w_regularizer,\n",
" trainable=True,\n",
" )\n",
"\n",
" if self.use_gated:\n",
" self.u_weight_params = self.add_weight(\n",
" shape=(input_dim, self.weight_params_dim),\n",
" initializer=self.u_init,\n",
" name=\"u\",\n",
" regularizer=self.u_regularizer,\n",
" trainable=True,\n",
" )\n",
" else:\n",
" self.u_weight_params = None\n",
"\n",
" self.input_built = True\n",
"\n",
" def call(self, inputs):\n",
" instances = self.compute_attention_scores(inputs)\n",
" instances = tf.reshape(instances, shape=(-1, dim[2]))\n",
" alpha = tf.math.softmax(instances, axis=1)\n",
" return alpha\n",
"\n",
" def compute_attention_scores(self, instance):\n",
" original_instance = instance\n",
" instance = tf.math.tanh(tf.tensordot(instance, self.v_weight_params, axes=1))\n",
"\n",
" if self.use_gated:\n",
" instance = instance * tf.math.sigmoid(\n",
" tf.tensordot(original_instance, self.u_weight_params, axes=1)\n",
" )\n",
"\n",
" return tf.tensordot(instance, self.w_weight_params, axes=1)\n",
"\n",
"\n",
"# Model\n",
"num_data = batch_size\n",
"D = bag_size\n",
"\n",
"Conv1 = layers.Conv2D(16, (5, 5), data_format=\"channels_last\", activation='relu', kernel_initializer='glorot_uniform', padding='same')\n",
"Conv2 = layers.Conv2D(32, (3,3), data_format=\"channels_last\", activation='relu')\n",
"Conv3 = layers.Conv2D(32, (3,3), data_format=\"channels_last\", activation='relu')\n",
"Conv4 = layers.Conv2D(32, (3,3), data_format=\"channels_last\", activation='relu')\n",
"Conv5 = layers.Conv2D(32, (3,3), data_format=\"channels_last\", activation='relu')\n",
"Conv6 = layers.Conv2D(32, (3,3), data_format=\"channels_last\", activation='relu')\n",
"\n",
"def VGG(inp):\n",
" inp = tf.reshape(tf.transpose(inp, perm=(0,3,1,2,4)), shape=(-1, dim[0], dim[1], 3))\n",
" x = Conv1(inp)\n",
" x = layers.BatchNormalization()(x)\n",
" x = layers.MaxPool2D((2, 2), data_format=\"channels_last\", strides=(2, 2))(x)\n",
" x = Conv2(x)\n",
" x = layers.BatchNormalization()(x)\n",
" x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
" x = layers.Dropout(0.3)(x)\n",
"\n",
" x = Conv3(x)\n",
" x = layers.BatchNormalization()(x)\n",
" x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
" x = Conv4(x)\n",
" x = layers.BatchNormalization()(x)\n",
" x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
"\n",
" x = Conv5(x)\n",
" x = layers.BatchNormalization()(x)\n",
" x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
" x = layers.Dropout(0.3)(x)\n",
"\n",
" x = Conv6(x)\n",
" x = layers.BatchNormalization()(x)\n",
" x = layers.MaxPool2D((2, 2), strides=(2, 2), data_format=\"channels_last\")(x)\n",
" x = layers.Dropout(0.3)(x)\n",
"\n",
" return layers.Flatten()(x)\n",
"\n",
"def build_model():\n",
" inp = keras.Input(shape=(*dim, 3))\n",
" H = VGG(inp)\n",
" A = MILAttentionLayer(\n",
" weight_params_dim=64,\n",
" kernel_regularizer=keras.regularizers.l2(0.01),\n",
" use_gated=True,\n",
" name=\"alpha\",\n",
" )(H)\n",
" H = tf.reshape(H, shape=(-1, dim[2], H.shape[1]))\n",
" A = tf.expand_dims(A, axis=1)\n",
" intermediate = tf.linalg.matmul(A, H)\n",
" intermediate = tf.squeeze(intermediate, axis=1)\n",
" intermediate = layers.Dropout(0.25)(intermediate)\n",
" intermediate = layers.Dense(128)(intermediate)\n",
" out = layers.Dense(1, activation='sigmoid')(intermediate)\n",
" return keras.Model(inputs=inp, outputs=out)\n",
"\n",
"model = build_model()\n",
"print(model.summary())"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "T-m_DhYo7J48",
"outputId": "65b2c285-9aae-45b5-e17f-ecc75d17808b"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"WARNING:tensorflow:From /usr/local/lib/python3.10/dist-packages/keras/layers/normalization/batch_normalization.py:581: _colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Colocations handled automatically by placer.\n",
"/usr/local/lib/python3.10/dist-packages/keras/initializers/initializers.py:120: UserWarning: The initializer GlorotUniform is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initalizer instance more than once.\n",
" warnings.warn(\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Model: \"model\"\n",
"__________________________________________________________________________________________________\n",
" Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
" input_1 (InputLayer) [(None, 512, 512, 5 0 [] \n",
" 7, 3)] \n",
" \n",
" tf_op_layer_transpose (TensorF [(None, 57, 512, 51 0 ['input_1[0][0]'] \n",
" lowOpLayer) 2, 3)] \n",
" \n",
" tf_op_layer_Reshape (TensorFlo [(None, 512, 512, 3 0 ['tf_op_layer_transpose[0][0]'] \n",
" wOpLayer) )] \n",
" \n",
" conv2d (Conv2D) (None, 512, 512, 16 1216 ['tf_op_layer_Reshape[0][0]'] \n",
" ) \n",
" \n",
" batch_normalization (BatchNorm (None, 512, 512, 16 64 ['conv2d[0][0]'] \n",
" alization) ) \n",
" \n",
" max_pooling2d (MaxPooling2D) (None, 256, 256, 16 0 ['batch_normalization[0][0]'] \n",
" ) \n",
" \n",
" conv2d_1 (Conv2D) (None, 254, 254, 32 4640 ['max_pooling2d[0][0]'] \n",
" ) \n",
" \n",
" batch_normalization_1 (BatchNo (None, 254, 254, 32 128 ['conv2d_1[0][0]'] \n",
" rmalization) ) \n",
" \n",
" max_pooling2d_1 (MaxPooling2D) (None, 127, 127, 32 0 ['batch_normalization_1[0][0]'] \n",
" ) \n",
" \n",
" dropout (Dropout) (None, 127, 127, 32 0 ['max_pooling2d_1[0][0]'] \n",
" ) \n",
" \n",
" conv2d_2 (Conv2D) (None, 125, 125, 32 9248 ['dropout[0][0]'] \n",
" ) \n",
" \n",
" batch_normalization_2 (BatchNo (None, 125, 125, 32 128 ['conv2d_2[0][0]'] \n",
" rmalization) ) \n",
" \n",
" max_pooling2d_2 (MaxPooling2D) (None, 62, 62, 32) 0 ['batch_normalization_2[0][0]'] \n",
" \n",
" conv2d_3 (Conv2D) (None, 60, 60, 32) 9248 ['max_pooling2d_2[0][0]'] \n",
" \n",
" batch_normalization_3 (BatchNo (None, 60, 60, 32) 128 ['conv2d_3[0][0]'] \n",
" rmalization) \n",
" \n",
" max_pooling2d_3 (MaxPooling2D) (None, 30, 30, 32) 0 ['batch_normalization_3[0][0]'] \n",
" \n",
" conv2d_4 (Conv2D) (None, 28, 28, 32) 9248 ['max_pooling2d_3[0][0]'] \n",
" \n",
" batch_normalization_4 (BatchNo (None, 28, 28, 32) 128 ['conv2d_4[0][0]'] \n",
" rmalization) \n",
" \n",
" max_pooling2d_4 (MaxPooling2D) (None, 14, 14, 32) 0 ['batch_normalization_4[0][0]'] \n",
" \n",
" dropout_1 (Dropout) (None, 14, 14, 32) 0 ['max_pooling2d_4[0][0]'] \n",
" \n",
" conv2d_5 (Conv2D) (None, 12, 12, 32) 9248 ['dropout_1[0][0]'] \n",
" \n",
" batch_normalization_5 (BatchNo (None, 12, 12, 32) 128 ['conv2d_5[0][0]'] \n",
" rmalization) \n",
" \n",
" max_pooling2d_5 (MaxPooling2D) (None, 6, 6, 32) 0 ['batch_normalization_5[0][0]'] \n",
" \n",
" dropout_2 (Dropout) (None, 6, 6, 32) 0 ['max_pooling2d_5[0][0]'] \n",
" \n",
" flatten (Flatten) (None, 1152) 0 ['dropout_2[0][0]'] \n",
" \n",
" alpha (MILAttentionLayer) (None, 57) 147520 ['flatten[0][0]'] \n",
" \n",
" tf_op_layer_ExpandDims (Tensor [(None, 1, 57)] 0 ['alpha[0][0]'] \n",
" FlowOpLayer) \n",
" \n",
" tf_op_layer_Reshape_1 (TensorF [(None, 57, 1152)] 0 ['flatten[0][0]'] \n",
" lowOpLayer) \n",
" \n",
" tf_op_layer_MatMul (TensorFlow [(None, 1, 1152)] 0 ['tf_op_layer_ExpandDims[0][0]', \n",
" OpLayer) 'tf_op_layer_Reshape_1[0][0]'] \n",
" \n",
" tf_op_layer_Squeeze (TensorFlo [(None, 1152)] 0 ['tf_op_layer_MatMul[0][0]'] \n",
" wOpLayer) \n",
" \n",
" dropout_3 (Dropout) (None, 1152) 0 ['tf_op_layer_Squeeze[0][0]'] \n",
" \n",
" dense (Dense) (None, 128) 147584 ['dropout_3[0][0]'] \n",
" \n",
" dense_1 (Dense) (None, 1) 129 ['dense[0][0]'] \n",
" \n",
"==================================================================================================\n",
"Total params: 338,785\n",
"Trainable params: 338,433\n",
"Non-trainable params: 352\n",
"__________________________________________________________________________________________________\n",
"None\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"####################\n",
"### Evaluate ###\n",
"####################\n",
"from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score\n",
"import time\n",
"\n",
"for i in range(0, 5):\n",
"\n",
" # Perform prediction on the test dataset using the saved checkpoint\n",
" checkpoint_path = \"./att_{}.ckpt\".format(i)\n",
" model.load_weights(checkpoint_path)\n",
"\n",
" start = time.process_time()\n",
" preds = model.predict(test_dataset)\n",
" print('time:', time.process_time() - start)\n",
"\n",
" target = bags2.bag_label\n",
"\n",
" # print('AUC:', roc_auc_score(target[:], preds[:, 0]))\n",
" preds_value = (preds[:, 0] > 0.5) * 1\n",
" print('Accuracy:', accuracy_score(target, preds_value))\n",
" print('Precision:', precision_score(target, preds_value))\n",
" print('Recall:', recall_score(target, preds_value))\n",
" print('F1 score:', f1_score(target, preds_value))"
],
"metadata": {
"id": "_W8tkSYT7J9U"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"####################\n",
"### Slice Labels Predictions ###\n",
"####################\n",
"for i in range(0, 5):\n",
"\n",
" # Perform prediction on the test dataset using the saved checkpoint\n",
" checkpoint_path = \"./models/att_{}.ckpt\".format(i)\n",
" model.load_weights(checkpoint_path)\n",
"\n",
" weights_layer = model.get_layer(\"alpha\")\n",
" feature_model = keras.Model([model.inputs], [weights_layer.output, model.output])\n",
" weights, pred = feature_model.predict(test_dataset)\n",
"\n",
" instance_id = []\n",
" bag_id = []\n",
" pred_bag_id = []\n",
" pred_instance_id = []\n",
" for i, ID in enumerate(np.array(bags2.index)):\n",
" ids = test_bags_dic[ID]\n",
" for ii, idx in enumerate(ids):\n",
" instance_id.append(idx)\n",
" bag_id.append(ID)\n",
" pred_bag_id.append(pred[i][0])\n",
" pred_instance_id.append(weights[i,ii])\n",
"\n",
" df_rest = df = pd.DataFrame(list(zip(instance_id, bag_id, pred_bag_id, pred_instance_id)),\n",
" columns =['instance_name', 'bag_name', 'bag_cnn_probability', 'cnn_prediction'])\n",
"\n",
" df_rest.to_csv('test_visual.csv')\n"
],
"metadata": {
"id": "8aVbdSyIq2mu"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "bqwNkhzoxxvU"
},
"execution_count": null,
"outputs": []
}
]
}