|
a |
|
b/config/config.py |
|
|
1 |
import os |
|
|
2 |
from config.augm import train_transform, val_transform |
|
|
3 |
from config.paths import train_images_folder, train_labels_folder, train_images, train_labels |
|
|
4 |
from semseg.data_loader import SemSegConfig |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
class SemSegMRIConfig(SemSegConfig): |
|
|
8 |
train_images = [os.path.join(train_images_folder, train_image) |
|
|
9 |
for train_image in train_images] |
|
|
10 |
train_labels = [os.path.join(train_labels_folder, train_label) |
|
|
11 |
for train_label in train_labels] |
|
|
12 |
val_images = None |
|
|
13 |
val_labels = None |
|
|
14 |
do_normalize = True |
|
|
15 |
batch_size = 16 |
|
|
16 |
num_workers = 0 |
|
|
17 |
pad_ref = (48, 64, 48) |
|
|
18 |
lr = 0.01 |
|
|
19 |
epochs = 100 |
|
|
20 |
low_lr_epoch = epochs // 3 |
|
|
21 |
val_epochs = epochs // 5 |
|
|
22 |
cuda = True |
|
|
23 |
num_outs = 3 |
|
|
24 |
do_crossval = True |
|
|
25 |
num_folders = 5 |
|
|
26 |
num_channels = 8 |
|
|
27 |
transform_train = train_transform |
|
|
28 |
transform_val = val_transform |
|
|
29 |
net = "vnet" |
|
|
30 |
|
|
|
31 |
|
|
|
32 |
LEARNING_RATE_REDUCTION_FACTOR = 2 |