Diff of /train.py [000000] .. [70e190]

Switch to unified view

a b/train.py
1
import argparse
2
import os
3
from glob import glob
4
import pandas as pd
5
import yaml
6
from utils import str2bool, write_csv
7
from collections import OrderedDict
8
from sklearn.model_selection import train_test_split
9
10
from trainer import trainer, validate
11
from dataset import CustomDataset
12
import torch
13
from torch.utils.data import DataLoader
14
import torch.optim as optim
15
from torch.nn.modules.loss import CrossEntropyLoss
16
from metrics import Dice, IOU, HD
17
18
from networks.RotCAtt_TransUNet_plusplus.RotCAtt_TransUNet_plusplus import RotCAtt_TransUNet_plusplus
19
from networks.RotCAtt_TransUNet_plusplus.config import get_config as rot_config
20
21
22
23
def parse_args():     
24
    
25
    # Training pipeline
26
    parser = argparse.ArgumentParser()
27
    parser.add_argument('--name', default=None, help='model name')
28
    parser.add_argument('--pretrained', default=False,
29
                        help='pretrained or not (default: False)')
30
    parser.add_argument('--epochs', default=600, type=int, metavar='N',
31
                        help='number of epochs for training')
32
    parser.add_argument('--batch_size', default=6, type=int, metavar='N',
33
                        help='mini-batch size')
34
    parser.add_argument('--seed', type=int, default=1234, help='random seed')
35
    parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
36
    parser.add_argument('--num_workers', default=3, type=int)
37
    parser.add_argument('--val_mode', default=True, type=str2bool)
38
    
39
    # Network
40
    parser.add_argument('--network', default='RotCAtt_TransUNet_plusplus') 
41
    parser.add_argument('--input_channels', default=1, type=int,
42
                        help='input channels')
43
    parser.add_argument('--patch_size', default=16, type=int,
44
                        help='input patch size')
45
    parser.add_argument('--num_classes', default=12, type=int,
46
                        help='number of classes')
47
    parser.add_argument('--img_size', default=512, type=int,
48
                        help='input image img_size')
49
    
50
    # Dataset
51
    parser.add_argument('--dataset', default='VHSCDD', help='dataset name')
52
    parser.add_argument('--ext', default='.npy', help='file extension')
53
    parser.add_argument('--range', default=None, type=int, help='dataset size')
54
    
55
     # Criterion
56
    parser.add_argument('--loss', default='Dice Iou Cross entropy')
57
    
58
    # Optimizer
59
    parser.add_argument('--optimizer', default='SGD', choices=['Adam', 'SGD'],
60
                        help='optimizer: ' + ' | '.join(['Adam', 'SGD']) 
61
                        + 'default (Adam)')
62
    parser.add_argument('--base_lr', '--learning_rate', default=0.01, type=float,
63
                        metavar='LR', help='initial learning rate')
64
    parser.add_argument('--momentum', default=0.9, type=float,
65
                        help='momentum')
66
    parser.add_argument('--weight_decay', default=0.0001, type=float,
67
                        help='weight decay')
68
    parser.add_argument('--nesterov', default=False, type=str2bool,
69
                        help='nesterov')
70
    
71
    # scheduler
72
    parser.add_argument('--scheduler', default='CosineAnnealingLR',
73
                        choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 
74
                                 'MultiStepLR', 'ConstantLR'])
75
    parser.add_argument('--min_lr', default=1e-5, type=float,
76
                        help='minimum learning rate')
77
    parser.add_argument('--factor', default=0.1, type=float)
78
    parser.add_argument('--patience', default=2, type=int)
79
    parser.add_argument('--milestones', default='1,2', type=str)
80
    parser.add_argument('--gamma', default=2/3, type=float)
81
    parser.add_argument('--early_stopping', default=-1, type=int,
82
                        metavar='N', help='early stopping (default: -1)')
83
    
84
    return parser.parse_args()
85
86
def output_config(config):
87
    print('-' * 20)
88
    for key in config:
89
        print(f'{key}: {config[key]}')
90
    print('-' * 20)   
91
   
92
93
def loading_2D_data(config):
94
    image_paths = glob(f"data/{config.dataset}/images/*.npy")
95
    label_paths = glob(f"data/{config.dataset}/labels/*.npy")
96
    
97
    if config.range != None: 
98
        image_paths = image_paths[:config.range]
99
        label_paths = label_paths[:config.range]
