Diff of /config_segmentation.py [000000] .. [a49583]

Switch to unified view

a b/config_segmentation.py
1
# SEGMENTATION MODEL CONFIG
2
import argparse
3
4
import utils
5
6
7
# test: get inference on new data
8
# precision: compute MS COCO mean average precision
9
def get_config_pars(stage):
10
    parser_ = argparse.ArgumentParser(
11
        description='arguments for training Mask R-CNN segmentation for 2 classes '
12
                    '(Ground Glass Opacity and Consolidation) on CNCB dataset')
13
14
    parser_.add_argument("--backbone_name", type=str, default='resnet50', help="One of resnet50, resnet34, reesnet18")
15
    parser_.add_argument("--truncation", type=str, default="0", help="One of 0,1,2 for no truncation, last block, two last blocks")
16
    parser_.add_argument("--device", type=str, default='cpu')
17
    parser_.add_argument("--model_name", type=str, default=None)
18
    parser_.add_argument("--rpn_nms_th", type=float, default=0.75, help="Both at train and test stages.")
19
    parser_.add_argument("--mask_type", type=str, default=None, help="currently accepts three types: both"
20
                                                                     "for 2 classes, GGO and C, ggo for only "
21
                                                                     "GGO mask, and merge for merged GGO and C "
22
                                                                     "classes. These inputs are used both in the "
23
                                                                     "dataset and model interfaces.")
24
25
    if stage == "trainval":
26
        parser_.add_argument("--start_epoch", type=str, default=0)
27
        parser_.add_argument("--pretrained_model", type=str, help="Pretrained model, must be a checkpoint with keys:"
28
                                                       "model_weights, anchor_generator, optimizer_state, model_name",
29
                             default=None)
30
        parser_.add_argument("--num_epochs", type=int, default=50)
31
        parser_.add_argument("--use_pretrained_resnet_backbone", type=utils.str_to_bool, default=False,
32
                             help="Use the ResNetw/FPN weights from Torchvision repository")
33
        parser_.add_argument("--save_dir", type=str, default="saved_models",
34
                             help="Directory to save checkpoints")
35
        parser_.add_argument("--train_data_dir", type=str, default='../covid_data/train',
36
                             help="Path to the training data. Must contain images and binary masks")
37
        parser_.add_argument("--val_data_dir", type=str, default='../covid_data/test',
38
                             help="Path to the validation data. Must contain images and binary masks")
39
        parser_.add_argument("--gt_dir", type=str, default='masks',
40
                             help="Path to directory with binary masks. Must be in the data directory.")
41
        parser_.add_argument("--imgs_dir", type=str, default='imgs')
42
        parser_.add_argument("--batch_size", type=int, default=1, help="Implemented only for batch size = 1")
43
        parser_.add_argument("--save_every", type=int, default=10)
44
        parser_.add_argument("--lrate", type=float, default=1e-5, help="Learning rate")
45
46
    elif stage == "test" or stage == "precision":
47
        parser_.add_argument("--ckpt", type=str,
48
                             help="Checkpoint file in .pth format. "
49
                                  "Must contain the following keys: model_weights, optimizer_state, anchor_generator")
50
        parser_.add_argument("--test_data_dir", type=str, default='../covid_data/test',
51
                             help="Path to the test data. Must contain images and may contain binary masks")
52
        parser_.add_argument("--test_imgs_dir", type=str, default='imgs', help="Directory with images. "
53
                                                                               "Must be in the test data directory.")
54
        parser_.add_argument("--save_dir", type=str, default='eval_imgs',
55
                             help="Directory to save segmentation results.")
56
        parser_.add_argument("--confidence_th", type=float, default=0.05, help="Lower confidence score threshold on "
57
                                                                               "positive prediction from RoI.")
58
        parser_.add_argument("--mask_logits_th", type=float, default=0.5, help="Lower threshold for positve mask "
59
                                                                               "prediction at "
60
                                                                               "pixel level.")
61
        parser_.add_argument("--gt_dir", type=str, default='masks',
62
                             help="Path to directory with binary masks. Must be in the data directory. If the stage == "
63
                                  "'precision', this path must be provided.")
64
        parser_.add_argument("--roi_nms_th", type=float, default=0.5, help="Only at test stage.")
65
    model_args = parser_.parse_args()
66
    return model_args