Diff of /config_classifier.py [000000] .. [a49583]

Switch to unified view

a b/config_classifier.py
1
# SEGMENTATION MODEL CONFIG
2
import argparse
3
4
import utils
5
6
7
def get_config_pars_classifier(stage):
8
    parser_ = argparse.ArgumentParser(
9
        description='arguments for training COVID-CT-Mask-Net on CNCB dataset')
10
11
    parser_.add_argument("--device", type=str, default='cpu')
12
    parser_.add_argument("--model_name", type=str, default=None)
13
    parser_.add_argument("--num_classes", type=int, default=None, help="This refers to the number of classes in the segmentation mode, so either 2 or 3")
14
    parser_.add_argument("--rpn_nms_th", type=float, default=0.75, help="Both at train and test stages")
15
    parser_.add_argument("--roi_nms_th", type=float, default=0.75, help="Both at train and test stages")
16
    parser_.add_argument("--backbone_name", type=str, default='resnet50', help="One of resnet50, resnet34, reesnet18")
17
    parser_.add_argument("--truncation", type=str, default="0", help="One of 0,1,2 for no truncation, last block, two last blocks")
18
    parser_.add_argument("--roi_batch_size", type=int, default=256, help="RoI batch size output, input in the S classifier module")
19
    parser_.add_argument("--s_features", type=int, default=None, help="Number of features in the S classification module")
20
21
    if stage == "trainval":
22
        parser_.add_argument("--start_epoch", type=str, default=0)
23
        parser_.add_argument("--num_epochs", type=int, default=50)
24
        parser_.add_argument("--pretrained_classification_model", type=str, default=None)
25
        parser_.add_argument("--pretrained_segmentation_model", type=str, default=None,
26
                             help="Either this or pretrained classifier must be defined!")
27
        parser_.add_argument("--save_dir", type=str, default="saved_models",
28
                             help="Directory to save checkpoints")
29
        parser_.add_argument("--train_data_dir", type=str, default='../covid_data/cncb/train_large',
30
                             help="Path to the training data. Must contain images and binary masks")
31
        parser_.add_argument("--val_data_dir", type=str, default='../covid_data/cncb/new_/val',
32
                             help="Path to the validation data. Must contain images and binary masks")
33
        parser_.add_argument("--batch_size", type=int, default=8, help="Implemented only for batch size = 1")
34
        parser_.add_argument("--save_every", type=int, default=10)
35
        parser_.add_argument("--lrate", type=float, default=1e-5, help="Learning rate")
36
37
    elif stage == "test":
38
        parser_.add_argument("--ckpt", type=str,
39
                             help="Checkpoint file in .pth format. "
40
                                  "Must contain the following keys: model_weights, optimizer_state, anchor_generator")
41
        parser_.add_argument("--test_data_dir", type=str, default='../covid_data/cncb/ncov-ai.big.ac.cn/download/test',
42
                             help="Path to the test data. Must contain images and may contain binary masks")
43
44
    model_args = parser_.parse_args()
45
    return model_args