a b/Notebooks/usage.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "markdown",
5
   "metadata": {},
6
   "source": [
7
    "## Training Example"
8
   ]
9
  },
10
  {
11
   "cell_type": "markdown",
12
   "metadata": {},
13
   "source": [
14
    "#### Training a model is very simple, follow this example to train your own model"
15
   ]
16
  },
17
  {
18
   "cell_type": "code",
19
   "execution_count": null,
20
   "metadata": {},
21
   "outputs": [],
22
   "source": [
23
    "#First import the training tool and the torchio library\n",
24
    "import sys\n",
25
    "sys.path.append('../Radiology_and_AI')\n",
26
    "from training.run_training import run_training\n",
27
    "import torchio as tio"
28
   ]
29
  },
30
  {
31
   "cell_type": "code",
32
   "execution_count": 4,
33
   "metadata": {},
34
   "outputs": [],
35
   "source": [
36
    "#Next define what transforms you want applied to the training data\n",
37
    "#Both the training and validation data must have the same normalization and data preparation steps\n",
38
    "#Only the training samples should have the augmentations applied\n",
39
    "#Any transforms found at https://torchio.readthedocs.io/transforms/transforms.html can be applied\n",
40
    "#Keep track of the  normalization and data preparation steps steps performed, you will need to apply the to all data passed into the model into the future\n",
41
    "\n",
42
    "#These transforms are applied to data before it is used for training the model\n",
43
    "training_transform = tio.Compose([\n",
44
    "    #Normalization\n",
45
    "    tio.ZNormalization(masking_method=tio.ZNormalization.mean), \n",
46
    "    \n",
47
    "    #Augmentation\n",
48
    "    #Play around with different augmentations as you desire, refer to the torchio docs to see how they work\n",
49
    "    tio.RandomNoise(p=0.5),\n",
50
    "    tio.RandomGamma(log_gamma=(-0.3, 0.3)),\n",
51
    "    tio.RandomElasticDeformation(),\n",
52
    "    \n",
53
    "    #Preparation\n",
54
    "    tio.CropOrPad((240, 240, 160)), #Crop/pad the images to a dimension your model can handle, our default unnet model requires the dimensions be multiples of 8\n",
55
    "    tio.OneHot(num_classes=5), #Set num_classes to the max segmentation label + 1\n",
56
    "    \n",
57
    "])\n",
58
    "\n",
59
    "#These transforms are applied to data before it is used to determined the performance of the model on the validation set\n",
60
    "validation_transform = tio.Compose([\n",
61
    "    #Normalization\n",
62
    "    tio.ZNormalization(masking_method=tio.ZNormalization.mean),\n",
63
    "    \n",
64
    "    #Preparation\n",
65
    "    tio.CropOrPad((240, 240, 160)),        \n",
66
    "    tio.OneHot(num_classes=5)    \n",
67
    "    \n",
68
    "])"
69
   ]
70
  },
