|
a |
|
b/src/utils.py |
|
|
1 |
""" |
|
|
2 |
Developed by: Daniel Crovo |
|
|
3 |
|
|
|
4 |
""" |
|
|
5 |
import numpy as np |
|
|
6 |
import csv |
|
|
7 |
from hs_dataset import HSDataset |
|
|
8 |
from torch.utils.data import DataLoader |
|
|
9 |
import torch |
|
|
10 |
import os |
|
|
11 |
import torchvision |
|
|
12 |
from torchmetrics.classification import BinaryJaccardIndex |
|
|
13 |
import torchvision.transforms as transforms |
|
|
14 |
import torchvision.utils as vutils |
|
|
15 |
import torch |
|
|
16 |
from PIL import Image |
|
|
17 |
|
|
|
18 |
def get_loaders(train_dir, train_maskdir, val_dir, val_maskdir, batch_size, train_transform, |
|
|
19 |
val_transform, num_workers, w_level, w_width, pin_memory = True, normalize=False): |
|
|
20 |
"""_summary_ |
|
|
21 |
|
|
|
22 |
Args: |
|
|
23 |
train_dir (_type_): _description_ |
|
|
24 |
train_maskdir (_type_): _description_ |
|
|
25 |
val_dir (_type_): _description_ |
|
|
26 |
val_maskdir (_type_): _description_ |
|
|
27 |
batch_size (_type_): _description_ |
|
|
28 |
train_transform (_type_): _description_ |
|
|
29 |
val_transform (_type_): _description_ |
|
|
30 |
num_workers (_type_): _description_ |
|
|
31 |
pin_memory (bool, optional): _description_. Defaults to True. |
|
|
32 |
|
|
|
33 |
Returns: |
|
|
34 |
_type_: _description_ |
|
|
35 |
""" |
|
|
36 |
|
|
|
37 |
train_img_mask = HSDataset(image_dir = train_dir, |
|
|
38 |
mask_dir = train_maskdir, transform = train_transform, |
|
|
39 |
normalized=normalize, w_level=w_level, w_width=w_width) |
|
|
40 |
train_loader = DataLoader(train_img_mask, batch_size = batch_size, |
|
|
41 |
num_workers = num_workers, pin_memory = pin_memory, shuffle = True) |
|
|
42 |
val_img_mask = HSDataset(image_dir = val_dir, mask_dir = val_maskdir, |
|
|
43 |
transform = val_transform, normalized=normalize, |
|
|
44 |
w_level=w_level, w_width=w_width,) |
|
|
45 |
val_loader = DataLoader(val_img_mask, batch_size = batch_size, |
|
|
46 |
num_workers = num_workers, pin_memory = pin_memory, shuffle = False) |
|
|
47 |
|
|
|
48 |
return train_loader, val_loader |
|
|
49 |
|
|
|
50 |
|
|
|
51 |
def load_checkpoint(checkpoint, model): |
|
|
52 |
try: |
|
|
53 |
model.load_state_dict(checkpoint['state_dict']) |
|
|
54 |
print('\nCheckpoint importado exitosamente.') |
|
|
55 |
except: |
|
|
56 |
print('Error en la importación del Checkpoint.') |
|
|
57 |
|
|
|
58 |
def save_checkpoint(state, filename = 'my_checkpoint.pth.tar'): |
|
|
59 |
try: |
|
|
60 |
torch.save(state, filename) |
|
|
61 |
print('\nCheckpoint almacenado exitosamente.') |
|
|
62 |
except: |
|
|
63 |
print('Error en la importación del Checkpoint.') |
|
|
64 |
|
|
|
65 |
def compute_jaccard_index(preds, targets): |
|
|
66 |
intersection = torch.logical_and(preds, targets).sum() |
|
|
67 |
union = torch.logical_or(preds, targets).sum() |
|
|
68 |
jaccard = intersection.item() / (union.item()+ 1e-10) |
|
|
69 |
return jaccard |
|
|
70 |
def diceCoef(y_true, y_pred, smooth=1.): |
|
|
71 |
y_true_f = y_true.flatten() |
|
|
72 |
y_pred_f = y_pred.flatten() |
|
|
73 |
intersection = np.sum(y_true_f * y_pred_f) |
|
|
74 |
dice = (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth) |
|
|
75 |
return round(float(dice), 3) |
|
|
76 |
|
|
|
77 |
def perf(loader, model, device): |
|
|
78 |
dice_score = 0.0 |
|
|
79 |
jaccard = 0.0 |
|
|
80 |
|
|
|
81 |
model.eval() |
|
|
82 |
|
|
|
83 |
with torch.no_grad(): # deshabilitar el cálculo y almacenamiento de gradientes en el grafo computacional de PyTorch |
|
|
84 |
for x, y in loader: |
|
|
85 |
x = x.to(device = device) |
|
|
86 |
y = y.to(device).unsqueeze(1) |
|
|
87 |
preds = torch.sigmoid(model(x)) |
|
|
88 |
|
|
|
89 |
preds = (preds > 0.5).float() |
|
|
90 |
if((preds.sum() == 0) and (y.sum() == 0)): |
|
|
91 |
jaccard +=1 |
|
|
92 |
dice_score +=1 |
|
|
93 |
else: |
|
|
94 |
jaccard += compute_jaccard_index(preds, y) |
|
|
95 |
dice_score += (2*(preds*y).sum())/((preds + y).sum() + 1e-10) |
|
|
96 |
|
|
|
97 |
dice_s = dice_score/len(loader) |
|
|
98 |
jaccard_idx = jaccard/len(loader) |
|
|
99 |
|
|
|
100 |
print('\nDice score: {}'.format(dice_s)) |
|
|
101 |
print('Jaccard index: {}\n'.format(jaccard_idx)) |
|
|
102 |
|
|
|
103 |
model.train() |
|
|
104 |
return dice_s, jaccard_idx |
|
|
105 |
|
|
|
106 |
def save_preds_as_imgs(loader, model, device, folder = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Proyecto/saved_images'): |
|
|
107 |
model.eval() |
|
|
108 |
for idx, (x, y) in enumerate(loader): |
|
|
109 |
x = x.to(device = device) |
|
|
110 |
with torch.no_grad(): # deshabilitar el cálculo y almacenamiento de gradientes en el grafo computacional de PyTorch |
|
|
111 |
preds = torch.sigmoid(model(x)) |
|
|
112 |
preds = (preds > 0.5).float() |
|
|
113 |
|
|
|
114 |
torchvision.utils.save_image(preds, f'{folder}/y_hat_{idx}.png') # almacenamiento de máscaras predichas |
|
|
115 |
y = torch.unsqueeze(y, 1).to(torch.float32) |
|
|
116 |
torchvision.utils.save_image(y, f'{folder}/y_{idx}.png') # almacenamiento de máscaras reales |
|
|
117 |
torchvision.utils.save_image(x, f'{folder}/x_{idx}.png') # almacenamiento de máscaras reales |
|
|
118 |
|
|
|
119 |
#masked_img= add_mask_to_rgb_image(x,preds) |
|
|
120 |
|
|
|
121 |
#torchvision.utils.save_image(masked_img, f'{folder}/masked_{idx}.png') # almacenamiento de máscaras reales |
|
|
122 |
|
|
|
123 |
|
|
|
124 |
model.train() |
|
|
125 |
|
|
|
126 |
def add_mask_to_rgb_image(rgb_image_tensor, mask_tensor): |
|
|
127 |
# Apply the mask to the RGB image tensor |
|
|
128 |
masked_image_tensor = torch.where(mask_tensor > 0, torch.tensor([32,255, 0, 0]), rgb_image_tensor) |
|
|
129 |
|
|
|
130 |
return masked_image_tensor #print(dice_score) |
|
|
131 |
#train_loss = [tensor.cpu() for tensor in train_loss] |
|
|
132 |
#train_dice = [tensor.cpu() for tensor in train_dice] |
|
|
133 |
|
|
|
134 |
|
|
|
135 |
def save_metrics_to_csv(epoch, train_loss, train_dice, train_jaccard, dice_score, jaccard_score, filename): |
|
|
136 |
metrics = { |
|
|
137 |
'epoch': epoch, |
|
|
138 |
'train_loss': train_loss, |
|
|
139 |
'train_dice': train_dice, |
|
|
140 |
'train_jaccard': train_jaccard, |
|
|
141 |
'dice_score': dice_score, |
|
|
142 |
'jaccard_score': jaccard_score |
|
|
143 |
} |
|
|
144 |
|
|
|
145 |
file_exists = os.path.isfile(filename) |
|
|
146 |
|
|
|
147 |
with open(filename, mode='a', newline='') as file: |
|
|
148 |
writer = csv.DictWriter(file, fieldnames=metrics.keys()) |
|
|
149 |
|
|
|
150 |
if not file_exists: |
|
|
151 |
writer.writeheader() |
|
|
152 |
|
|
|
153 |
writer.writerow(metrics) |
|
|
154 |
|
|
|
155 |
|
|
|
156 |
def save_preds_as_imgs2(loader, model, device, folder='/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Proyecto/saved_images'): |
|
|
157 |
model.eval() |
|
|
158 |
for idx, (x, y) in enumerate(loader): |
|
|
159 |
x = x.to(device=device) |
|
|
160 |
with torch.no_grad(): |
|
|
161 |
preds = torch.sigmoid(model(x)) |
|
|
162 |
preds = (preds > 0.5).float() |
|
|
163 |
|
|
|
164 |
# Convert single-channel mask to RGB format |
|
|
165 |
y_rgb = y.repeat(1, 3, 1, 1) |
|
|
166 |
|
|
|
167 |
# Merge predicted mask with input image |
|
|
168 |
merged_img = x.clone() |
|
|
169 |
merged_img[:, :3] = torch.where(preds > 0, torch.tensor([1.0, 0.0, 0.0]), merged_img[:, :3]) |
|
|
170 |
|
|
|
171 |
# Save the merged image |
|
|
172 |
vutils.save_image(merged_img, f'{folder}/merged_{idx}.png') |
|
|
173 |
|