|
a |
|
b/config_classifier.py |
|
|
1 |
# SEGMENTATION MODEL CONFIG |
|
|
2 |
import argparse |
|
|
3 |
|
|
|
4 |
import utils |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
def get_config_pars_classifier(stage): |
|
|
8 |
parser_ = argparse.ArgumentParser( |
|
|
9 |
description='arguments for training COVID-CT-Mask-Net on CNCB dataset') |
|
|
10 |
|
|
|
11 |
parser_.add_argument("--device", type=str, default='cpu') |
|
|
12 |
parser_.add_argument("--model_name", type=str, default=None) |
|
|
13 |
parser_.add_argument("--num_classes", type=int, default=None, help="This refers to the number of classes in the segmentation mode, so either 2 or 3") |
|
|
14 |
parser_.add_argument("--rpn_nms_th", type=float, default=0.75, help="Both at train and test stages") |
|
|
15 |
parser_.add_argument("--roi_nms_th", type=float, default=0.75, help="Both at train and test stages") |
|
|
16 |
parser_.add_argument("--backbone_name", type=str, default='resnet50', help="One of resnet50, resnet34, reesnet18") |
|
|
17 |
parser_.add_argument("--truncation", type=str, default="0", help="One of 0,1,2 for no truncation, last block, two last blocks") |
|
|
18 |
parser_.add_argument("--roi_batch_size", type=int, default=256, help="RoI batch size output, input in the S classifier module") |
|
|
19 |
parser_.add_argument("--s_features", type=int, default=None, help="Number of features in the S classification module") |
|
|
20 |
|
|
|
21 |
if stage == "trainval": |
|
|
22 |
parser_.add_argument("--start_epoch", type=str, default=0) |
|
|
23 |
parser_.add_argument("--num_epochs", type=int, default=50) |
|
|
24 |
parser_.add_argument("--pretrained_classification_model", type=str, default=None) |
|
|
25 |
parser_.add_argument("--pretrained_segmentation_model", type=str, default=None, |
|
|
26 |
help="Either this or pretrained classifier must be defined!") |
|
|
27 |
parser_.add_argument("--save_dir", type=str, default="saved_models", |
|
|
28 |
help="Directory to save checkpoints") |
|
|
29 |
parser_.add_argument("--train_data_dir", type=str, default='../covid_data/cncb/train_large', |
|
|
30 |
help="Path to the training data. Must contain images and binary masks") |
|
|
31 |
parser_.add_argument("--val_data_dir", type=str, default='../covid_data/cncb/new_/val', |
|
|
32 |
help="Path to the validation data. Must contain images and binary masks") |
|
|
33 |
parser_.add_argument("--batch_size", type=int, default=8, help="Implemented only for batch size = 1") |
|
|
34 |
parser_.add_argument("--save_every", type=int, default=10) |
|
|
35 |
parser_.add_argument("--lrate", type=float, default=1e-5, help="Learning rate") |
|
|
36 |
|
|
|
37 |
elif stage == "test": |
|
|
38 |
parser_.add_argument("--ckpt", type=str, |
|
|
39 |
help="Checkpoint file in .pth format. " |
|
|
40 |
"Must contain the following keys: model_weights, optimizer_state, anchor_generator") |
|
|
41 |
parser_.add_argument("--test_data_dir", type=str, default='../covid_data/cncb/ncov-ai.big.ac.cn/download/test', |
|
|
42 |
help="Path to the test data. Must contain images and may contain binary masks") |
|
|
43 |
|
|
|
44 |
model_args = parser_.parse_args() |
|
|
45 |
return model_args |