a b/sybil/parsing.py
1
import argparse
2
import torch
3
import os
4
import pwd
5
from pytorch_lightning import Trainer
6
7
EMPTY_NAME_ERR = 'Name of augmentation or one of its arguments cant be empty\n\
8
                  Use "name/arg1=value/arg2=value" format'
9
POSS_VAL_NOT_LIST = (
10
    "Flag {} has an invalid list of values: {}. Length of list must be >=1"
11
)
12
13
14
def parse_augmentations(raw_augmentations):
15
    """
16
    Parse the list of augmentations, given by configuration, into a list of
17
    tuple of the augmentations name and a dictionary containing additional args.
18
19
    The augmentation is assumed to be of the form 'name/arg1=value/arg2=value'
20
21
    :raw_augmentations: list of strings [unparsed augmentations]
22
    :returns: list of parsed augmentations [list of (name,additional_args)]
23
24
    """
25
    augmentations = []
26
    for t in raw_augmentations:
27
        arguments = t.split("/")
28
        name = arguments[0]
29
        if name == "":
30
            raise Exception(EMPTY_NAME_ERR)
31
32
        kwargs = {}
33
        if len(arguments) > 1:
34
            for a in arguments[1:]:
35
                splited = a.split("=")
36
                var = splited[0]
37
                val = splited[1] if len(splited) > 1 else None
38
                if var == "":
39
                    raise Exception(EMPTY_NAME_ERR)
40
41
                kwargs[var] = val
42
43
        augmentations.append((name, kwargs))
44
45
    return augmentations
46
47
48
def parse_dispatcher_config(config):
49
    """
50
    Parses an experiment config, and creates jobs. For flags that are expected to be a single item,
51
    but the config contains a list, this will return one job for each item in the list.
52
    :config - experiment_config
53
54
    returns: jobs - a list of flag strings, each of which encapsulates one job.
55
        *Example: --train --cuda --dropout=0.1 ...
56
    returns: experiment_axies - axies that the grid search is searching over
57
    """
58
    jobs = [""]
59
    experiment_axies = []
60
    search_spaces = config["search_space"]
61
62
    # Support a list of search spaces, convert to length one list for backward compatiblity
63
    if not isinstance(search_spaces, list):
64
        search_spaces = [search_spaces]
65
66
    for search_space in search_spaces:
67
        # Go through the tree of possible jobs and enumerate into a list of jobs
68
        for ind, flag in enumerate(search_space):
69
            possible_values = search_space[flag]
70
            if len(possible_values) > 1:
71
                experiment_axies.append(flag)
72
73
            children = []
74
            if len(possible_values) == 0 or type(possible_values) is not list:
75
                raise Exception(POSS_VAL_NOT_LIST.format(flag, possible_values))
76
            for value in possible_values:
77
                for parent_job in jobs:
78
                    if type(value) is bool:
79
                        if value:
80
                            new_job_str = "{} --{}".format(parent_job, flag)
81
                        else:
82
                            new_job_str = parent_job
83
                    elif type(value) is list:
84
                        val_list_str = " ".join([str(v) for v in value])
85
                        new_job_str = "{} --{} {}".format(
86
                            parent_job, flag, val_list_str
87
                        )
88
                    else:
89
                        new_job_str = "{} --{} {}".format(parent_job, flag, value)
90
                    children.append(new_job_str)
91
            jobs = children
92
93
    return jobs, experiment_axies
94
95
96
def parse_args(args_strings=None):
97
    parser = argparse.ArgumentParser(
98
        description="Sandstone research repo. Support Mammograms, CT Scans, Thermal Imaging, Cell Imaging and Chemistry."
99
    )
100
    # setup
101
    parser.add_argument(
102
        "--train",
103
        action="store_true",
104
        default=False,
105
        help="Whether or not to train model",
106
    )
107
    parser.add_argument(
108
        "--test",
109
        action="store_true",
110
        default=False,
111
        help="Whether or not to run model on test set",
112
    )
113
    parser.add_argument(
114
        "--dev",
115
        action="store_true",
116
        default=False,
117
        help="Whether or not to run model on dev set",
118
    )
119
    parser.add_argument(
120
        "--fine_tune",
121
        action="store_true",
122
        default=False,
123
        help="Whether or not to fine_tune model",
124
    )
125
    parser.add_argument(
126
        "--num_epochs_fine_tune",
127
        type=int,
128
        default=1,
129
        help="Num epochs to finetune model",
130
    )
131
132
    # data
133
    parser.add_argument(
134
        "--dataset",
135
        default="nlst",
136
        choices=[
137
            "sybil",
138
            "nlst",
139
            "nlst_risk_factors",
140
            "nlst_for_plco2012",
141
            "nlst_for_plco2019" "mgh",
142
        ],
143
        help="Name of dataset from dataset factory to use [default: nlst]",
144
    )
145
    parser.add_argument(
146
        "--img_size",
147
        type=int,
148
        nargs="+",
149
        default=[256, 256],
150
        help="Width and height of image in pixels. [default: [256,256]]",
151
    )
