Diff of /train.py [000000] .. [903821]

Switch to side-by-side view

--- a
+++ b/train.py
@@ -0,0 +1,141 @@
+import os
+import torch
+import argparse
+import torch.optim as optim
+from tqdm import tqdm
+from torch.utils.data import DataLoader
+from torchvision import transforms
+from networks.vnet import VNet
+from loss import Loss,cal_dice
+from dataloaders.la_heart import LAHeart, RandomCrop, CenterCrop, RandomRotFlip, ToTensor
+
+
+def train_loop(model, optimizer, criterion, train_loader, device):
+    model.train()
+    running_loss = 0
+    pbar = tqdm(train_loader)
+    dice_train = 0
+
+    for sampled_batch in pbar:
+        volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
+        volume_batch, label_batch = volume_batch.to(device), label_batch.to(device)
+        # print(volume_batch.shape,label_batch.shape)
+        outputs = model(volume_batch)
+        # print(outputs.shape)
+        loss = criterion(outputs, label_batch)
+        dice = cal_dice(outputs, label_batch)
+        dice_train += dice.item()
+        pbar.set_postfix(loss="{:.3f}".format(loss.item()), dice="{:.3f}".format(dice.item()))
+
+        running_loss += loss.item()
+
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+    loss = running_loss / len(train_loader)
+    dice = dice_train / len(train_loader)
+    return {'loss': loss, 'dice': dice}
+
+
+def eval_loop(model, criterion, valid_loader, device):
+    model.eval()
+    running_loss = 0
+    pbar = tqdm(valid_loader)
+    dice_valid = 0
+
+    with torch.no_grad():
+        for sampled_batch in pbar:
+            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
+            volume_batch, label_batch = volume_batch.to(device), label_batch.to(device)
+
+            outputs = model(volume_batch)
+
+            loss = criterion(outputs, label_batch)
+            dice = cal_dice(outputs, label_batch)
+            running_loss += loss.item()
+            dice_valid += dice.item()
+            pbar.set_postfix(loss="{:.3f}".format(loss.item()), dice="{:.3f}".format(dice.item()))
+
+    loss = running_loss / len(valid_loader)
+    dice = dice_valid / len(valid_loader)
+    return {'loss': loss, 'dice': dice}
+
+
+def train(args, model, optimizer, criterion, train_loader, valid_loader, epochs,
+          device, train_log, loss_min=999.0):
+    for e in range(epochs):
+        # train for epoch
+        train_metrics = train_loop(model, optimizer, criterion, train_loader, device)
+        valid_metrics = eval_loop(model, criterion, valid_loader, device)
+
+        # eval for epoch
+        info1 = "Epoch:[{}/{}] train_loss: {:.3f} valid_loss: {:.3f}".format(e + 1, epochs, train_metrics["loss"],
+                                                                             valid_metrics['loss'])
+        info2 = "train_dice: {:.3f} valid_dice: {:.3f}".format(train_metrics['dice'], valid_metrics['dice'])
+
+        print(info1 + '\n' + info2)
+        with open(train_log, 'a') as f:
+            f.write(info1 + '\n' + info2 + '\n')
+
+        if valid_metrics['loss'] < loss_min:
+            loss_min = valid_metrics['loss']
+            torch.save(model.state_dict(), args.save_path)
+    print("Finished Training!")
+
+
+def main(args):
+    torch.manual_seed(args.seed)  # 为CPU设置种子用于生成随机数,以使得结果是确定的
+    torch.cuda.manual_seed_all(args.seed)  # 为所有的GPU设置种子,以使得结果是确定的
+
+    torch.backends.cudnn.deterministic = True
+    torch.backends.cudnn.benchmark = True
+    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+    # data info
+    db_train = LAHeart(base_dir=args.train_path,
+                           split='train',
+                           transform=transforms.Compose([
+                               RandomRotFlip(),
+                               RandomCrop(args.patch_size),
+                               ToTensor(),
+                           ]))
+    db_test = LAHeart(base_dir=args.train_path,
+                          split='test',
+                          transform=transforms.Compose([
+                              CenterCrop(args.patch_size),
+                              ToTensor()
+                          ]))
+    print('Using {} images for training, {} images for testing.'.format(len(db_train), len(db_test)))
+    trainloader = DataLoader(db_train, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True,
+                             drop_last=True)
+    testloader = DataLoader(db_test, batch_size=1, num_workers=4, pin_memory=True)
+    model = VNet(n_channels=1,n_classes=args.num_classes, normalization='batchnorm', has_dropout=True).to(device)
+
+    criterion = Loss(n_classes=args.num_classes).to(device)
+    optimizer = optim.SGD(model.parameters(), momentum=0.9, lr=args.lr, weight_decay=1e-4)
+
+    # 加载训练模型
+    if os.path.exists(args.weight_path):
+        weight_dict = torch.load(args.weight_path, map_location=device)
+        model.load_state_dict(weight_dict)
+        print('Successfully loading checkpoint.')
+
+    train(args, model, optimizer, criterion, trainloader, testloader, args.epochs, device, train_log=args.train_log)
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--num_classes', type=int, default=2)
+    parser.add_argument('--seed', type=int, default=21)
+    parser.add_argument('--epochs', type=int, default=160)
+    parser.add_argument('--batch_size', type=int, default=4)
+    parser.add_argument('--lr', type=float, default=0.01)
+    parser.add_argument('--patch_size', type=float, default=(112, 112, 80))
+    parser.add_argument('--train_path', type=str, default='/***data_set/LASet/data')
+    parser.add_argument('--train_log', type=str, default='results/VNet_sup.txt')
+    parser.add_argument('--weight_path', type=str, default='results/VNet_sup.pth')  # 加载
+    parser.add_argument('--save_path', type=str, default='results/VNet_sup.pth')  # 保存
+    args = parser.parse_args()
+
+    main(args)
\ No newline at end of file