--- a +++ b/U-Net/train_blood.py @@ -0,0 +1,243 @@ +import argparse +import logging +import os +import random +import sys +import copy +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as transforms +import torchvision.transforms.functional as TF +import torchvision.models as models + +from pathlib import Path +from torch import optim +from torch.utils.data import DataLoader, random_split +from tqdm import tqdm + +from evaluate import evaluate +from unet.unet_model import UNet +from utils.data_loading import BasicDataset, CarvanaDataset +from utils.dice_score import dice_loss + +import segmentation_models_pytorch as smp + +device = torch.device('cuda:4' if torch.cuda.is_available() else 'cpu') + +dir_img = Path('./data/train/imgs/') +dir_mask = Path('./data/train/masks/') +dir_checkpoint = Path('./checkpoints') + +def train_model( + model, device, epochs, batch_size, learning_rate, + val_percent: float = 0.1, + save_checkpoint: bool = True, + img_scale: float = 0.5, + amp: bool = False, + weight_decay: float = 1e-8, + momentum: float = 0.5, + gradient_clipping: float = 1.0 + ): + + best_model_params = copy.deepcopy(model.state_dict()) + best_acc = 0.0 + best_epoch = 0 + + # 1. Create dataset + data_transform = transforms.Compose([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + try: + dataset = CarvanaDataset(dir_img, dir_mask, img_scale) + except (AssertionError, RuntimeError, IndexError): + dataset = BasicDataset(dir_img, dir_mask, img_scale, data_transform) + + # 2. Split into train / validation partitions + n_val = int(len(dataset) * val_percent) + n_train = len(dataset) - n_val + + train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0)) + + # 3. Create data loaders + loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True) + train_loader = DataLoader(train_set, shuffle=True, **loader_args) + val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args) + + logging.info(f'''Starting training: + Epochs: {epochs} + Batch size: {batch_size} + Learning rate: {learning_rate} + Training size: {n_train} + Validation size: {n_val} + Checkpoints: {save_checkpoint} + Device: {device.type} + Images scaling: {img_scale} + Mixed Precision: {amp} + ''') + + # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP + optimizer = optim.Adam(model.parameters(), + lr=learning_rate, weight_decay=weight_decay) + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5) # goal: maximize Dice score + grad_scaler = torch.cuda.amp.GradScaler(enabled=amp) + criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss() + global_step = 0 + + # 5. Begin training + for epoch in range(1, epochs + 1): + model.train() + epoch_loss = 0 + with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar: + for batch in train_loader: + images, true_masks = batch['image'], batch['mask'] + + assert images.shape[1] == model.n_channels, \ + f'Network has been defined with {model.n_channels} input channels, ' \ + f'but loaded images have {images.shape[1]} channels. Please check that ' \ + 'the images are loaded correctly.' + + images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last) + true_masks = true_masks.to(device=device, dtype=torch.long) + + with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp): + masks_pred = model(images) + if model.n_classes == 1: + loss = criterion(masks_pred.squeeze(1), true_masks.float()) + loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False) + else: + loss = criterion(masks_pred, true_masks) + loss += dice_loss( + F.softmax(masks_pred, dim=1).float(), + F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(), + multiclass=True + ) + + optimizer.zero_grad(set_to_none=True) + grad_scaler.scale(loss).backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping) + grad_scaler.step(optimizer) + grad_scaler.update() + + pbar.update(images.shape[0]) + global_step += 1 + epoch_loss += loss.item() + pbar.set_postfix(**{'loss (batch)': loss.item()}) + + # Evaluation round + division_step = (n_train // (5 * batch_size)) + if division_step > 0: + if global_step % division_step == 0: + + val_score = evaluate(model, val_loader, device, amp) + scheduler.step(val_score) + + logging.info('Validation Dice score: {}'.format(val_score)) + + # Check best accuracy model ( but not the best on test ) + if val_score > best_acc: + best_acc = val_score + best_epoch = epoch + best_model_params = copy.deepcopy(model.state_dict()) + logging.info("Best model: [" + f'epoch: {best_epoch}, acc: {best_acc:.4f}]') + + + if save_checkpoint: + Path(dir_checkpoint).mkdir(parents=True, exist_ok=True) + state_dict = model.state_dict() + state_dict['mask_values'] = dataset.mask_values + torch.save(state_dict, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch))) + logging.info(f'Checkpoint {epoch} saved!') + + # only weight + torch.save(best_model_params, f'epoch_{best_epoch}_acc_{best_acc:.2f}_best_val_acc.pth') + logging.info("Best model name : " + f'epoch_{best_epoch}_acc_{best_acc:.2f}_best_val_acc.pth') + + +def get_args(): + parser = argparse.ArgumentParser(description='Train the UNet on images and target masks') + parser.add_argument('--epochs', '-e', metavar='E', type=int, default=50, help='Number of epochs') + parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=16, help='Batch size') + parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=0.001, + help='Learning rate', dest='lr') + parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file') + parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images') + parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0, + help='Percent of the data that is used as validation (0-100)') + parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision') + parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling') + parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes') + + return parser.parse_args() + + +if __name__ == '__main__': + args = get_args() + + logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') + #device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + logging.info(f'Using device {device}') + + """ + Change here to adapt to your data + n_channels=3 for RGB images + n_classes is the number of probabilities you want to get per pixel + """ + model = UNet(n_channels=1, n_classes=5, bilinear=True) + + model = model.to(memory_format=torch.channels_last) + + logging.info(f'Network:\n' + f'\t{model.n_channels} input channels\n' + f'\t{model.n_classes} output channels (classes)\n' + f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling') + + if args.load: + state_dict = torch.load(args.load, map_location=device) + del state_dict['mask_values'] + model.load_state_dict(state_dict) + logging.info(f'Model loaded from {args.load}') + + model.to(device=device) + + train_model( + model=model, + epochs=args.epochs, + batch_size=args.batch_size, + learning_rate=args.lr, + device=device, + img_scale=args.scale, + val_percent=args.val / 100, + amp=args.amp + ) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +