Diff of /src/train (1).py [000000] .. [3475df]

Switch to unified view

a b/src/train (1).py
1
# -*- coding: utf-8 -*-
2
"""
3
4
"""
5
import os
6
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
7
8
9
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
10
import torch # para optimizar procesos de DL
11
import multiprocessing # para conocer el número de workers (núcleos)
12
import albumentations as A # para emplear data augmentation
13
from albumentations.pytorch import ToTensorV2 # para convertir imágenes y etiquetas en tensores de Pytorch
14
import torch.optim as optim # para optimización de parámetros
15
import torch.nn as nn # para crear, definir y personalizar diferentes tipos de capas, modelos y criterios de pérdida en DL
16
from tqdm import tqdm # para agregar barras de progreso a bucles o iteraciones de larga duración
17
from utils import ( # .py previamente desarrollado
18
                       get_loaders,
19
                       load_checkpoint,
20
                       save_checkpoint,
21
                       perf,
22
                       save_preds_as_imgs,
23
                       save_metrics_to_csv,
24
                       compute_jaccard_index,
25
                       diceCoef
26
                      )
27
#from loss_fn import dice_coefficient_loss, dice_coefficient
28
import segmentation_models_pytorch as o01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/images'
29
TRAIN_MAKS_DIR = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/masks'
30
VAL_IMG_DIR = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/val_images'
31
VAL_MASK_DIR = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/val_masks'
32
PREDS_PATH = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/saved_images'
33
METRICS_PATH = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/scores/UNETpp_.csv'
34
35
def train_func(train_loader, model, optimizer, loss_fn, scaler): 
36
    
37
    p_bar = tqdm(train_loader) # progress bar
38
    running_loss = 0.0
39
    running_dice = 0.0
40
    running_jaccard = 0.0
41
    dice = 0
42
    jaccard = 0
43
    for batch_idx, (data, targets) in enumerate(p_bar):
44
        
45
        data = data.to(device = DEVICE)
46
        targets = targets.float().unsqueeze(1).to(device = DEVICE) # agregar una nueva dimensión en la posición 1 del tensor "targets"
47
        #if not torch.all(targets==0):
48
        # Forward pass: 
49
        with torch.cuda.amp.autocast(): #para realizar operaciones de cálculo de punto flotante de precisión mixta (16 y 32 bits) en modelos de DL
50
            preds = model(data)
51
            loss = loss_fn(preds, targets)
52
            with torch.no_grad():
53
                p = torch.sigmoid(preds)
54
                p = (p > 0.5).float()
55
                if((p.sum() == 0) and (targets.sum() == 0)):
56
                    jaccard = 1
57
                    dice = (2*(p*targets).sum()+1)/((p + targets).sum() + 1)
58
                else:
59
                    jaccard = compute_jaccard_index(p, targets)
60
                    dice = (2*(p*targets).sum()+1)/((p + targets).sum() + 1)
61
            
62
        
63
        # Backward pass: 
64
        optimizer.zero_grad() 
65
        scaler.scale(loss).backward() 
66
        scaler.step(optimizer) 
67
        scaler.update() 
68
        running_loss += loss.item()
69
        running_dice += dice
70
        running_jaccard += jaccard
71
        
72
        # Actualización progress bar: 
73
        p_bar.set_postfix(loss = loss.item(), dice=dice.item(), jaccard=jaccard)
74
        torch.cuda.empty_cache()
75
76
    epoch_loss = running_loss/len(train_loader)
77
    epoch_dice = running_dice/len(train_loader)
78
    epoch_jaccard = running_jaccard/len(train_loader)
79
    return epoch_loss, epoch_dice, epoch_jaccard
80
        
81
def main():
82
    x_start = 50
83
    x_end = 462
84
    y_start = 50
85
    y_end = 462 
86
    train_transforms = A.Compose(
87
                                 [
88
                                  A.Crop(x_start, y_start, x_end, y_end, always_apply= True),
89
                                  A.CLAHE(p=0.2),
90
                                  A.GaussNoise(p=0.2),
91
                                  A.Resize(height = IMAGE_HEIGHT, width = IMAGE_WIDTH), 
92
                                  A.Rotate(limit = 30, p = 0.3), 
93
                                  A.HorizontalFlip(p = 0.3), 
94
                                  A.VerticalFlip(p = 0.3), 
95
                                  A.Normalize(
96
                                              mean = [0.0, 0.0, 0.0],
97
                                              std = [1.0, 1.0, 1.0], 
98
                                              max_pixel_value = 255.0
99
                                             ),
100
                                  ToTensorV2(), # conversión a tensor de Pytorch
101
                                 ],
102
                                )
103
    
104
    val_transforms = A.Compose(
105
                               [A.Crop(x_start, y_start, x_end, y_end, always_apply= False),
106
                                A.Resize(height = IMAGE_HEIGHT, width = IMAGE_WIDTH), 
107
                                A.Normalize(
108
                                            mean = [0.0, 0.0, 0.0],
109
                                            std = [1.0, 1.0, 1.0], 
110
                                            max_pixel_value = 255.0
111
                                         ),
112
                              ToTensorV2(), # conversión a tensor de Pytorch
113
                             ],
114
                            )
