Diff of /src/train.py [000000] .. [3475df]

Switch to unified view

a b/src/train.py
1
# -*- coding: utf-8 -*-
2
"""
3
4
"""
5
import os
6
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
7
8
9
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
10
import torch # para optimizar procesos de DL
11
import multiprocessing # para conocer el número de workers (núcleos)
12
import albumentations as A # para emplear data augmentation
13
from albumentations.pytorch import ToTensorV2 # para convertir imágenes y etiquetas en tensores de Pytorch
14
import torch.optim as optim # para optimización de parámetros
15
import torch.nn as nn # para crear, definir y personalizar diferentes tipos de capas, modelos y criterios de pérdida en DL
16
from tqdm import tqdm # para agregar barras de progreso a bucles o iteraciones de larga duración
17
from utils import ( # .py previamente desarrollado
18
                       get_loaders,
19
                       load_checkpoint,
20
                       save_checkpoint,
21
                       perf,
22
                       save_preds_as_imgs,
23
                       save_metrics_to_csv,
24
                       compute_jaccard_index,
25
                       diceCoef
26
                      )
27
#from loss_fn import dice_coefficient_loss, dice_coefficient
28
import segmentation_models_pytorch as smp
29
from segmentation_models_pytorch.encoders import get_preprocessing_fn
30
import matplotlib.pyplot as plt
31
from Unet import UNET
32
# Hiperparámetros preliminares: 
33
LEARNING_RATE = 0.8e-5
34
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
35
BATCH_SIZE = 16 
36
NUM_EPOCHS = 100
37
NUM_WORKERS = multiprocessing.cpu_count()
38
IMAGE_HEIGHT = 224 # 1280 px.
39
IMAGE_WIDTH = 224 # 1918 px.
40
PIN_MEMORY = True # almacena en la memoria fija del sistema una copia de los datos cargados en la memoria temporal de Python
41
LOAD_MODEL = False
42
CHECKPOINT_PATH = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/checkpoints/UNETpp_my_checkpoint.pth.tar'
43
TRAIN_IMG_DIR = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/images'
44
TRAIN_MAKS_DIR = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/masks'
45
VAL_IMG_DIR = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/val_images'
46
VAL_MASK_DIR = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/val_masks'
47
PREDS_PATH = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/saved_images'
48
METRICS_PATH = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/scores/UNETpp_.csv'
49
50
def train_func(train_loader, model, optimizer, loss_fn, scaler): 
51
    
52
    p_bar = tqdm(train_loader) # progress bar
53
    running_loss = 0.0
54
    running_dice = 0.0
55
    running_jaccard = 0.0
56
    dice = 0
57
    jaccard = 0
58
    for batch_idx, (data, targets) in enumerate(p_bar):
59
        
60
        data = data.to(device = DEVICE)
61
        targets = targets.float().unsqueeze(1).to(device = DEVICE) # agregar una nueva dimensión en la posición 1 del tensor "targets"
62
        #if not torch.all(targets==0):
63
        # Forward pass: 
64
        with torch.cuda.amp.autocast(): #para realizar operaciones de cálculo de punto flotante de precisión mixta (16 y 32 bits) en modelos de DL
65
            preds = model(data)
66
            loss = loss_fn(preds, targets)
67
            with torch.no_grad():
68
                p = torch.sigmoid(preds)
69
                p = (p > 0.5).float()
70
                if((p.sum() == 0) and (targets.sum() == 0)):
71
                    jaccard = 1
72
                    dice = (2*(p*targets).sum()+1)/((p + targets).sum() + 1)
73
                else:
74
                    jaccard = compute_jaccard_index(p, targets)
75
                    dice = (2*(p*targets).sum()+1)/((p + targets).sum() + 1)
76
            
77
        
78
        # Backward pass: 
79
        optimizer.zero_grad() 
80
        scaler.scale(loss).backward() 
81
        scaler.step(optimizer) 
82
        scaler.update() 
83
        running_loss += loss.item()
84
        running_dice += dice
85
        running_jaccard += jaccard
86
        
87
        # Actualización progress bar: 
88
        p_bar.set_postfix(loss = loss.item(), dice=dice.item(), jaccard=jaccard)
89
        torch.cuda.empty_cache()
90
91
    epoch_loss = running_loss/len(train_loader)
92
    epoch_dice = running_dice/len(train_loader)
93
    epoch_jaccard = running_jaccard/len(train_loader)
94
    return epoch_loss, epoch_dice, epoch_jaccard
95
        
96
def main():
97
    x_start = 50
98
    x_end = 462
99
    y_start = 50
100
    y_end = 462 
101
    train_transforms = A.Compose(
102
                                 [
103
                                  A.Crop(x_start, y_start, x_end, y_end, always_apply= True),
104
                                  A.CLAHE(p=0.2),
105
                                  A.GaussNoise(p=0.2),
106
                                  A.Resize(height = IMAGE_HEIGHT, width = IMAGE_WIDTH), 
107
                                  A.Rotate(limit = 30, p = 0.3), 
108
                                  A.HorizontalFlip(p = 0.3), 
109
                                  A.VerticalFlip(p = 0.3), 
110
                                  A.Normalize(
111
                                              mean = [0.0, 0.0, 0.0],
112
                                              std = [1.0, 1.0, 1.0], 
113
                                              max_pixel_value = 255.0
114
                                             ),
115
                                  ToTensorV2(), # conversión a tensor de Pytorch
116
                                 ],
117
                                )
118
    