71
  {
72
   "cell_type": "code",
73
   "execution_count": null,
74
   "metadata": {},
75
   "outputs": [],
76
   "source": [
77
    "#The run training method applies the transforms you set and trains a model based on the parameters set here\n",
78
    "run_training(\n",
79
    "    #input_data_path must be set to the path to the folder containing the subfolders for each training example.\n",
80
    "    #Each subfolder should contain one nii.gz file for each of the imaging series and the segmentation for that example\n",
81
    "    #The name of each nii.gz file should be the name of the parent folder followed by the name of the imaging series type or seg if it is the segmentation\n",
82
    "    #For example,MICCAI_BraTS2020_TrainingData contains ~300 folders, each corresponding to an input example,\n",
83
    "    # one folder BraTS20_Training_001, contains five files: BraTS20_Training_001_flair.nii.gz, BraTS20_Training_001_seg.nii.gz, BraTS20_Training_001_t1.nii.gz , BraTS20_Training_001_t2.nii.gz,and BraTS20_Training_001_t1ce.nii.gz\n",
84
    "    input_data_path = '../../brats_new/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData',\n",
85
    "    \n",
86
    "    #Where you want your trained model to be saved after training is completed\n",
87
    "    output_model_path = '../Models/test_train_many_1e-3.pt',\n",
88
    "    \n",
89
    "    #The transforms you created previously\n",
90
    "    training_transform = training_transform,    \n",
91
    "    validation_transform = validation_transform,\n",
92
    "    \n",
93
    "    #The names of the modalities every example in your input data has\n",
94
    "    input_channels_list = ['flair','t1','t2','t1ce'],\n",
95
    "    \n",
96
    "    #Which of the labels in your segmentation you want to train your model to predict\n",
97
    "    seg_channels = [1,2,4],\n",
98
    "    \n",
99
    "    #The name of the type of model you want to train, currently UNet3D is the only available model\n",
100
    "    model_type = 'UNet3D',\n",
101
    "    \n",
102
    "    #The amount of examples per training batch, reduce/increase this based on memory availability\n",
103
    "    batch_size = 1,\n",
104
    "    \n",
105
    "    #The amount of cpus you want to be avaiable for loading the input data into the model\n",
106
    "    num_loading_cpus = 1,\n",
107
    "    \n",
108
    "    #The learning rate of the AdamW optimizer\n",
109
    "    learning_rate = 1e-3,\n",
110
    "    \n",
111
    "    #Whether or not you want to run wandb logging of your run, install wandb to use these parameters\n",
112
    "    wandb_logging = False,\n",
113
    "    wandb_project_name = None,\n",
114
    "    wandb_run_name = None,\n",
115
    "    \n",
116
    "    #The seed determines how your training and validation data will be randomly split\n",
117
    "    #training_split_ratio is the share of your input data you want to use for training the model, the remainder is used for the validation data\n",
118
    "    #Keep track of both the seed and ratio used if you want to be able to split your input data the same way in the future\n",
119
    "    seed=42,    \n",
120
    "    training_split_ratio = 0.9,\n",
121
    "    \n",
122
    "    #Any parameters which can be applied to a pytorch lightning trainer can also be applied, below is a selection of parameters you can apply\n",
123
    "    #Refer to https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-class-api to see the other parameters you could apply\n",
124
    "    max_epochs=10,\n",
125
    "    amp_backend = 'apex',\n",
126
    "    amp_level = 'O1',\n",
127
    "    precision=16,\n",
128
    "    check_val_every_n_epoch = 1,\n",
129
    "    log_every_n_steps=10,      \n",
130
    "    val_check_interval= 50,\n",
131
    "    progress_bar_refresh_rate=1,      \n",
132
    ")"
133
   ]
134
  },
135
  {
136
   "cell_type": "markdown",
137
   "metadata": {},
138
   "source": [
139
    "## Evaluation Example"
140
   ]
141
  },
142
  {
143
   "cell_type": "markdown",
144
   "metadata": {},
145
   "source": [
146
    "#### If you want to evaluate your model in the future on a certain test dataset follow the below"
147
   ]
148
  },
149
  {
150
   "cell_type": "code",
151
   "execution_count": null,
152
   "metadata": {},
153
   "outputs": [],
154
   "source": [
155
    "#First import the training tool and the torchio library\n",
156
    "import sys\n",
157
    "sys.path.append('.../Radiology_and_AI')\n",
158
    "from training.run_training import run_eval\n",
159
    "import torchio as tio"
160
   ]
161
  },
162
  {
163
   "cell_type": "code",
164
   "execution_count": null,
165
   "metadata": {},
166
   "outputs": [],
167
   "source": [
168
    "#Whatever normalization and data preperation steps you performed must also be applied here\n",
169
    "#Refer to the above for more info\n",
170
    "#These transforms are applied to data before it is used to determined the performance of the model on the validation set\n",
171
    "test_transform = tio.Compose([\n",
172
    "    #Normalization\n",
173
    "    tio.ZNormalization(masking_method=tio.ZNormalization.mean),\n",
174
    "    \n",
175
    "    #Preparation\n",
176
    "    tio.CropOrPad((240, 240, 160)),        \n",
177
    "    tio.OneHot(num_classes=5)    \n",
178
    "    \n",
179
    "])"
180
   ]
181
  },