115
    
116
    train_loader, val_loader = get_loaders( # cargar y preprocesar los datos de entrenamiento y validación con base en los formatos tensoriales
117
                                           TRAIN_IMG_DIR, 
118
                                           TRAIN_MAKS_DIR, 
119
                                           VAL_IMG_DIR, 
120
                                           VAL_MASK_DIR,
121
                                           BATCH_SIZE,
122
                                           train_transforms,
123
                                           val_transforms, 
124
                                           NUM_WORKERS,
125
                                           w_level=40, 
126
                                           w_width=350,
127
                                           pin_memory= PIN_MEMORY, 
128
                                           normalize=False, 
129
                                           
130
                                          )
131
132
    #model = UNET(3,1).to(DEVICE)
133
    #preprocess_input = get_preprocessing_fn('mit_b3', pretrained='imagenet')
134
    model = smp.UnetPlusPlus(   
135
                encoder_name="timm-efficientnet-b5",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
136
                encoder_weights="advprop",     # use `imagenet` pre-trained weights for encoder initialization
137
                in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
138
                classes=1,
139
                decoder_use_batchnorm=True,
140
                ).to(DEVICE)
141
    
142
    if LOAD_MODEL == True: # existe un modelo entrenado preliminarmente
143
        load_checkpoint(torch.load(CHECKPOINT_PATH), model)
144
                                      
145
    optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE)
146
    loss_fn =  smp.losses.DiceLoss(smp.losses.BINARY_MODE, log_loss=True, eps=1e-10, from_logits=True) # Binary Cross Entropy with logits loss   
147
    scaler = torch.cuda.amp.GradScaler() # ajustar de manera dinámica la escala de los gradientes durante el entrenamiento para reducir el tiempo de entrenamiento mientras se mantiene la precisión de los resultados                        
148
    train_loss = []
149
    train_dice = []
150
    train_jaccard = []
151
    dice_score =[]
152
    jaccard_score = []
153
    # Entrenamiento: 
154
    for epoch in tqdm(range(NUM_EPOCHSodel.state_dict(), # estado de los parámetros 
155
                      'optimizer': optimizer.state_dict(), # estado de los gradientes            
156
                     }
157
        save_checkpoint(checkpoint, CHECKPOINT_PATH)
158
        
159
        # Rendimiento: 
160
        dice, jaccard = perf(val_loader, model, DEVICE)
161
        dice_score.append(dice.detach().cpu().numpy())
162
        jaccard_score.append(jaccard)
163
        
164
        # Almacenamiento de predicciones: 
165
        save_preds_as_imgs(val_loader, model, DEVICE,PREDS_PATH)
166
        save_metrics_to_csv(epoch,epoch_loss, dice_train.detach().cpu().numpy(), jaccard_train, dice.detach().cpu().numpy(), jaccard, METRICS_PATH)
167
168
169
170
    #train_jaccard = [tensor.cpu()  for tensor in train_jaccard]
171
    #dice_score = [tensor.cpu() for tensor in dice_score]
172
    #jaccard_score = [tensor.cpu() for tensor in jaccard_score]
173
174
    fig, axes = plt.subplots(1,3)
175
    fig.set_figheight(6)
176
    fig.set_figwidth(18)
177
    axes[0].plot(train_loss)
178
    axes[0].set_title('Training loss')
179
    axes[0].set_xlabel('Epochs')
180
    axes[0].set_ylabel('Loss')
181
182
    axes[1].plot(train_dice)
183
    axes[1].set_title('Dice Score in train set')
184
    axes[1].set_xlabel('Epochs')
185
    axes[1].set_ylabel('Dice Score')
186
    
187
    axes[2].plot(train_jaccard)
188
    axes[2].set_title('Jaccard Index in train set')
189
    axes[2].set_xlabel('Epochs')
190
    axes[2].set_ylabel('Jaccard Index')
191
    plt.savefig('/home/danielcroovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/plots/Training_UNETpp.png')
192
    #plt.show()
193
194
195
    fig2, axes2 = plt.subplots(1,2)
196
    fig2.set_figheight(6)
197
    fig2.set_figwidth(18)
198
199
    axes2[0].plot(dice_score)
200
    axes2[0].set_title('Dice Score in validation set')
201
    axes2[0].set_xlabel('Epochs')
202
    axes2[0].set_ylabel('Dice Score')
203
    
204
    axes2[1].plot(jaccard_score)
205
    axes2[1].set_title('Jaccard Index in validation set')
206
    axes2[1].set_xlabel('Epochs')
207
    axes2[1].set_ylabel('Jaccard Index')
208
    plt.savefig('/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/plots/Validation_UNETpp.png')
209
    #plt.show()
210
211
212
if __name__ == '__main__': 
213
    main()