--- a +++ b/src/train (1).py @@ -0,0 +1,213 @@ +# -*- coding: utf-8 -*- +""" + +""" +import os +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" + + +os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' +import torch # para optimizar procesos de DL +import multiprocessing # para conocer el número de workers (núcleos) +import albumentations as A # para emplear data augmentation +from albumentations.pytorch import ToTensorV2 # para convertir imágenes y etiquetas en tensores de Pytorch +import torch.optim as optim # para optimización de parámetros +import torch.nn as nn # para crear, definir y personalizar diferentes tipos de capas, modelos y criterios de pérdida en DL +from tqdm import tqdm # para agregar barras de progreso a bucles o iteraciones de larga duración +from utils import ( # .py previamente desarrollado + get_loaders, + load_checkpoint, + save_checkpoint, + perf, + save_preds_as_imgs, + save_metrics_to_csv, + compute_jaccard_index, + diceCoef + ) +#from loss_fn import dice_coefficient_loss, dice_coefficient +import segmentation_models_pytorch as o01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/images' +TRAIN_MAKS_DIR = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/masks' +VAL_IMG_DIR = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/val_images' +VAL_MASK_DIR = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/val_masks' +PREDS_PATH = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/saved_images' +METRICS_PATH = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/scores/UNETpp_.csv' + +def train_func(train_loader, model, optimizer, loss_fn, scaler): + + p_bar = tqdm(train_loader) # progress bar + running_loss = 0.0 + running_dice = 0.0 + running_jaccard = 0.0 + dice = 0 + jaccard = 0 + for batch_idx, (data, targets) in enumerate(p_bar): + + data = data.to(device = DEVICE) + targets = targets.float().unsqueeze(1).to(device = DEVICE) # agregar una nueva dimensión en la posición 1 del tensor "targets" + #if not torch.all(targets==0): + # Forward pass: + with torch.cuda.amp.autocast(): #para realizar operaciones de cálculo de punto flotante de precisión mixta (16 y 32 bits) en modelos de DL + preds = model(data) + loss = loss_fn(preds, targets) + with torch.no_grad(): + p = torch.sigmoid(preds) + p = (p > 0.5).float() + if((p.sum() == 0) and (targets.sum() == 0)): + jaccard = 1 + dice = (2*(p*targets).sum()+1)/((p + targets).sum() + 1) + else: + jaccard = compute_jaccard_index(p, targets) + dice = (2*(p*targets).sum()+1)/((p + targets).sum() + 1) + + + # Backward pass: + optimizer.zero_grad() + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + running_loss += loss.item() + running_dice += dice + running_jaccard += jaccard + + # Actualización progress bar: + p_bar.set_postfix(loss = loss.item(), dice=dice.item(), jaccard=jaccard) + torch.cuda.empty_cache() + + epoch_loss = running_loss/len(train_loader) + epoch_dice = running_dice/len(train_loader) + epoch_jaccard = running_jaccard/len(train_loader) + return epoch_loss, epoch_dice, epoch_jaccard + +def main(): + x_start = 50 + x_end = 462 + y_start = 50 + y_end = 462 + train_transforms = A.Compose( + [ + A.Crop(x_start, y_start, x_end, y_end, always_apply= True), + A.CLAHE(p=0.2), + A.GaussNoise(p=0.2), + A.Resize(height = IMAGE_HEIGHT, width = IMAGE_WIDTH), + A.Rotate(limit = 30, p = 0.3), + A.HorizontalFlip(p = 0.3), + A.VerticalFlip(p = 0.3), + A.Normalize( + mean = [0.0, 0.0, 0.0], + std = [1.0, 1.0, 1.0], + max_pixel_value = 255.0 + ), + ToTensorV2(), # conversión a tensor de Pytorch + ], + ) + + val_transforms = A.Compose( + [A.Crop(x_start, y_start, x_end, y_end, always_apply= False), + A.Resize(height = IMAGE_HEIGHT, width = IMAGE_WIDTH), + A.Normalize( + mean = [0.0, 0.0, 0.0], + std = [1.0, 1.0, 1.0], + max_pixel_value = 255.0 + ), + ToTensorV2(), # conversión a tensor de Pytorch + ], + ) + + train_loader, val_loader = get_loaders( # cargar y preprocesar los datos de entrenamiento y validación con base en los formatos tensoriales + TRAIN_IMG_DIR, + TRAIN_MAKS_DIR, + VAL_IMG_DIR, + VAL_MASK_DIR, + BATCH_SIZE, + train_transforms, + val_transforms, + NUM_WORKERS, + w_level=40, + w_width=350, + pin_memory= PIN_MEMORY, + normalize=False, + + ) + + #model = UNET(3,1).to(DEVICE) + #preprocess_input = get_preprocessing_fn('mit_b3', pretrained='imagenet') + model = smp.UnetPlusPlus( + encoder_name="timm-efficientnet-b5", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 + encoder_weights="advprop", # use `imagenet` pre-trained weights for encoder initialization + in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.) + classes=1, + decoder_use_batchnorm=True, + ).to(DEVICE) + + if LOAD_MODEL == True: # existe un modelo entrenado preliminarmente + load_checkpoint(torch.load(CHECKPOINT_PATH), model) + + optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE) + loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, log_loss=True, eps=1e-10, from_logits=True) # Binary Cross Entropy with logits loss + scaler = torch.cuda.amp.GradScaler() # ajustar de manera dinámica la escala de los gradientes durante el entrenamiento para reducir el tiempo de entrenamiento mientras se mantiene la precisión de los resultados + train_loss = [] + train_dice = [] + train_jaccard = [] + dice_score =[] + jaccard_score = [] + # Entrenamiento: + for epoch in tqdm(range(NUM_EPOCHSodel.state_dict(), # estado de los parámetros + 'optimizer': optimizer.state_dict(), # estado de los gradientes + } + save_checkpoint(checkpoint, CHECKPOINT_PATH) + + # Rendimiento: + dice, jaccard = perf(val_loader, model, DEVICE) + dice_score.append(dice.detach().cpu().numpy()) + jaccard_score.append(jaccard) + + # Almacenamiento de predicciones: + save_preds_as_imgs(val_loader, model, DEVICE,PREDS_PATH) + save_metrics_to_csv(epoch,epoch_loss, dice_train.detach().cpu().numpy(), jaccard_train, dice.detach().cpu().numpy(), jaccard, METRICS_PATH) + + + + #train_jaccard = [tensor.cpu() for tensor in train_jaccard] + #dice_score = [tensor.cpu() for tensor in dice_score] + #jaccard_score = [tensor.cpu() for tensor in jaccard_score] + + fig, axes = plt.subplots(1,3) + fig.set_figheight(6) + fig.set_figwidth(18) + axes[0].plot(train_loss) + axes[0].set_title('Training loss') + axes[0].set_xlabel('Epochs') + axes[0].set_ylabel('Loss') + + axes[1].plot(train_dice) + axes[1].set_title('Dice Score in train set') + axes[1].set_xlabel('Epochs') + axes[1].set_ylabel('Dice Score') + + axes[2].plot(train_jaccard) + axes[2].set_title('Jaccard Index in train set') + axes[2].set_xlabel('Epochs') + axes[2].set_ylabel('Jaccard Index') + plt.savefig('/home/danielcroovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/plots/Training_UNETpp.png') + #plt.show() + + + fig2, axes2 = plt.subplots(1,2) + fig2.set_figheight(6) + fig2.set_figwidth(18) + + axes2[0].plot(dice_score) + axes2[0].set_title('Dice Score in validation set') + axes2[0].set_xlabel('Epochs') + axes2[0].set_ylabel('Dice Score') + + axes2[1].plot(jaccard_score) + axes2[1].set_title('Jaccard Index in validation set') + axes2[1].set_xlabel('Epochs') + axes2[1].set_ylabel('Jaccard Index') + plt.savefig('/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/plots/Validation_UNETpp.png') + #plt.show() + + +if __name__ == '__main__': + main() \ No newline at end of file