|
a |
|
b/src/train (1).py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
""" |
|
|
3 |
|
|
|
4 |
""" |
|
|
5 |
import os |
|
|
6 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' |
|
|
10 |
import torch # para optimizar procesos de DL |
|
|
11 |
import multiprocessing # para conocer el número de workers (núcleos) |
|
|
12 |
import albumentations as A # para emplear data augmentation |
|
|
13 |
from albumentations.pytorch import ToTensorV2 # para convertir imágenes y etiquetas en tensores de Pytorch |
|
|
14 |
import torch.optim as optim # para optimización de parámetros |
|
|
15 |
import torch.nn as nn # para crear, definir y personalizar diferentes tipos de capas, modelos y criterios de pérdida en DL |
|
|
16 |
from tqdm import tqdm # para agregar barras de progreso a bucles o iteraciones de larga duración |
|
|
17 |
from utils import ( # .py previamente desarrollado |
|
|
18 |
get_loaders, |
|
|
19 |
load_checkpoint, |
|
|
20 |
save_checkpoint, |
|
|
21 |
perf, |
|
|
22 |
save_preds_as_imgs, |
|
|
23 |
save_metrics_to_csv, |
|
|
24 |
compute_jaccard_index, |
|
|
25 |
diceCoef |
|
|
26 |
) |
|
|
27 |
#from loss_fn import dice_coefficient_loss, dice_coefficient |
|
|
28 |
import segmentation_models_pytorch as o01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/images' |
|
|
29 |
TRAIN_MAKS_DIR = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/masks' |
|
|
30 |
VAL_IMG_DIR = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/val_images' |
|
|
31 |
VAL_MASK_DIR = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/val_masks' |
|
|
32 |
PREDS_PATH = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/saved_images' |
|
|
33 |
METRICS_PATH = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/scores/UNETpp_.csv' |
|
|
34 |
|
|
|
35 |
def train_func(train_loader, model, optimizer, loss_fn, scaler): |
|
|
36 |
|
|
|
37 |
p_bar = tqdm(train_loader) # progress bar |
|
|
38 |
running_loss = 0.0 |
|
|
39 |
running_dice = 0.0 |
|
|
40 |
running_jaccard = 0.0 |
|
|
41 |
dice = 0 |
|
|
42 |
jaccard = 0 |
|
|
43 |
for batch_idx, (data, targets) in enumerate(p_bar): |
|
|
44 |
|
|
|
45 |
data = data.to(device = DEVICE) |
|
|
46 |
targets = targets.float().unsqueeze(1).to(device = DEVICE) # agregar una nueva dimensión en la posición 1 del tensor "targets" |
|
|
47 |
#if not torch.all(targets==0): |
|
|
48 |
# Forward pass: |
|
|
49 |
with torch.cuda.amp.autocast(): #para realizar operaciones de cálculo de punto flotante de precisión mixta (16 y 32 bits) en modelos de DL |
|
|
50 |
preds = model(data) |
|
|
51 |
loss = loss_fn(preds, targets) |
|
|
52 |
with torch.no_grad(): |
|
|
53 |
p = torch.sigmoid(preds) |
|
|
54 |
p = (p > 0.5).float() |
|
|
55 |
if((p.sum() == 0) and (targets.sum() == 0)): |
|
|
56 |
jaccard = 1 |
|
|
57 |
dice = (2*(p*targets).sum()+1)/((p + targets).sum() + 1) |
|
|
58 |
else: |
|
|
59 |
jaccard = compute_jaccard_index(p, targets) |
|
|
60 |
dice = (2*(p*targets).sum()+1)/((p + targets).sum() + 1) |
|
|
61 |
|
|
|
62 |
|
|
|
63 |
# Backward pass: |
|
|
64 |
optimizer.zero_grad() |
|
|
65 |
scaler.scale(loss).backward() |
|
|
66 |
scaler.step(optimizer) |
|
|
67 |
scaler.update() |
|
|
68 |
running_loss += loss.item() |
|
|
69 |
running_dice += dice |
|
|
70 |
running_jaccard += jaccard |
|
|
71 |
|
|
|
72 |
# Actualización progress bar: |
|
|
73 |
p_bar.set_postfix(loss = loss.item(), dice=dice.item(), jaccard=jaccard) |
|
|
74 |
torch.cuda.empty_cache() |
|
|
75 |
|
|
|
76 |
epoch_loss = running_loss/len(train_loader) |
|
|
77 |
epoch_dice = running_dice/len(train_loader) |
|
|
78 |
epoch_jaccard = running_jaccard/len(train_loader) |
|
|
79 |
return epoch_loss, epoch_dice, epoch_jaccard |
|
|
80 |
|
|
|
81 |
def main(): |
|
|
82 |
x_start = 50 |
|
|
83 |
x_end = 462 |
|
|
84 |
y_start = 50 |
|
|
85 |
y_end = 462 |
|
|
86 |
train_transforms = A.Compose( |
|
|
87 |
[ |
|
|
88 |
A.Crop(x_start, y_start, x_end, y_end, always_apply= True), |
|
|
89 |
A.CLAHE(p=0.2), |
|
|
90 |
A.GaussNoise(p=0.2), |
|
|
91 |
A.Resize(height = IMAGE_HEIGHT, width = IMAGE_WIDTH), |
|
|
92 |
A.Rotate(limit = 30, p = 0.3), |
|
|
93 |
A.HorizontalFlip(p = 0.3), |
|
|
94 |
A.VerticalFlip(p = 0.3), |
|
|
95 |
A.Normalize( |
|
|
96 |
mean = [0.0, 0.0, 0.0], |
|
|
97 |
std = [1.0, 1.0, 1.0], |
|
|
98 |
max_pixel_value = 255.0 |
|
|
99 |
), |
|
|
100 |
ToTensorV2(), # conversión a tensor de Pytorch |
|
|
101 |
], |
|
|
102 |
) |
|
|
103 |
|
|
|
104 |
val_transforms = A.Compose( |
|
|
105 |
[A.Crop(x_start, y_start, x_end, y_end, always_apply= False), |
|
|
106 |
A.Resize(height = IMAGE_HEIGHT, width = IMAGE_WIDTH), |
|
|
107 |
A.Normalize( |
|
|
108 |
mean = [0.0, 0.0, 0.0], |
|
|
109 |
std = [1.0, 1.0, 1.0], |
|
|
110 |
max_pixel_value = 255.0 |
|
|
111 |
), |
|
|
112 |
ToTensorV2(), # conversión a tensor de Pytorch |
|
|
113 |
], |
|
|
114 |
) |
|
|
115 |
|
|
|
116 |
train_loader, val_loader = get_loaders( # cargar y preprocesar los datos de entrenamiento y validación con base en los formatos tensoriales |
|
|
117 |
TRAIN_IMG_DIR, |
|
|
118 |
TRAIN_MAKS_DIR, |
|
|
119 |
VAL_IMG_DIR, |
|
|
120 |
VAL_MASK_DIR, |
|
|
121 |
BATCH_SIZE, |
|
|
122 |
train_transforms, |
|
|
123 |
val_transforms, |
|
|
124 |
NUM_WORKERS, |
|
|
125 |
w_level=40, |
|
|
126 |
w_width=350, |
|
|
127 |
pin_memory= PIN_MEMORY, |
|
|
128 |
normalize=False, |
|
|
129 |
|
|
|
130 |
) |
|
|
131 |
|
|
|
132 |
#model = UNET(3,1).to(DEVICE) |
|
|
133 |
#preprocess_input = get_preprocessing_fn('mit_b3', pretrained='imagenet') |
|
|
134 |
model = smp.UnetPlusPlus( |
|
|
135 |
encoder_name="timm-efficientnet-b5", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 |
|
|
136 |
encoder_weights="advprop", # use `imagenet` pre-trained weights for encoder initialization |
|
|
137 |
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.) |
|
|
138 |
classes=1, |
|
|
139 |
decoder_use_batchnorm=True, |
|
|
140 |
).to(DEVICE) |
|
|
141 |
|
|
|
142 |
if LOAD_MODEL == True: # existe un modelo entrenado preliminarmente |
|
|
143 |
load_checkpoint(torch.load(CHECKPOINT_PATH), model) |
|
|
144 |
|
|
|
145 |
optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE) |
|
|
146 |
loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, log_loss=True, eps=1e-10, from_logits=True) # Binary Cross Entropy with logits loss |
|
|
147 |
scaler = torch.cuda.amp.GradScaler() # ajustar de manera dinámica la escala de los gradientes durante el entrenamiento para reducir el tiempo de entrenamiento mientras se mantiene la precisión de los resultados |
|
|
148 |
train_loss = [] |
|
|
149 |
train_dice = [] |
|
|
150 |
train_jaccard = [] |
|
|
151 |
dice_score =[] |
|
|
152 |
jaccard_score = [] |
|
|
153 |
# Entrenamiento: |
|
|
154 |
for epoch in tqdm(range(NUM_EPOCHSodel.state_dict(), # estado de los parámetros |
|
|
155 |
'optimizer': optimizer.state_dict(), # estado de los gradientes |
|
|
156 |
} |
|
|
157 |
save_checkpoint(checkpoint, CHECKPOINT_PATH) |
|
|
158 |
|
|
|
159 |
# Rendimiento: |
|
|
160 |
dice, jaccard = perf(val_loader, model, DEVICE) |
|
|
161 |
dice_score.append(dice.detach().cpu().numpy()) |
|
|
162 |
jaccard_score.append(jaccard) |
|
|
163 |
|
|
|
164 |
# Almacenamiento de predicciones: |
|
|
165 |
save_preds_as_imgs(val_loader, model, DEVICE,PREDS_PATH) |
|
|
166 |
save_metrics_to_csv(epoch,epoch_loss, dice_train.detach().cpu().numpy(), jaccard_train, dice.detach().cpu().numpy(), jaccard, METRICS_PATH) |
|
|
167 |
|
|
|
168 |
|
|
|
169 |
|
|
|
170 |
#train_jaccard = [tensor.cpu() for tensor in train_jaccard] |
|
|
171 |
#dice_score = [tensor.cpu() for tensor in dice_score] |
|
|
172 |
#jaccard_score = [tensor.cpu() for tensor in jaccard_score] |
|
|
173 |
|
|
|
174 |
fig, axes = plt.subplots(1,3) |
|
|
175 |
fig.set_figheight(6) |
|
|
176 |
fig.set_figwidth(18) |
|
|
177 |
axes[0].plot(train_loss) |
|
|
178 |
axes[0].set_title('Training loss') |
|
|
179 |
axes[0].set_xlabel('Epochs') |
|
|
180 |
axes[0].set_ylabel('Loss') |
|
|
181 |
|
|
|
182 |
axes[1].plot(train_dice) |
|
|
183 |
axes[1].set_title('Dice Score in train set') |
|
|
184 |
axes[1].set_xlabel('Epochs') |
|
|
185 |
axes[1].set_ylabel('Dice Score') |
|
|
186 |
|
|
|
187 |
axes[2].plot(train_jaccard) |
|
|
188 |
axes[2].set_title('Jaccard Index in train set') |
|
|
189 |
axes[2].set_xlabel('Epochs') |
|
|
190 |
axes[2].set_ylabel('Jaccard Index') |
|
|
191 |
plt.savefig('/home/danielcroovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/plots/Training_UNETpp.png') |
|
|
192 |
#plt.show() |
|
|
193 |
|
|
|
194 |
|
|
|
195 |
fig2, axes2 = plt.subplots(1,2) |
|
|
196 |
fig2.set_figheight(6) |
|
|
197 |
fig2.set_figwidth(18) |
|
|
198 |
|
|
|
199 |
axes2[0].plot(dice_score) |
|
|
200 |
axes2[0].set_title('Dice Score in validation set') |
|
|
201 |
axes2[0].set_xlabel('Epochs') |
|
|
202 |
axes2[0].set_ylabel('Dice Score') |
|
|
203 |
|
|
|
204 |
axes2[1].plot(jaccard_score) |
|
|
205 |
axes2[1].set_title('Jaccard Index in validation set') |
|
|
206 |
axes2[1].set_xlabel('Epochs') |
|
|
207 |
axes2[1].set_ylabel('Jaccard Index') |
|
|
208 |
plt.savefig('/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/plots/Validation_UNETpp.png') |
|
|
209 |
#plt.show() |
|
|
210 |
|
|
|
211 |
|
|
|
212 |
if __name__ == '__main__': |
|
|
213 |
main() |