182
  {
183
   "cell_type": "code",
184
   "execution_count": null,
185
   "metadata": {},
186
   "outputs": [],
187
   "source": [
188
    "#The run_eval method evaluates and prints your models performance on a test dataset by averaging the Dice loss per batch\n",
189
    "run_eval(\n",
190
    "    #The path to the folder containing the data, refer to the training example for more info\n",
191
    "    input_data_path= '../../brats_new/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData',\n",
192
    "    \n",
193
    "    #The path to the saved model weights\n",
194
    "    model_path=\"../../randgamma.pt\",\n",
195
    "    \n",
196
    "    #The transforms you specified above\n",
197
    "    validation_transform=validation_transform,   \n",
198
    "    \n",
199
    "    #The names of the modalities every example in your input data has\n",
200
    "    input_channels_list = ['flair','t1','t2','t1ce'],\n",
201
    "    #Which of the labels in your segmentation you want to train your model to predict\n",
202
    "    seg_channels = [1,2,4],\n",
203
    "    #The name of the type of model you want to train, currently UNet3D is the only available model\n",
204
    "    model_type = 'UNet3D'\n",
205
    "    \n",
206
    "    #If set to true, we only return the performance of the model on the example which were not used for training, based on the train_val_split_ration and seed\n",
207
    "    #If false we evaluate on all data and ignore seed and training_split_ratio,\n",
208
    "    #set to false if input_data_path is set to a dataset you did not use during training\n",
209
    "    is_validation_data = True,\n",
210
    "    training_split_ratio=0.9,\n",
211
    "    seed=42,\n",
212
    "    \n",
213
    "    #The amount of examples per training batch, reduce/increase this based on memory availability\n",
214
    "    batch_size=1,\n",
215
    "    #The amount of cpus you want to be avaiable for loading the input data into the model\n",
216
    "    num_loading_cpus = 1,   \n",
217
    ")"
218
   ]
219
  },
220
  {
221
   "cell_type": "markdown",
222
   "metadata": {},
223
   "source": [
224
    "## Visualization Example"
225
   ]
226
  },
227
  {
228
   "cell_type": "markdown",
229
   "metadata": {},
230
   "source": [
231
    "#### Tools for generating gifs, slices, and nifti files from input data and model predictions"
232
   ]
233
  },
234
  {
235
   "cell_type": "code",
236
   "execution_count": 10,
237
   "metadata": {},
238
   "outputs": [],
239
   "source": [
240
    "#First import the training tool and the torchio library\n",
241
    "import sys\n",
242
    "sys.path.append('../Radiology_and_AI')\n",
243
    "sys.path.append('../../MedicalZooPytorch')\n",
244
    "from visuals.run_visualization import gen_visuals\n",
245
    "import torchio as tio"
246
   ]
247
  },
248
  {
249
   "cell_type": "code",
250
   "execution_count": 11,
251
   "metadata": {},
252
   "outputs": [],
253
   "source": [
254
    "#Whatever normalization and data preperation steps you performed must also be applied here\n",
255
    "#Refer to the above for more info\n",
256
    "#These transforms are applied to data before it is used to determined the performance of the model on the validation set\n",
257
    "validation_transform = tio.Compose([\n",
258
    "    #Normalization\n",
259
    "    tio.ZNormalization(masking_method=tio.ZNormalization.mean),\n",
260
    "    \n",
261
    "    #Preparation\n",
262
    "    tio.CropOrPad((240, 240, 160)),        \n",
263
    "    tio.OneHot(num_classes=5)        \n",
264
    "])"
265
   ]
266
  },