152
    parser.add_argument(
153
        "--num_chan", type=int, default=3, help="Number of channels for input image"
154
    )
155
    parser.add_argument(
156
        "--img_mean",
157
        type=float,
158
        nargs="+",
159
        default=[128.1722],
160
        help="Mean of image per channel",
161
    )
162
    parser.add_argument(
163
        "--img_std",
164
        type=float,
165
        nargs="+",
166
        default=[87.1849],
167
        help="Standard deviation  of image per channel",
168
    )
169
    parser.add_argument(
170
        "--img_dir",
171
        type=str,
172
        default="/data/rsg/mammogram/NLST/nlst-ct-png",
173
        help="Dir of images. Note, image path in dataset jsons should stem from here",
174
    )
175
    parser.add_argument(
176
        "--img_file_type",
177
        type=str,
178
        default="png",
179
        choices=["png", "dicom"],
180
        help="Type of image. one of [png, dicom]",
181
    )
182
    parser.add_argument(
183
        "--fix_seed_for_multi_image_augmentations",
184
        action="store_true",
185
        default=False,
186
        help="Use same seed for each slice of volume augmentations",
187
    )
188
    parser.add_argument(
189
        "--dataset_file_path",
190
        type=str,
191
        default="/Mounts/rbg-storage1/datasets/NLST/full_nlst_google.json",
192
        help="Path to dataset file either as json or csv",
193
    )
194
    parser.add_argument(
195
        "--num_classes", type=int, default=6, help="Number of classes to predict"
196
    )
197
198
    # Alternative training/testing schemes
199
    parser.add_argument(
200
        "--cross_val_seed",
201
        type=int,
202
        default=0,
203
        help="Seed used to generate the partition.",
204
    )
205
    parser.add_argument(
206
        "--assign_splits",
207
        action="store_true",
208
        default=False,
209
        help="Whether to assign different splits than those predetermined in dataset",
210
    )
211
    parser.add_argument(
212
        "--split_type",
213
        type=str,
214
        default="random",
215
        choices=["random", "institution_split"],
216
        help="How to split dataset if assign_split = True. Usage: ['random', 'institution_split'].",
217
    )
218
    parser.add_argument(
219
        "--split_probs",
220
        type=float,
221
        nargs="+",
222
        default=[0.6, 0.2, 0.2],
223
        help="Split probs for datasets without fixed train dev test. ",
224
    )
225
226
    # survival analysis setup
227
    parser.add_argument(
228
        "--max_followup", type=int, default=6, help="Max followup to predict over"
229
    )
230
231
    # risk factors
232
    parser.add_argument(
233
        "--use_risk_factors",
234
        action="store_true",
235
        default=False,
236
        help="Whether to feed risk factors into last FC of model.",
237
    )  #
238
    parser.add_argument(
239
        "--risk_factor_keys",
240
        nargs="*",
241
        default=[],
242
        help="List of risk factors to include in risk factor vector.",
243
    )
244
245
    # handling CT slices
246
    parser.add_argument(
247
        "--resample_pixel_spacing_prob",
248
        type=float,
249
        default=1,
250
        help="Probability of resampling pixel spacing into fixed dimensions. 1 when eval and using resampling",
251
    )
252
    parser.add_argument(
253
        "--num_images",
254
        type=int,
255
        default=200,
256
        help="In multi image setting, the number of images per single sample.",
257
    )
258
    parser.add_argument(
259
        "--min_num_images",
260
        type=int,
261
        default=0,
262
        help="In multi image setting, the min number of images per single sample.",
263
    )
264
    parser.add_argument(
265
        "--slice_thickness_filter",
266
        type=float,
267
        help="Slice thickness using, if restricting to specific thickness value.",
268
    )
269
    parser.add_argument(
270
        "--use_only_thin_cuts_for_ct",
271
        action="store_true",
272
        default=False,
273
        help="Wether to use image series with thinnest cuts only.",
274
    )
275
276
    # region annotations
277
    parser.add_argument(
278
        "--use_annotations",
279
        action="store_true",
280
        default=False,
281
        help="whether to use image annotations (pixel labels) in modeling",
282
    )
283
284
    parser.add_argument(
285
        "--region_annotations_filepath", type=str, help="Path to annotations file"
286
    )
287
    parser.add_argument(
288
        "--annotation_loss_lambda",
289
        type=float,
290
        default=1,
291
        help="Weight of annotation losses",
292
    )
293
    parser.add_argument(
294
        "--image_attention_loss_lambda",
295
        type=float,
296
        default=1,
297
        help="Weight of loss for predicting image attention scores",
298
    )
299
    parser.add_argument(
300
        "--volume_attention_loss_lambda",
301
        type=float,
302
        default=1,
303
        help="Weight of loss for predicting volume attention scores",
304
    )
305
306
    # regularization