100
    
101
    train_image_paths, val_image_paths, train_label_paths, val_label_paths = train_test_split(image_paths, label_paths, test_size=0.2, random_state=41)
102
    train_ds = CustomDataset(config.num_classes, train_image_paths, train_label_paths, img_size=config.img_size)
103
    val_ds = CustomDataset(config.num_classes, val_image_paths, val_label_paths, img_size=config.img_size)
104
105
    train_loader = DataLoader(
106
        train_ds,
107
        batch_size=config.batch_size,
108
        shuffle=False,
109
        num_workers=config.num_workers,
110
        drop_last=False,
111
    )
112
    
113
    val_loader = DataLoader(
114
        val_ds,
115
        batch_size=config.batch_size,
116
        shuffle=False,
117
        num_workers=config.num_workers,
118
        drop_last=False,
119
    )
120
    return train_loader, val_loader
121
122
    
123
def load_pretrained_model(model_path):
124
    if os.path.exists(model_path):
125
        model = torch.load(model_path)
126
        return model
127
    else:
128
        print("No pretrained exists")
129
        exit()
130
        
131
132
def load_network(config):
133
    if config.network == 'RotCAtt_TransUNet_plusplus':
134
        model_config = rot_config()
135
        model_config.img_size = config.img_size
136
        model_config.num_classes = config.num_classes
137
        model = RotCAtt_TransUNet_plusplus(config=model_config).cuda()
138
        
139
    else:
140
        print("Add the custom network to the training pipeline please")
141
        exit(1)
142
        
143
    return model
144
145
 
146
def rlog(value):
147
    return round(value, 3)
148
        
149
def train(config):
150
    config_dict = vars(config)
151
    print(config.network)
152
    
153
    # Config name
154
    config.name = f"{config.dataset}_{config.network}_bs{config.batch_size}_ps{config.patch_size}_epo{config.epochs}_hw{config.img_size}"
155
156
    # Model
157
    print(f"=> Initialize model: {config.network}")
158
    if config.pretrained == False: 
159
        model = load_network(config)
160
        output_config(config_dict)
161
        print(f"=> Initialize output: {config.name}")
162
        model_path = f"outputs/{config.name}"
163
        if not os.path.exists(model_path):
164
            os.makedirs(model_path)
165
            with open(f"{model_path}/config.yml", "w") as f:
166
                yaml.dump(config_dict, f)
167
                
168
    else: model = load_pretrained_model(f'outputs/{config.name}/model.pth')
169
    
170
    # Data loading
171
    if config.dataset == 'VHSCDD': config.dataset += f'_{config.img_size}'
172
    train_loader, val_loader = loading_2D_data(config)
173
    
174
    # logging
175
    log = OrderedDict([
176
        ('epoch', []),                                          # 0
177
        ('lr', []),                                             # 1
178
        
179
        ('Train loss', []),                                     # 2
180
        ('Train ce loss', []),                                  # 3
181
        ('Train dice score', []),                               # 4
182
        ('Train dice loss', []),                                # 5
183
        ('Train iou score', []),                                # 6
184
        ('Train iou loss', []),                                 # 7
185
        ('Train hausdorff', []),                                # 8
186
        
187
        ('Val loss', []),                                       # 8
188
        ('Val ce loss', []),                                    # 9
189
        ('Val dice score', []),                                 # 10
190
        ('Val dice loss', []),                                  # 11
191
        ('Val iou score', []),                                  # 12
192
        ('Val iou loss', []),                                   # 13
193
        ('Val hausdorff', []),                                  # 14
194
    ])
195
    
196
    if config.pretrained: 
197
        pre_log = pd.read_csv(f'outputs/{config.name}/epo_log.csv')
198
        print(pre_log)
199
        log = OrderedDict((key, []) for key in pre_log.keys())
200
        for column in pre_log.columns:
201
           log[column] = pre_log[column].tolist()
202
203
    # Optimizer
204
    params = filter(lambda p: p.requires_grad, model.parameters())
205
    if config.optimizer == 'Adam':
206
        optimizer = optim.Adam(params, lr=config.base_lr, weight_decay=config.weight_decay)
207
    elif config.optimizer == 'SGD':
208
        optimizer = optim.SGD(params, lr=config.base_lr, momentum=config.momentum,
209
        nesterov=config.nesterov, weight_decay=config.weight_decay)
210
        
211
    # Criterion
