a b/code/config.py
1
"""
2
DeepSlide
3
Contains all hyperparameters for the entire repository.
4
5
Authors: Jason Wei, Behnaz Abdollahi, Saeed Hassanpour
6
"""
7
8
import argparse
9
from pathlib import Path
10
11
import torch
12
13
from compute_stats import compute_stats
14
from utils import (get_classes, get_log_csv_name)
15
16
# Source: https://stackoverflow.com/questions/12151306/argparse-way-to-include-default-values-in-help
17
parser = argparse.ArgumentParser(
18
    description="DeepSlide",
19
    formatter_class=argparse.ArgumentDefaultsHelpFormatter)
20
21
###########################################
22
#               USER INPUTS               #
23
###########################################
24
# Input folders for training images.
25
# Must contain subfolders of images labelled by class.
26
# If your two classes are 'a' and 'n', you must have a/*.jpg with the images in class a and
27
# n/*.jpg with the images in class n.
28
parser.add_argument(
29
    "--all_wsi",
30
    type=Path,
31
    default=Path("all_wsi"),
32
    help="Location of the WSI organized in subfolders by class")
33
# For splitting into validation set.
34
parser.add_argument("--val_wsi_per_class",
35
                    type=int,
36
                    default=20,
37
                    help="Number of WSI per class to use in validation set")
38
# For splitting into testing set, remaining images used in train.
39
parser.add_argument("--test_wsi_per_class",
40
                    type=int,
41
                    default=30,
42
                    help="Number of WSI per class to use in test set")
43
# When splitting, do you want to move WSI or copy them?
44
parser.add_argument(
45
    "--keep_orig_copy",
46
    type=bool,
47
    default=True,
48
    help=
49
    "Whether to move or copy the WSI when splitting into training, validation, and test sets"
50
)
51
52
#######################################
53
#               GENERAL               #
54
#######################################
55
# Number of processes to use.
56
parser.add_argument("--num_workers",
57
                    type=int,
58
                    default=8,
59
                    help="Number of workers to use for IO")
60
# Default shape for ResNet in PyTorch.
61
parser.add_argument("--patch_size",
62
                    type=int,
63
                    default=224,
64
                    help="Size of the patches extracted from the WSI")
65
66
##########################################
67
#               DATA SPLIT               #
68
##########################################
69
# The names of your to-be folders.
70
parser.add_argument("--wsi_train",
71
                    type=Path,
72
                    default=Path("wsi_train"),
73
                    help="Location to be created to store WSI for training")
74
parser.add_argument("--wsi_val",
75
                    type=Path,
76
                    default=Path("wsi_val"),
77
                    help="Location to be created to store WSI for validation")
78
parser.add_argument("--wsi_test",
79
                    type=Path,
80
                    default=Path("wsi_test"),
81
                    help="Location to be created to store WSI for testing")
82
83
# Where the CSV file labels will go.
84
parser.add_argument("--labels_train",
85
                    type=Path,
86
                    default=Path("labels_train.csv"),
87
                    help="Location to store the CSV file labels for training")
88
parser.add_argument(
89
    "--labels_val",
90
    type=Path,
91
    default=Path("labels_val.csv"),
92
    help="Location to store the CSV file labels for validation")
93
parser.add_argument("--labels_test",
94
                    type=Path,
95
                    default=Path("labels_test.csv"),
96
                    help="Location to store the CSV file labels for testing")
97
98
###############################################################
99
#               PROCESSING AND PATCH GENERATION               #
100
###############################################################
101
# This is the input for model training, automatically built.
102
parser.add_argument(
103
    "--train_folder",
104
    type=Path,
105
    default=Path("train_folder"),
106
    help="Location of the automatically built training input folder")
107
108
# Folders of patches by WSI in training set, used for finding training accuracy at WSI level.
109
parser.add_argument(
110
    "--patches_eval_train",
111
    type=Path,
112
    default=Path("patches_eval_train"),
113
    help=
114
    "Folders of patches by WSI in training set, used for finding training accuracy at WSI level"
115
)
116
# Folders of patches by WSI in validation set, used for finding validation accuracy at WSI level.
117
parser.add_argument(
118
    "--patches_eval_val",
119
    type=Path,
120
    default=Path("patches_eval_val"),
121
    help=
122
    "Folders of patches by WSI in validation set, used for finding validation accuracy at WSI level"
123
)
124
# Folders of patches by WSI in test set, used for finding test accuracy at WSI level.
125
parser.add_argument(
126
    "--patches_eval_test",
127
    type=Path,
128
    default=Path("patches_eval_test"),
129
    help=
130
    "Folders of patches by WSI in testing set, used for finding test accuracy at WSI level"
131
)
132
133
# Target number of training patches per class.
134
parser.add_argument("--num_train_per_class",
135
                    type=int,
136
                    default=80000,
137
                    help="Target number of training samples per class")