307
    parser.add_argument(
308
        "--primary_loss_lambda",
309
        type=float,
310
        default=1.0,
311
        help="Lambda to weigh the primary loss.",
312
    )
313
    parser.add_argument(
314
        "--adv_loss_lambda",
315
        type=float,
316
        default=1.0,
317
        help="Lambda to weigh the adversary loss.",
318
    )
319
320
    # learning
321
    parser.add_argument(
322
        "--batch_size",
323
        type=int,
324
        default=32,
325
        help="Batch size for training [default: 128]",
326
    )
327
    parser.add_argument(
328
        "--init_lr",
329
        type=float,
330
        default=0.001,
331
        help="Initial learning rate [default: 0.001]",
332
    )
333
    parser.add_argument(
334
        "--dropout",
335
        type=float,
336
        default=0.25,
337
        help="Amount of dropout to apply on last hidden layer [default: 0.25]",
338
    )
339
    parser.add_argument(
340
        "--optimizer", type=str, default="adam", help="Optimizer to use [default: adam]"
341
    )
342
    parser.add_argument(
343
        "--momentum", type=float, default=0, help="Momentum to use with SGD"
344
    )
345
    parser.add_argument(
346
        "--lr_decay",
347
        type=float,
348
        default=0.1,
349
        help="Initial learning rate [default: 0.5]",
350
    )
351
    parser.add_argument(
352
        "--weight_decay",
353
        type=float,
354
        default=0,
355
        help="L2 Regularization penaty [default: 0]",
356
    )
357
    parser.add_argument(
358
        "--adv_lr",
359
        type=float,
360
        default=0.001,
361
        help="Initial learning rate for adversary model [default: 0.001]",
362
    )
363
364
    # schedule
365
    parser.add_argument(
366
        "--patience",
367
        type=int,
368
        default=5,
369
        help="Number of epochs without improvement on dev before halving learning rate and reloading best model [default: 5]",
370
    )
371
    parser.add_argument(
372
        "--num_adv_steps",
373
        type=int,
374
        default=1,
375
        help="Number of steps for domain adaptation discriminator per one step of encoding model [default: 5]",
376
    )
377
    parser.add_argument(
378
        "--tuning_metric",
379
        type=str,
380
        default="c_index",
381
        help="Criterion based on which model is saved [default: c_index]",
382
    )
383
384
    # model checkpointing
385
    parser.add_argument(
386
        "--turn_off_checkpointing",
387
        action="store_true",
388
        default=False,
389
        help="Do not save best model",
390
    )
391
392
    parser.add_argument(
393
        "--save_dir", type=str, default="snapshot", help="Where to dump the model"
394
    )
395
396
    parser.add_argument(
397
        "--snapshot",
398
        type=str,
399
        default=None,
400
        help="Filename of model snapshot to load[default: None]",
401
    )
402
403
    # system
404
    parser.add_argument(
405
        "--num_workers",
406
        type=int,
407
        default=8,
408
        help="Num workers for each data loader [default: 4]",
409
    )
410
411
    # storing results
412
    parser.add_argument(
413
        "--store_hiddens",
414
        action="store_true",
415
        default=False,
416
        help="Save hidden repr from each image to an npz based off results path, git hash and exam name",
417
    )
418
    parser.add_argument(
419
        "--save_predictions",
420
        action="store_true",
421
        default=False,
422
        help="Save hidden repr from each image to an npz based off results path, git hash and exam name",
423
    )
424
    parser.add_argument(
425
        "--hiddens_dir",
426
        type=str,
427
        default="hiddens/test_run",
428
        help='Dir to store hiddens npy"s when store_hiddens is true',
429
    )
430
    parser.add_argument(
431
        "--save_attention_scores",
432
        action="store_true",
433
        default=False,
434
        help="Whether to save attention scores when using attention mechanism",
435
    )
436
    parser.add_argument(
437
        "--results_path",
438
        type=str,
439
        default="logs/test.args",
440
        help="Where to save the result logs",
441
    )
442
443
    # cache
444
    parser.add_argument(
445
        "--cache_path", type=str, default=None, help="Dir to cache images."
446
    )
447
    parser.add_argument(
448
        "--cache_full_img",
449
        action="store_true",
450
        default=False,
451
        help="Cache full image locally as well as cachable transforms",
452
    )
453
454
    # run
455
    parser = Trainer.add_argparse_args(parser)
456
    if args_strings is None:
457
        args = parser.parse_args()
458
    else:
459
        args = parser.parse_args(args_strings)
460
    args.lr = args.init_lr
461
462
    if (isinstance(args.gpus, str) and len(args.gpus.split(",")) > 1) or (
463
        isinstance(args.gpus, int) and args.gpus > 1
464
    ):
465
        args.accelerator = "ddp"
466
        args.replace_sampler_ddp = False
467
    else:
468
        args.accelerator = None
469
        args.replace_sampler_ddp = False
470
471
    args.unix_username = pwd.getpwuid(os.getuid())[0]
472
473
    # learning initial state
474
    args.step_indx = 1
475
476
    return args