267
  {
268
   "cell_type": "code",
269
   "execution_count": 12,
270
   "metadata": {},
271
   "outputs": [
272
    {
273
     "name": "stderr",
274
     "output_type": "stream",
275
     "text": [
276
      "/home/cameron/storage/miniconda3/envs/cameronenv/lib/python3.8/site-packages/matplotlib/image.py:446: UserWarning: Warning: converting a masked element to nan.\n",
277
      "  dv = np.float64(self.norm.vmax) - np.float64(self.norm.vmin)\n",
278
      "/home/cameron/storage/miniconda3/envs/cameronenv/lib/python3.8/site-packages/matplotlib/image.py:453: UserWarning: Warning: converting a masked element to nan.\n",
279
      "  a_min = np.float64(newmin)\n",
280
      "/home/cameron/storage/miniconda3/envs/cameronenv/lib/python3.8/site-packages/matplotlib/image.py:458: UserWarning: Warning: converting a masked element to nan.\n",
281
      "  a_max = np.float64(newmax)\n"
282
     ]
283
    }
284
   ],
285
   "source": [
286
    "#The gen_visuals method can be used for generating gifs of the inpu\n",
287
    "gen_visuals(\n",
288
    "    #The path to the folder containing the nifti files for an example\n",
289
    "    image_path=\"../../brats_new/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_010\",\n",
290
    "    \n",
291
    "    #The transforms applied to the input should the same applied to the validation data during model training\n",
292
    "    transforms = validation_transform,\n",
293
    "    \n",
294
    "    #The path to the model to use for predictions \n",
295
    "    model_path =  \"../Models/test_train_many_1e-3.pt\",\n",
296
    "    \n",
297
    "    #Generate visuals using segmentations generated by the model\n",
298
    "    gen_pred = True,\n",
299
    "    #Generate visuals using annotated segmentations\n",
300
    "    gen_true = True,\n",
301
    "    \n",
302
    "    #The modalities your input example has\n",
303
    "    input_channels_list = ['flair','t1','t2','t1ce'],\n",
304
    "    #The labels your segmentation has\n",
305
    "    seg_channels = [1,2,4],\n",
306
    "\n",
307
    "    #Save a gif of the brain in 3D spinning on its vertical axis\n",
308
    "    gen_gif = False,\n",
309
    "    #Where to output the gif of the brain with segmentations either from the annotated labels or the predicted labels\n",
310
    "    true_gif_output_path = \"../../output/true\",\n",
311
    "    pred_gif_output_path = \"../../output/pred\",    \n",
312
    "    #Which segmentation labels to display in the gif\n",
313
    "    seg_channels_to_display_gif = [1,2,4],\n",
314
    "    #The angle from the horizontal axis you are looking down on the brain at as it is spinning\n",
315
    "    gif_view_angle = 30,\n",
316
    "    #How much the brain rotates between images of the gif\n",
317
    "    gif_angle_rotation = 20,\n",
318
    "    #fig size of the gif images\n",
319
    "    fig_size_gif = (50,25),\n",
320
    "\n",
321
    "    #Save an image of slices of the brain at different views and with segmentations\n",
322
    "    gen_slice = True,\n",
323
    "    #where to save the generated slice image\n",
324
    "    slice_output_path = \"../../output/slices\",\n",
325
    "    #Fig size of the slice images\n",
326
    "    fig_size_slice = (25,50),\n",
327
    "    #Which seg labels to display in the slice, they will be layered in this order on the image\n",
328
    "    seg_channels_to_display_slice = [2,4,1],\n",
329
    "    #Which slice to display for different views of the brain\n",
330
    "    sag_slice = None, #Sagittal\n",
331
    "    cor_slice = None, #Coronal\n",
332
    "    axi_slice = None, #Axial\n",
333
    "    disp_slice_base = True, #WHether or not to display the input image in the background\n",
334
    "    slice_title = None, #THe title of the slice images figure\n",
335
    "\n",
336
    "    gen_nifti = True, #Whether or not to generate nifti files for the input image and the segmentations\n",
337
    "    nifti_output_path = \"../../output/nifti\", #WHere to ssave the nifti files\n",
338
    ")"
339
   ]
340
  },
341
  {
342
   "cell_type": "code",
343
   "execution_count": null,
344
   "metadata": {},
345
   "outputs": [],
346
   "source": []
347
  }
348
 ],
349
 "metadata": {
350
  "kernelspec": {
351
   "display_name": "cameronenvironment",
352
   "language": "python",
353
   "name": "cameronenvironment"
354
  },
355
  "language_info": {
356
   "codemirror_mode": {
357
    "name": "ipython",
358
    "version": 3
359
   },
360
   "file_extension": ".py",
361
   "mimetype": "text/x-python",
362
   "name": "python",
363
   "nbconvert_exporter": "python",
364
   "pygments_lexer": "ipython3",
365
   "version": "3.8.5"
366
  }
367
 },
368
 "nbformat": 4,
369
 "nbformat_minor": 4
370
}