138
139
# Only looks for purple images and filters whitespace.
140
parser.add_argument(
141
    "--type_histopath",
142
    type=bool,
143
    default=True,
144
    help="Only look for purple histopathology images and filter whitespace")
145
146
# Number of purple points for region to be considered purple.
147
parser.add_argument(
148
    "--purple_threshold",
149
    type=int,
150
    default=100,
151
    help="Number of purple points for region to be considered purple.")
152
153
# Scalar to use for reducing image to check for purple.
154
parser.add_argument(
155
    "--purple_scale_size",
156
    type=int,
157
    default=15,
158
    help="Scalar to use for reducing image to check for purple.")
159
160
# Sliding window overlap factor (for testing).
161
# For generating patches during the training phase, we slide a window to overlap by some factor.
162
# Must be an integer. 1 means no overlap, 2 means overlap by 1/2, 3 means overlap by 1/3.
163
# Recommend 2 for very high resolution, 3 for medium, and 5 not extremely high resolution images.
164
parser.add_argument("--slide_overlap",
165
                    type=int,
166
                    default=3,
167
                    help="Sliding window overlap factor for the testing phase")
168
169
# Overlap factor to use when generating validation patches.
170
parser.add_argument(
171
    "--gen_val_patches_overlap_factor",
172
    type=float,
173
    default=1.5,
174
    help="Overlap factor to use when generating validation patches.")
175
176
parser.add_argument("--image_ext",
177
                    type=str,
178
                    default="jpg",
179
                    help="Image extension for saving patches")
180
181
# Produce patches for testing and validation by folder.  The code only works
182
# for now when testing and validation are split by folder.
183
parser.add_argument(
184
    "--by_folder",
185
    type=bool,
186
    default=True,
187
    help="Produce patches for testing and validation by folder.")
188
189
#########################################
190
#               TRANSFORM               #
191
#########################################
192
parser.add_argument(
193
    "--color_jitter_brightness",
194
    type=float,
195
    default=0.5,
196
    help=
197
    "Random brightness jitter to use in data augmentation for ColorJitter() transform"
198
)
199
parser.add_argument(
200
    "--color_jitter_contrast",
201
    type=float,
202
    default=0.5,
203
    help=
204
    "Random contrast jitter to use in data augmentation for ColorJitter() transform"
205
)
206
parser.add_argument(
207
    "--color_jitter_saturation",
208
    type=float,
209
    default=0.5,
210
    help=
211
    "Random saturation jitter to use in data augmentation for ColorJitter() transform"
212
)
213
parser.add_argument(
214
    "--color_jitter_hue",
215
    type=float,
216
    default=0.2,
217
    help=
218
    "Random hue jitter to use in data augmentation for ColorJitter() transform"
219
)
220
221
########################################
222
#               TRAINING               #
223
########################################
224
# Model hyperparameters.
225
parser.add_argument("--num_epochs",
226
                    type=int,
227
                    default=20,
228
                    help="Number of epochs for training")
229
# Choose from [18, 34, 50, 101, 152].
230
parser.add_argument(
231
    "--num_layers",
232
    type=int,
233
    default=18,
234
    help=
235
    "Number of layers to use in the ResNet model from [18, 34, 50, 101, 152]")
236
parser.add_argument("--learning_rate",
237
                    type=float,
238
                    default=0.001,
239
                    help="Learning rate to use for gradient descent")
240
parser.add_argument("--batch_size",
241
                    type=int,
242
                    default=16,
243
                    help="Mini-batch size to use for training")
244
parser.add_argument("--weight_decay",
245
                    type=float,
246
                    default=1e-4,
247
                    help="Weight decay (L2 penalty) to use in optimizer")
248
parser.add_argument("--learning_rate_decay",
249
                    type=float,
250
                    default=0.85,
251
                    help="Learning rate decay amount per epoch")
252
parser.add_argument("--resume_checkpoint",
253
                    type=bool,
254
                    default=False,
255
                    help="Resume model from checkpoint file")
256
parser.add_argument("--save_interval",
257
                    type=int,
258
                    default=1,
259
                    help="Number of epochs between saving checkpoints")
260
# Where models are saved.
261
parser.add_argument("--checkpoints_folder",
262
                    type=Path,
263
                    default=Path("checkpoints"),
264
                    help="Directory to save model checkpoints to")
265
266
# Name of checkpoint file to load from.
267
parser.add_argument(
268
    "--checkpoint_file",
269
    type=Path,
270
    default=Path("xyz.pt"),
271
    help="Checkpoint file to load if resume_checkpoint_path is True")
272
# ImageNet pretrain?
273
parser.add_argument("--pretrain",
274
                    type=bool,
275
                    default=False,
276
                    help="Use pretrained ResNet weights")
277
parser.add_argument("--log_folder",
278
                    type=Path,
279
                    default=Path("logs"),
280
                    help="Directory to save logs to")
