Diff of /train_sup.py [000000] .. [903821]

Switch to unified view

a b/train_sup.py
1
import os
2
import sys
3
from sklearn.utils import shuffle
4
from tqdm import tqdm
5
from tensorboardX import SummaryWriter
6
import argparse
7
import logging
8
import time
9
import random
10
11
import torch
12
import torch.optim as optim
13
from torchvision import transforms
14
import torch.nn.functional as F
15
import torch.backends.cudnn as cudnn
16
from torch.utils.data import DataLoader
17
18
from networks.vnet import VNet
19
from utils import ramps, losses
20
from dataloaders.la_heart import *
21
22
parser = argparse.ArgumentParser()
23
parser.add_argument('--dataset_name', type=str, default='LA', help='dataset_name')
24
parser.add_argument('--root_path', type=str, default='/data/omnisky/postgraduate/Yb/data_set/LASet/data',
25
                    help='Name of Experiment')
26
parser.add_argument('--exp', type=str, default='vnet', help='model_name')
27
parser.add_argument('--model', type=str, default='supervised', help='model_name')
28
parser.add_argument('--max_iterations', type=int, default=6000, help='maximum epoch number to train')
29
parser.add_argument('--batch_size', type=int, default=2, help='batch_size per gpu')
30
parser.add_argument('--labelnum', type=int, default=25, help='trained samples')
31
parser.add_argument('--max_samples', type=int, default=123, help='all samples')
32
parser.add_argument('--base_lr', type=float, default=0.01, help='maximum epoch number to train')
33
parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training')
34
parser.add_argument('--seed', type=int, default=1337, help='random seed')
35
parser.add_argument('--gpu', type=str, default='1', help='GPU to use')
36
args = parser.parse_args()
37
38
39
num_classes = 2
40
patch_size = (112, 112, 80)
41
snapshot_path = "model/{}_{}_{}_labeled/{}".format(args.dataset_name, args.exp, args.labelnum, args.model)
42
43
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
44
batch_size = args.batch_size * len(args.gpu.split(','))
45
max_iterations = args.max_iterations
46
base_lr = args.base_lr
47
48
if args.deterministic:
49
    cudnn.benchmark = False
50
    cudnn.deterministic = True
51
    random.seed(args.seed)
52
    np.random.seed(args.seed)
53
    torch.manual_seed(args.seed)
54
    torch.cuda.manual_seed(args.seed)
55
56
57
def cal_dice(output, target, eps=1e-3):
58
    output = torch.argmax(output,dim=1)
59
    inter = torch.sum(output * target) + eps
60
    union = torch.sum(output) + torch.sum(target) + eps * 2
61
    dice = 2 * inter / union
62
    return dice
63
64
65
if __name__ == "__main__":
66
    # make logger file
67
    if not os.path.exists(snapshot_path):
68
        os.makedirs(snapshot_path)
69
70
    logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO,
71
                        format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
72
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
73
    logging.info(str(args))
74
75
    model = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True).cuda()
76
    db_train = LAHeart(base_dir=args.root_path,
77
                       split='train',
78
                       num=args.labelnum,
79
                       transform=transforms.Compose([
80
                           RandomRotFlip(),
81
                           RandomCrop(patch_size),
82
                           ToTensor(),
83
                       ]))
84
    
85
    db_test = LAHeart(base_dir=args.root_path,
86
                      split='test',
87
                      transform=transforms.Compose([
88
                          CenterCrop(patch_size),
89
                          ToTensor()
90
                      ]))
91
92
93
    def worker_init_fn(worker_id):
94
        random.seed(args.seed + worker_id)
95
96
    train_loader = DataLoader(db_train, batch_size=args.batch_size,shuffle=True, num_workers=4, pin_memory=True,drop_last=True,
97
                              worker_init_fn=worker_init_fn)
98
    test_loader = DataLoader(db_test, batch_size=1,shuffle=False, num_workers=4, pin_memory=True)
99
100
    model.train()
101
    optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)
102
103
    writer = SummaryWriter(snapshot_path + '/log')
104
    logging.info("{} itertations per epoch".format(len(train_loader)))
105
106
    iter_num = 0
107
    best_dice = 0
108
    max_epoch = max_iterations // len(train_loader) + 1
109
    lr_ = base_lr
110
    model.train()
111
    for epoch_num in tqdm(range(max_epoch), ncols=70):
112
        time1 = time.time()
113
        for i_batch, sampled_batch in enumerate(train_loader):
114
            time2 = time.time()
115
            # print('fetch data cost {}'.format(time2-time1))
116
            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
117
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
118
            outputs = model(volume_batch)
119
120
            # calculate the loss
121
            loss_seg = F.cross_entropy(outputs, label_batch)
122
            outputs_soft = F.softmax(outputs, dim=1)
123
            loss_seg_dice = losses.dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1)
124
            loss = 0.5 * (loss_seg + loss_seg_dice)
125
126
            optimizer.zero_grad()
127
            loss.backward()
128
            optimizer.step()
129
130
            iter_num = iter_num + 1
131
            writer.add_scalar('lr', lr_, iter_num)
132
            writer.add_scalar('loss/loss', loss, iter_num)
133
            writer.add_scalar('loss/loss_seg', loss_seg, iter_num)
134
            writer.add_scalar('loss/loss_seg_dice', loss_seg_dice, iter_num)
135
136
            logging.info('iteration %d : loss : %f ' % (iter_num, loss.item()))
137
138
            if iter_num >= 800 and iter_num % 200 == 0:
139
                model.eval()
140
                with torch.no_grad():
141
                    dice_sample = 0
142
                    for sampled_batch in test_loader:
143
                        img, lbl = sampled_batch['image'].cuda(), sampled_batch['label'].cuda()
144
                        outputs = model(img)
145
                        dice_once = cal_dice(outputs,lbl)
146
                        dice_sample += dice_once
147
                    dice_sample = dice_sample / len(test_loader)
148
                    print('Average center dice:{:.3f}'.format(dice_sample))
149
                    
150
                if dice_sample > best_dice:
151
                    best_dice = dice_sample
152
                    save_mode_path = os.path.join(snapshot_path, 'iter_{}_dice_{}.pth'.format(iter_num, best_dice))
153
                    save_best_path = os.path.join(snapshot_path, '{}_best_model.pth'.format(args.model))
154
                    torch.save(model.state_dict(), save_mode_path)
155
                    torch.save(model.state_dict(), save_best_path)
156
                    logging.info("save best model to {}".format(save_mode_path))
157
                writer.add_scalar('Var_dice/Dice', dice_sample, iter_num)
158
                writer.add_scalar('Var_dice/Best_dice', best_dice, iter_num)
159
                model.train()
160
161
            if iter_num >= max_iterations:
162
                break
163
            time1 = time.time()
164
        if iter_num >= max_iterations:
165
            break
166
    save_mode_path = os.path.join(snapshot_path, 'iter_' + str(max_iterations) + '.pth')
167
    torch.save(model.state_dict(), save_mode_path)
168
    logging.info("save model to {}".format(save_mode_path))
169
    writer.close()