Diff of /U-Net/train_blood.py [000000] .. [6f3ba0]

Switch to side-by-side view

--- a
+++ b/U-Net/train_blood.py
@@ -0,0 +1,243 @@
+import argparse
+import logging
+import os
+import random
+import sys
+import copy
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.transforms as transforms
+import torchvision.transforms.functional as TF
+import torchvision.models as models
+
+from pathlib import Path
+from torch import optim
+from torch.utils.data import DataLoader, random_split
+from tqdm import tqdm
+
+from evaluate import evaluate
+from unet.unet_model import UNet
+from utils.data_loading import BasicDataset, CarvanaDataset
+from utils.dice_score import dice_loss
+
+import segmentation_models_pytorch as smp
+
+device = torch.device('cuda:4' if torch.cuda.is_available() else 'cpu')
+
+dir_img = Path('./data/train/imgs/')
+dir_mask = Path('./data/train/masks/')
+dir_checkpoint = Path('./checkpoints')
+
+def train_model(
+        model, device, epochs, batch_size, learning_rate,
+        val_percent: float = 0.1,
+        save_checkpoint: bool = True,
+        img_scale: float = 0.5,
+        amp: bool = False,
+        weight_decay: float = 1e-8,
+        momentum: float = 0.5,
+        gradient_clipping: float = 1.0
+    ):
+    
+    best_model_params = copy.deepcopy(model.state_dict())
+    best_acc = 0.0
+    best_epoch = 0
+
+    # 1. Create dataset    
+    data_transform = transforms.Compose([
+        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+    ])
+
+    try:
+        dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
+    except (AssertionError, RuntimeError, IndexError):
+        dataset = BasicDataset(dir_img, dir_mask, img_scale, data_transform)
+
+    # 2. Split into train / validation partitions
+    n_val = int(len(dataset) * val_percent)
+    n_train = len(dataset) - n_val
+    
+    train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))
+
+    # 3. Create data loaders
+    loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
+    train_loader = DataLoader(train_set, shuffle=True, **loader_args)
+    val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)
+
+    logging.info(f'''Starting training:
+        Epochs:          {epochs}
+        Batch size:      {batch_size}
+        Learning rate:   {learning_rate}
+        Training size:   {n_train}
+        Validation size: {n_val}
+        Checkpoints:     {save_checkpoint}
+        Device:          {device.type}
+        Images scaling:  {img_scale}
+        Mixed Precision: {amp}
+    ''')
+
+    # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
+    optimizer = optim.Adam(model.parameters(),
+                              lr=learning_rate, weight_decay=weight_decay)
+    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5)  # goal: maximize Dice score
+    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
+    criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss()
+    global_step = 0
+
+    # 5. Begin training
+    for epoch in range(1, epochs + 1):
+        model.train()
+        epoch_loss = 0
+        with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
+            for batch in train_loader:
+                images, true_masks = batch['image'], batch['mask']
+
+                assert images.shape[1] == model.n_channels, \
+                    f'Network has been defined with {model.n_channels} input channels, ' \
+                    f'but loaded images have {images.shape[1]} channels. Please check that ' \
+                    'the images are loaded correctly.'
+
+                images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
+                true_masks = true_masks.to(device=device, dtype=torch.long)
+
+                with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
+                    masks_pred = model(images)
+                    if model.n_classes == 1:
+                        loss = criterion(masks_pred.squeeze(1), true_masks.float())
+                        loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False)
+                    else:
+                        loss = criterion(masks_pred, true_masks)
+                        loss += dice_loss(
+                            F.softmax(masks_pred, dim=1).float(),
+                            F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(),
+                            multiclass=True
+                        )
+
+                optimizer.zero_grad(set_to_none=True)
+                grad_scaler.scale(loss).backward()
+                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
+                grad_scaler.step(optimizer)
+                grad_scaler.update()
+
+                pbar.update(images.shape[0])
+                global_step += 1
+                epoch_loss += loss.item()
+                pbar.set_postfix(**{'loss (batch)': loss.item()})
+
+                # Evaluation round
+                division_step = (n_train // (5 * batch_size))
+                if division_step > 0:
+                    if global_step % division_step == 0:
+
+                        val_score = evaluate(model, val_loader, device, amp)
+                        scheduler.step(val_score)
+
+                        logging.info('Validation Dice score: {}'.format(val_score))
+
+            # Check best accuracy model ( but not the best on test )
+            if val_score > best_acc:
+                best_acc = val_score
+                best_epoch = epoch
+                best_model_params = copy.deepcopy(model.state_dict())
+            logging.info("Best model: [" + f'epoch: {best_epoch}, acc: {best_acc:.4f}]')
+
+
+        if save_checkpoint:
+            Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
+            state_dict = model.state_dict()
+            state_dict['mask_values'] = dataset.mask_values
+            torch.save(state_dict, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch)))
+            logging.info(f'Checkpoint {epoch} saved!')
+        
+    # only weight
+    torch.save(best_model_params, f'epoch_{best_epoch}_acc_{best_acc:.2f}_best_val_acc.pth')
+    logging.info("Best model name : " + f'epoch_{best_epoch}_acc_{best_acc:.2f}_best_val_acc.pth')
+
+
+def get_args():
+    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
+    parser.add_argument('--epochs', '-e', metavar='E', type=int, default=50, help='Number of epochs')
+    parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=16, help='Batch size')
+    parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=0.001,
+                        help='Learning rate', dest='lr')
+    parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
+    parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
+    parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
+                        help='Percent of the data that is used as validation (0-100)')
+    parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
+    parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
+    parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
+
+    return parser.parse_args()
+
+
+if __name__ == '__main__':
+    args = get_args()
+
+    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
+    #device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+    logging.info(f'Using device {device}')
+
+    """
+    Change here to adapt to your data
+    n_channels=3 for RGB images
+    n_classes is the number of probabilities you want to get per pixel
+    """
+    model = UNet(n_channels=1, n_classes=5, bilinear=True)
+
+    model = model.to(memory_format=torch.channels_last)
+
+    logging.info(f'Network:\n'
+                 f'\t{model.n_channels} input channels\n'
+                 f'\t{model.n_classes} output channels (classes)\n'
+                 f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling')
+
+    if args.load:
+        state_dict = torch.load(args.load, map_location=device)
+        del state_dict['mask_values']
+        model.load_state_dict(state_dict)
+        logging.info(f'Model loaded from {args.load}')
+
+    model.to(device=device)
+    
+    train_model(
+        model=model,
+        epochs=args.epochs,
+        batch_size=args.batch_size,
+        learning_rate=args.lr,
+        device=device,
+        img_scale=args.scale,
+        val_percent=args.val / 100,
+        amp=args.amp
+    )
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+