|
a |
|
b/config_segmentation.py |
|
|
1 |
# SEGMENTATION MODEL CONFIG |
|
|
2 |
import argparse |
|
|
3 |
|
|
|
4 |
import utils |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
# test: get inference on new data |
|
|
8 |
# precision: compute MS COCO mean average precision |
|
|
9 |
def get_config_pars(stage): |
|
|
10 |
parser_ = argparse.ArgumentParser( |
|
|
11 |
description='arguments for training Mask R-CNN segmentation for 2 classes ' |
|
|
12 |
'(Ground Glass Opacity and Consolidation) on CNCB dataset') |
|
|
13 |
|
|
|
14 |
parser_.add_argument("--backbone_name", type=str, default='resnet50', help="One of resnet50, resnet34, reesnet18") |
|
|
15 |
parser_.add_argument("--truncation", type=str, default="0", help="One of 0,1,2 for no truncation, last block, two last blocks") |
|
|
16 |
parser_.add_argument("--device", type=str, default='cpu') |
|
|
17 |
parser_.add_argument("--model_name", type=str, default=None) |
|
|
18 |
parser_.add_argument("--rpn_nms_th", type=float, default=0.75, help="Both at train and test stages.") |
|
|
19 |
parser_.add_argument("--mask_type", type=str, default=None, help="currently accepts three types: both" |
|
|
20 |
"for 2 classes, GGO and C, ggo for only " |
|
|
21 |
"GGO mask, and merge for merged GGO and C " |
|
|
22 |
"classes. These inputs are used both in the " |
|
|
23 |
"dataset and model interfaces.") |
|
|
24 |
|
|
|
25 |
if stage == "trainval": |
|
|
26 |
parser_.add_argument("--start_epoch", type=str, default=0) |
|
|
27 |
parser_.add_argument("--pretrained_model", type=str, help="Pretrained model, must be a checkpoint with keys:" |
|
|
28 |
"model_weights, anchor_generator, optimizer_state, model_name", |
|
|
29 |
default=None) |
|
|
30 |
parser_.add_argument("--num_epochs", type=int, default=50) |
|
|
31 |
parser_.add_argument("--use_pretrained_resnet_backbone", type=utils.str_to_bool, default=False, |
|
|
32 |
help="Use the ResNetw/FPN weights from Torchvision repository") |
|
|
33 |
parser_.add_argument("--save_dir", type=str, default="saved_models", |
|
|
34 |
help="Directory to save checkpoints") |
|
|
35 |
parser_.add_argument("--train_data_dir", type=str, default='../covid_data/train', |
|
|
36 |
help="Path to the training data. Must contain images and binary masks") |
|
|
37 |
parser_.add_argument("--val_data_dir", type=str, default='../covid_data/test', |
|
|
38 |
help="Path to the validation data. Must contain images and binary masks") |
|
|
39 |
parser_.add_argument("--gt_dir", type=str, default='masks', |
|
|
40 |
help="Path to directory with binary masks. Must be in the data directory.") |
|
|
41 |
parser_.add_argument("--imgs_dir", type=str, default='imgs') |
|
|
42 |
parser_.add_argument("--batch_size", type=int, default=1, help="Implemented only for batch size = 1") |
|
|
43 |
parser_.add_argument("--save_every", type=int, default=10) |
|
|
44 |
parser_.add_argument("--lrate", type=float, default=1e-5, help="Learning rate") |
|
|
45 |
|
|
|
46 |
elif stage == "test" or stage == "precision": |
|
|
47 |
parser_.add_argument("--ckpt", type=str, |
|
|
48 |
help="Checkpoint file in .pth format. " |
|
|
49 |
"Must contain the following keys: model_weights, optimizer_state, anchor_generator") |
|
|
50 |
parser_.add_argument("--test_data_dir", type=str, default='../covid_data/test', |
|
|
51 |
help="Path to the test data. Must contain images and may contain binary masks") |
|
|
52 |
parser_.add_argument("--test_imgs_dir", type=str, default='imgs', help="Directory with images. " |
|
|
53 |
"Must be in the test data directory.") |
|
|
54 |
parser_.add_argument("--save_dir", type=str, default='eval_imgs', |
|
|
55 |
help="Directory to save segmentation results.") |
|
|
56 |
parser_.add_argument("--confidence_th", type=float, default=0.05, help="Lower confidence score threshold on " |
|
|
57 |
"positive prediction from RoI.") |
|
|
58 |
parser_.add_argument("--mask_logits_th", type=float, default=0.5, help="Lower threshold for positve mask " |
|
|
59 |
"prediction at " |
|
|
60 |
"pixel level.") |
|
|
61 |
parser_.add_argument("--gt_dir", type=str, default='masks', |
|
|
62 |
help="Path to directory with binary masks. Must be in the data directory. If the stage == " |
|
|
63 |
"'precision', this path must be provided.") |
|
|
64 |
parser_.add_argument("--roi_nms_th", type=float, default=0.5, help="Only at test stage.") |
|
|
65 |
model_args = parser_.parse_args() |
|
|
66 |
return model_args |