Diff of /src/train (1).py [000000] .. [3475df]

Switch to side-by-side view

--- a
+++ b/src/train (1).py
@@ -0,0 +1,213 @@
+# -*- coding: utf-8 -*-
+"""
+
+"""
+import os
+os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
+
+
+os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
+import torch # para optimizar procesos de DL
+import multiprocessing # para conocer el número de workers (núcleos)
+import albumentations as A # para emplear data augmentation
+from albumentations.pytorch import ToTensorV2 # para convertir imágenes y etiquetas en tensores de Pytorch
+import torch.optim as optim # para optimización de parámetros
+import torch.nn as nn # para crear, definir y personalizar diferentes tipos de capas, modelos y criterios de pérdida en DL
+from tqdm import tqdm # para agregar barras de progreso a bucles o iteraciones de larga duración
+from utils import ( # .py previamente desarrollado
+                       get_loaders,
+                       load_checkpoint,
+                       save_checkpoint,
+                       perf,
+                       save_preds_as_imgs,
+                       save_metrics_to_csv,
+                       compute_jaccard_index,
+                       diceCoef
+                      )
+#from loss_fn import dice_coefficient_loss, dice_coefficient
+import segmentation_models_pytorch as o01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/images'
+TRAIN_MAKS_DIR = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/masks'
+VAL_IMG_DIR = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/val_images'
+VAL_MASK_DIR = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/val_masks'
+PREDS_PATH = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/saved_images'
+METRICS_PATH = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/scores/UNETpp_.csv'
+
+def train_func(train_loader, model, optimizer, loss_fn, scaler): 
+    
+    p_bar = tqdm(train_loader) # progress bar
+    running_loss = 0.0
+    running_dice = 0.0
+    running_jaccard = 0.0
+    dice = 0
+    jaccard = 0
+    for batch_idx, (data, targets) in enumerate(p_bar):
+        
+        data = data.to(device = DEVICE)
+        targets = targets.float().unsqueeze(1).to(device = DEVICE) # agregar una nueva dimensión en la posición 1 del tensor "targets"
+        #if not torch.all(targets==0):
+        # Forward pass: 
+        with torch.cuda.amp.autocast(): #para realizar operaciones de cálculo de punto flotante de precisión mixta (16 y 32 bits) en modelos de DL
+            preds = model(data)
+            loss = loss_fn(preds, targets)
+            with torch.no_grad():
+                p = torch.sigmoid(preds)
+                p = (p > 0.5).float()
+                if((p.sum() == 0) and (targets.sum() == 0)):
+                    jaccard = 1
+                    dice = (2*(p*targets).sum()+1)/((p + targets).sum() + 1)
+                else:
+                    jaccard = compute_jaccard_index(p, targets)
+                    dice = (2*(p*targets).sum()+1)/((p + targets).sum() + 1)
+            
+        
+        # Backward pass: 
+        optimizer.zero_grad() 
+        scaler.scale(loss).backward() 
+        scaler.step(optimizer) 
+        scaler.update() 
+        running_loss += loss.item()
+        running_dice += dice
+        running_jaccard += jaccard
+        
+        # Actualización progress bar: 
+        p_bar.set_postfix(loss = loss.item(), dice=dice.item(), jaccard=jaccard)
+        torch.cuda.empty_cache()
+
+    epoch_loss = running_loss/len(train_loader)
+    epoch_dice = running_dice/len(train_loader)
+    epoch_jaccard = running_jaccard/len(train_loader)
+    return epoch_loss, epoch_dice, epoch_jaccard
+        
+def main():
+    x_start = 50
+    x_end = 462
+    y_start = 50
+    y_end = 462 
+    train_transforms = A.Compose(
+                                 [
+                                  A.Crop(x_start, y_start, x_end, y_end, always_apply= True),
+                                  A.CLAHE(p=0.2),
+                                  A.GaussNoise(p=0.2),
+                                  A.Resize(height = IMAGE_HEIGHT, width = IMAGE_WIDTH), 
+                                  A.Rotate(limit = 30, p = 0.3), 
+                                  A.HorizontalFlip(p = 0.3), 
+                                  A.VerticalFlip(p = 0.3), 
+                                  A.Normalize(
+                                              mean = [0.0, 0.0, 0.0],
+                                              std = [1.0, 1.0, 1.0], 
+                                              max_pixel_value = 255.0
+                                             ),
+                                  ToTensorV2(), # conversión a tensor de Pytorch
+                                 ],
+                                )
+    
+    val_transforms = A.Compose(
+                               [A.Crop(x_start, y_start, x_end, y_end, always_apply= False),
+                                A.Resize(height = IMAGE_HEIGHT, width = IMAGE_WIDTH), 
+                                A.Normalize(
+                                            mean = [0.0, 0.0, 0.0],
+                                            std = [1.0, 1.0, 1.0], 
+                                            max_pixel_value = 255.0
+                                         ),
+                              ToTensorV2(), # conversión a tensor de Pytorch
+                             ],
+                            )
+    
+    train_loader, val_loader = get_loaders( # cargar y preprocesar los datos de entrenamiento y validación con base en los formatos tensoriales
+                                           TRAIN_IMG_DIR, 
+                                           TRAIN_MAKS_DIR, 
+                                           VAL_IMG_DIR, 
+                                           VAL_MASK_DIR,
+                                           BATCH_SIZE,
+                                           train_transforms,
+                                           val_transforms, 
+                                           NUM_WORKERS,
+                                           w_level=40, 
+                                           w_width=350,
+                                           pin_memory= PIN_MEMORY, 
+                                           normalize=False, 
+                                           
+                                          )
+
+    #model = UNET(3,1).to(DEVICE)
+    #preprocess_input = get_preprocessing_fn('mit_b3', pretrained='imagenet')
+    model = smp.UnetPlusPlus(   
+                encoder_name="timm-efficientnet-b5",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
+                encoder_weights="advprop",     # use `imagenet` pre-trained weights for encoder initialization
+                in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
+                classes=1,
+                decoder_use_batchnorm=True,
+                ).to(DEVICE)
+    
+    if LOAD_MODEL == True: # existe un modelo entrenado preliminarmente
+        load_checkpoint(torch.load(CHECKPOINT_PATH), model)
+                                      
+    optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE)
+    loss_fn =  smp.losses.DiceLoss(smp.losses.BINARY_MODE, log_loss=True, eps=1e-10, from_logits=True) # Binary Cross Entropy with logits loss   
+    scaler = torch.cuda.amp.GradScaler() # ajustar de manera dinámica la escala de los gradientes durante el entrenamiento para reducir el tiempo de entrenamiento mientras se mantiene la precisión de los resultados                        
+    train_loss = []
+    train_dice = []
+    train_jaccard = []
+    dice_score =[]
+    jaccard_score = []
+    # Entrenamiento: 
+    for epoch in tqdm(range(NUM_EPOCHSodel.state_dict(), # estado de los parámetros 
+                      'optimizer': optimizer.state_dict(), # estado de los gradientes            
+                     }
+        save_checkpoint(checkpoint, CHECKPOINT_PATH)
+        
+        # Rendimiento: 
+        dice, jaccard = perf(val_loader, model, DEVICE)
+        dice_score.append(dice.detach().cpu().numpy())
+        jaccard_score.append(jaccard)
+        
+        # Almacenamiento de predicciones: 
+        save_preds_as_imgs(val_loader, model, DEVICE,PREDS_PATH)
+        save_metrics_to_csv(epoch,epoch_loss, dice_train.detach().cpu().numpy(), jaccard_train, dice.detach().cpu().numpy(), jaccard, METRICS_PATH)
+
+
+
+    #train_jaccard = [tensor.cpu()  for tensor in train_jaccard]
+    #dice_score = [tensor.cpu() for tensor in dice_score]
+    #jaccard_score = [tensor.cpu() for tensor in jaccard_score]
+
+    fig, axes = plt.subplots(1,3)
+    fig.set_figheight(6)
+    fig.set_figwidth(18)
+    axes[0].plot(train_loss)
+    axes[0].set_title('Training loss')
+    axes[0].set_xlabel('Epochs')
+    axes[0].set_ylabel('Loss')
+
+    axes[1].plot(train_dice)
+    axes[1].set_title('Dice Score in train set')
+    axes[1].set_xlabel('Epochs')
+    axes[1].set_ylabel('Dice Score')
+    
+    axes[2].plot(train_jaccard)
+    axes[2].set_title('Jaccard Index in train set')
+    axes[2].set_xlabel('Epochs')
+    axes[2].set_ylabel('Jaccard Index')
+    plt.savefig('/home/danielcroovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/plots/Training_UNETpp.png')
+    #plt.show()
+
+
+    fig2, axes2 = plt.subplots(1,2)
+    fig2.set_figheight(6)
+    fig2.set_figwidth(18)
+
+    axes2[0].plot(dice_score)
+    axes2[0].set_title('Dice Score in validation set')
+    axes2[0].set_xlabel('Epochs')
+    axes2[0].set_ylabel('Dice Score')
+    
+    axes2[1].plot(jaccard_score)
+    axes2[1].set_title('Jaccard Index in validation set')
+    axes2[1].set_xlabel('Epochs')
+    axes2[1].set_ylabel('Jaccard Index')
+    plt.savefig('/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/plots/Validation_UNETpp.png')
+    #plt.show()
+
+
+if __name__ == '__main__': 
+    main()
\ No newline at end of file