281
282
##########################################
283
#               PREDICTION               #
284
##########################################
285
# Selecting the best model.
286
# Automatically select the model with the highest validation accuracy.
287
parser.add_argument(
288
    "--auto_select",
289
    type=bool,
290
    default=True,
291
    help="Automatically select the model with the highest validation accuracy")
292
# Where to put the training prediction CSV files.
293
parser.add_argument(
294
    "--preds_train",
295
    type=Path,
296
    default=Path("preds_train"),
297
    help="Directory for outputting training prediction CSV files")
298
# Where to put the validation prediction CSV files.
299
parser.add_argument(
300
    "--preds_val",
301
    type=Path,
302
    default=Path("preds_val"),
303
    help="Directory for outputting validation prediction CSV files")
304
# Where to put the testing prediction CSV files.
305
parser.add_argument(
306
    "--preds_test",
307
    type=Path,
308
    default=Path("preds_test"),
309
    help="Directory for outputting testing prediction CSV files")
310
311
##########################################
312
#               EVALUATION               #
313
##########################################
314
# Folder for outputting WSI predictions based on each threshold.
315
parser.add_argument(
316
    "--inference_train",
317
    type=Path,
318
    default=Path("inference_train"),
319
    help=
320
    "Folder for outputting WSI training predictions based on each threshold")
321
parser.add_argument(
322
    "--inference_val",
323
    type=Path,
324
    default=Path("inference_val"),
325
    help=
326
    "Folder for outputting WSI validation predictions based on each threshold")
327
parser.add_argument(
328
    "--inference_test",
329
    type=Path,
330
    default=Path("inference_test"),
331
    help="Folder for outputting WSI testing predictions based on each threshold"
332
)
333
334
# For visualization.
335
parser.add_argument(
336
    "--vis_train",
337
    type=Path,
338
    default=Path("vis_train"),
339
    help="Folder for outputting the WSI training prediction visualizations")
340
parser.add_argument(
341
    "--vis_val",
342
    type=Path,
343
    default=Path("vis_val"),
344
    help="Folder for outputting the WSI validation prediction visualizations")
345
parser.add_argument(
346
    "--vis_test",
347
    type=Path,
348
    default=Path("vis_test"),
349
    help="Folder for outputting the WSI testing prediction visualizations")
350
351
#######################################################
352
#               ARGUMENTS FROM ARGPARSE               #
353
#######################################################
354
args = parser.parse_args()
355
356
# Device to use for PyTorch code.
357
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
358
359
# Automatically read in the classes.
360
classes = get_classes(folder=args.all_wsi)
361
num_classes = len(classes)
362
363
# This is the input for model training, automatically built.
364
train_patches = args.train_folder.joinpath("train")
365
val_patches = args.train_folder.joinpath("val")
366
367
# Compute the mean and standard deviation of the image patches from the specified folder.
368
path_mean, path_std = compute_stats(folderpath=train_patches,
369
                                    image_ext=args.image_ext)
370
371
# Only used is resume_checkpoint is True.
372
resume_checkpoint_path = args.checkpoints_folder.joinpath(args.checkpoint_file)
373
374
# Named with date and time.
375
log_csv = get_log_csv_name(log_folder=args.log_folder)
376
377
# Does nothing if auto_select is True.
378
eval_model = args.checkpoints_folder.joinpath(args.checkpoint_file)
379
380
# Find the best threshold for filtering noise (discard patches with a confidence less than this threshold).
381
threshold_search = (0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)
382
383
# For visualization.
384
# This order is the same order as your sorted classes.
385
colors = ("red", "white", "blue", "green", "purple", "orange", "black", "pink",
386
          "yellow")
387
388
# Print the configuration.
389
# Source: https://stackoverflow.com/questions/44689546/how-to-print-out-a-dictionary-nicely-in-python/44689627
390
# chr(10) and chr(9) are ways of going around the f-string limitation of
391
# not allowing the '\' character inside.
392
print(f"###############     CONFIGURATION     ###############\n"
393
      f"{chr(10).join(f'{k}:{chr(9)}{v}' for k, v in vars(args).items())}\n"
394
      f"device:\t{device}\n"
395
      f"classes:\t{classes}\n"
396
      f"num_classes:\t{num_classes}\n"
397
      f"train_patches:\t{train_patches}\n"
398
      f"val_patches:\t{val_patches}\n"
399
      f"path_mean:\t{path_mean}\n"
400
      f"path_std:\t{path_std}\n"
401
      f"resume_checkpoint_path:\t{resume_checkpoint_path}\n"
402
      f"log_csv:\t{log_csv}\n"
403
      f"eval_model:\t{eval_model}\n"
404
      f"threshold_search:\t{threshold_search}\n"
405
      f"colors:\t{colors}\n"
406
      f"\n#####################################################\n\n\n")