119
    val_transforms = A.Compose(
120
                               [A.Crop(x_start, y_start, x_end, y_end, always_apply= False),
121
                                A.Resize(height = IMAGE_HEIGHT, width = IMAGE_WIDTH), 
122
                                A.Normalize(
123
                                            mean = [0.0, 0.0, 0.0],
124
                                            std = [1.0, 1.0, 1.0], 
125
                                            max_pixel_value = 255.0
126
                                         ),
127
                              ToTensorV2(), # conversión a tensor de Pytorch
128
                             ],
129
                            )
130
    
131
    train_loader, val_loader = get_loaders( # cargar y preprocesar los datos de entrenamiento y validación con base en los formatos tensoriales
132
                                           TRAIN_IMG_DIR, 
133
                                           TRAIN_MAKS_DIR, 
134
                                           VAL_IMG_DIR, 
135
                                           VAL_MASK_DIR,
136
                                           BATCH_SIZE,
137
                                           train_transforms,
138
                                           val_transforms, 
139
                                           NUM_WORKERS,
140
                                           w_level=40, 
141
                                           w_width=350,
142
                                           pin_memory= PIN_MEMORY, 
143
                                           normalize=False, 
144
                                           
145
                                          )
146
147
    #model = UNET(3,1).to(DEVICE)
148
    #preprocess_input = get_preprocessing_fn('mit_b3', pretrained='imagenet')
149
    model = smp.UnetPlusPlus(   
150
                encoder_name="timm-efficientnet-b5",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
151
                encoder_weights="advprop",     # use `imagenet` pre-trained weights for encoder initialization
152
                in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
153
                classes=1,
154
                decoder_use_batchnorm=True,
155
                ).to(DEVICE)
156
    
157
    if LOAD_MODEL == True: # existe un modelo entrenado preliminarmente
158
        load_checkpoint(torch.load(CHECKPOINT_PATH), model)
159
                                      
160
    optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE)
161
    loss_fn =  smp.losses.DiceLoss(smp.losses.BINARY_MODE, log_loss=True, eps=1e-10, from_logits=True) # Binary Cross Entropy with logits loss   
162
    scaler = torch.cuda.amp.GradScaler() # ajustar de manera dinámica la escala de los gradientes durante el entrenamiento para reducir el tiempo de entrenamiento mientras se mantiene la precisión de los resultados                        
163
    train_loss = []
164
    train_dice = []
165
    train_jaccard = []
166
    dice_score =[]
167
    jaccard_score = []
168
    # Entrenamiento: 
169
    for epoch in tqdm(range(NUM_EPOCHS)): 
170
        epoch_loss, dice_train, jaccard_train = train_func(train_loader, model, optimizer, loss_fn, scaler)
171
        train_loss.append(epoch_loss)
172
        train_dice.append(dice_train.detach().cpu().numpy())
173
        train_jaccard.append(jaccard_train)
174
        
175
        # Creación de checkpoint: 
176
        checkpoint = {
177
                      'state_dict': model.state_dict(), # estado de los parámetros 
178
                      'optimizer': optimizer.state_dict(), # estado de los gradientes            
179
                     }
180
        save_checkpoint(checkpoint, CHECKPOINT_PATH)
181
        
182
        # Rendimiento: 
183
        dice, jaccard = perf(val_loader, model, DEVICE)
184
        dice_score.append(dice.detach().cpu().numpy())
185
        jaccard_score.append(jaccard)
186
        
187
        # Almacenamiento de predicciones: 
188
        save_preds_as_imgs(val_loader, model, DEVICE,PREDS_PATH)
189
        save_metrics_to_csv(epoch,epoch_loss, dice_train.detach().cpu().numpy(), jaccard_train, dice.detach().cpu().numpy(), jaccard, METRICS_PATH)
190
191
192
193
    #train_jaccard = [tensor.cpu()  for tensor in train_jaccard]
194
    #dice_score = [tensor.cpu() for tensor in dice_score]
195
    #jaccard_score = [tensor.cpu() for tensor in jaccard_score]
196
197
    fig, axes = plt.subplots(1,3)
198
    fig.set_figheight(6)
199
    fig.set_figwidth(18)
200
    axes[0].plot(train_loss)
201
    axes[0].set_title('Training loss')
202
    axes[0].set_xlabel('Epochs')
203
    axes[0].set_ylabel('Loss')
204
205
    axes[1].plot(train_dice)
206
    axes[1].set_title('Dice Score in train set')
207
    axes[1].set_xlabel('Epochs')
208
    axes[1].set_ylabel('Dice Score')
209
    
210
    axes[2].plot(train_jaccard)
211
    axes[2].set_title('Jaccard Index in train set')
212
    axes[2].set_xlabel('Epochs')
213
    axes[2].set_ylabel('Jaccard Index')
214
    plt.savefig('/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/plots/Training_UNETpp.png')
215
    #plt.show()
216
217
218
    fig2, axes2 = plt.subplots(1,2)
219
    fig2.set_figheight(6)
220
    fig2.set_figwidth(18)
221
222
    axes2[0].plot(dice_score)
223
    axes2[0].set_title('Dice Score in validation set')
224
    axes2[0].set_xlabel('Epochs')
225
    axes2[0].set_ylabel('Dice Score')
226
    
227
    axes2[1].plot(jaccard_score)
228
    axes2[1].set_title('Jaccard Index in validation set')
229
    axes2[1].set_xlabel('Epochs')
230
    axes2[1].set_ylabel('Jaccard Index')
231
    plt.savefig('/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/plots/Validation_UNETpp.png')
232
    #plt.show()
233
234
235
if __name__ == '__main__': 
236
    main()