212
    ce = CrossEntropyLoss()
213
    dice = Dice(config.num_classes)
214
    iou = IOU(config.num_classes)
215
    hd = HD()
216
    
217
    # Training loop
218
    best_train_iou = 0
219
    best_train_dice_score = 0
220
    best_val_iou = 0
221
    best_val_dice_score = 0
222
    
223
    fieldnames = ['CE Loss', 'Dice Score', 'Dice Loss', 'IoU Score', 'IoU Loss', 'HausDorff Distance', 'Total Loss']
224
    iter_log_file = f'outputs/{config.name}/iter_log.csv'
225
    if not os.path.exists(iter_log_file): 
226
        write_csv(iter_log_file, fieldnames)
227
        
228
    for epoch in range(config.epochs):
229
        print(f"Epoch: {epoch+1}/{config.epochs}")
230
        train_log = trainer(config, train_loader, optimizer, model, ce, dice, iou, hd)
231
        if config.val_mode: val_log = validate(config, val_loader, model, ce, dice, iou, hd)
232
        
233
        print(f"Train loss: {rlog(train_log['loss'])} - Train ce loss: {rlog(train_log['ce_loss'])} - Train dice score: {rlog(train_log['dice_score'])} - Train dice loss: {rlog(train_log['dice_loss'])} - Train iou Score: {rlog(train_log['iou_score'])} - Train iou loss: {rlog(train_log['iou_loss'])} - Train hausdorff: {rlog(train_log['hausdorff'])}")
234
        if config.val_mode: print(f"Val loss: {rlog(val_log['loss'])} - Val ce loss: {rlog(val_log['ce_loss'])} - Val dice score: {rlog(val_log['dice_score'])} - Val dice loss: {rlog(val_log['dice_loss'])} - Val iou Score: {rlog(val_log['iou_score'])} - Val iou loss: {rlog(val_log['iou_loss'])} - Val hausdorff: {rlog(val_log['hausdorff'])}")
235
        
236
        log['epoch'].append(epoch)
237
        log['lr'].append(config.base_lr)
238
        
239
        log['Train loss'].append(train_log['loss'])
240
        log['Train ce loss'].append(train_log['ce_loss'])
241
        log['Train dice score'].append(train_log['dice_score'])
242
        log['Train dice loss'].append(train_log['dice_loss'])
243
        log['Train iou score'].append(train_log['iou_score'])
244
        log['Train iou loss'].append(train_log['iou_loss']) 
245
        log['Train hausdorff'].append(train_log['hausdorff']) 
246
        
247
        if config.val_mode:
248
            log['Val loss'].append(val_log['loss'])
249
            log['Val ce loss'].append(val_log['ce_loss'])
250
            log['Val dice score'].append(val_log['dice_score'])
251
            log['Val dice loss'].append(val_log['dice_loss'])
252
            log['Val iou score'].append(val_log['iou_score'])
253
            log['Val iou loss'].append(val_log['iou_loss'])
254
            log['Val hausdorff'].append(val_log['hausdorff'])
255
            
256
        else:
257
            log['Val loss'].append(None)
258
            log['Val ce loss'].append(None)
259
            log['Val dice score'].append(None)
260
            log['Val dice loss'].append(None)
261
            log['Val iou score'].append(None)
262
            log['Val iou loss'].append(None)
263
            log['Val hausdorff'].append(None)
264
265
        
266
        pd.DataFrame(log).to_csv(f'outputs/{config.name}/epo_log.csv', index=False)
267
268
        # Save best model
269
        if train_log['iou_score'] > best_train_iou and train_log['dice_score'] > best_train_dice_score and val_log['iou_score'] > best_val_iou and val_log['dice_score'] > best_val_dice_score:
270
                
271
            best_train_iou = train_log['iou_score']
272
            best_train_dice_score = train_log['dice_score']
273
            best_val_iou = val_log['iou_score']
274
            best_val_dice_score = val_log['dice_score']
275
            
276
            torch.save(model, f"outputs/{config.name}/model.pth")
277
            
278
        if (epoch+1) % 1 == 0:
279
            print(f'BEST TRAIN DICE: {best_train_dice_score} - BEST TRAIN IOU: {best_train_iou} - BEST VAL DICE SCORE: {best_val_dice_score} - BEST VAL IOU: {best_val_iou}')
280
            
281
if __name__ == '__main__':
282
    config = parse_args()
283
    train(config)