--- a +++ b/train.py @@ -0,0 +1,141 @@ +import os +import torch +import argparse +import torch.optim as optim +from tqdm import tqdm +from torch.utils.data import DataLoader +from torchvision import transforms +from networks.vnet import VNet +from loss import Loss,cal_dice +from dataloaders.la_heart import LAHeart, RandomCrop, CenterCrop, RandomRotFlip, ToTensor + + +def train_loop(model, optimizer, criterion, train_loader, device): + model.train() + running_loss = 0 + pbar = tqdm(train_loader) + dice_train = 0 + + for sampled_batch in pbar: + volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] + volume_batch, label_batch = volume_batch.to(device), label_batch.to(device) + # print(volume_batch.shape,label_batch.shape) + outputs = model(volume_batch) + # print(outputs.shape) + loss = criterion(outputs, label_batch) + dice = cal_dice(outputs, label_batch) + dice_train += dice.item() + pbar.set_postfix(loss="{:.3f}".format(loss.item()), dice="{:.3f}".format(dice.item())) + + running_loss += loss.item() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + loss = running_loss / len(train_loader) + dice = dice_train / len(train_loader) + return {'loss': loss, 'dice': dice} + + +def eval_loop(model, criterion, valid_loader, device): + model.eval() + running_loss = 0 + pbar = tqdm(valid_loader) + dice_valid = 0 + + with torch.no_grad(): + for sampled_batch in pbar: + volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] + volume_batch, label_batch = volume_batch.to(device), label_batch.to(device) + + outputs = model(volume_batch) + + loss = criterion(outputs, label_batch) + dice = cal_dice(outputs, label_batch) + running_loss += loss.item() + dice_valid += dice.item() + pbar.set_postfix(loss="{:.3f}".format(loss.item()), dice="{:.3f}".format(dice.item())) + + loss = running_loss / len(valid_loader) + dice = dice_valid / len(valid_loader) + return {'loss': loss, 'dice': dice} + + +def train(args, model, optimizer, criterion, train_loader, valid_loader, epochs, + device, train_log, loss_min=999.0): + for e in range(epochs): + # train for epoch + train_metrics = train_loop(model, optimizer, criterion, train_loader, device) + valid_metrics = eval_loop(model, criterion, valid_loader, device) + + # eval for epoch + info1 = "Epoch:[{}/{}] train_loss: {:.3f} valid_loss: {:.3f}".format(e + 1, epochs, train_metrics["loss"], + valid_metrics['loss']) + info2 = "train_dice: {:.3f} valid_dice: {:.3f}".format(train_metrics['dice'], valid_metrics['dice']) + + print(info1 + '\n' + info2) + with open(train_log, 'a') as f: + f.write(info1 + '\n' + info2 + '\n') + + if valid_metrics['loss'] < loss_min: + loss_min = valid_metrics['loss'] + torch.save(model.state_dict(), args.save_path) + print("Finished Training!") + + +def main(args): + torch.manual_seed(args.seed) # 为CPU设置种子用于生成随机数,以使得结果是确定的 + torch.cuda.manual_seed_all(args.seed) # 为所有的GPU设置种子,以使得结果是确定的 + + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True + os.environ['CUDA_VISIBLE_DEVICES'] = '0' + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # data info + db_train = LAHeart(base_dir=args.train_path, + split='train', + transform=transforms.Compose([ + RandomRotFlip(), + RandomCrop(args.patch_size), + ToTensor(), + ])) + db_test = LAHeart(base_dir=args.train_path, + split='test', + transform=transforms.Compose([ + CenterCrop(args.patch_size), + ToTensor() + ])) + print('Using {} images for training, {} images for testing.'.format(len(db_train), len(db_test))) + trainloader = DataLoader(db_train, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, + drop_last=True) + testloader = DataLoader(db_test, batch_size=1, num_workers=4, pin_memory=True) + model = VNet(n_channels=1,n_classes=args.num_classes, normalization='batchnorm', has_dropout=True).to(device) + + criterion = Loss(n_classes=args.num_classes).to(device) + optimizer = optim.SGD(model.parameters(), momentum=0.9, lr=args.lr, weight_decay=1e-4) + + # 加载训练模型 + if os.path.exists(args.weight_path): + weight_dict = torch.load(args.weight_path, map_location=device) + model.load_state_dict(weight_dict) + print('Successfully loading checkpoint.') + + train(args, model, optimizer, criterion, trainloader, testloader, args.epochs, device, train_log=args.train_log) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--num_classes', type=int, default=2) + parser.add_argument('--seed', type=int, default=21) + parser.add_argument('--epochs', type=int, default=160) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--lr', type=float, default=0.01) + parser.add_argument('--patch_size', type=float, default=(112, 112, 80)) + parser.add_argument('--train_path', type=str, default='/***data_set/LASet/data') + parser.add_argument('--train_log', type=str, default='results/VNet_sup.txt') + parser.add_argument('--weight_path', type=str, default='results/VNet_sup.pth') # 加载 + parser.add_argument('--save_path', type=str, default='results/VNet_sup.pth') # 保存 + args = parser.parse_args() + + main(args) \ No newline at end of file