|
a |
|
b/src/evaluation.py |
|
|
1 |
from utils import ( # .py previamente desarrollado |
|
|
2 |
get_loaders, |
|
|
3 |
load_checkpoint, |
|
|
4 |
save_checkpoint, |
|
|
5 |
perf, |
|
|
6 |
save_preds_as_imgs, |
|
|
7 |
save_metrics_to_csv, |
|
|
8 |
compute_jaccard_index, |
|
|
9 |
) |
|
|
10 |
import torch |
|
|
11 |
from segmentation_models_pytorch.encoders import get_preprocessing_fn |
|
|
12 |
import multiprocessing # para conocer el número de workers (núcleos) |
|
|
13 |
import albumentations as A # para emplear data augmentation |
|
|
14 |
from albumentations.pytorch import ToTensorV2 # para convertir imágenes y etiquetas en tensores de Pytorch |
|
|
15 |
import torch.optim as optim # para optimización de parámetros |
|
|
16 |
import segmentation_models_pytorch as smp |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
CHECKPOINT_PATH = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/checkpoints/MANET_my_checkpoint.pth.tar' |
|
|
20 |
TEST_IMAGES = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/test_images' |
|
|
21 |
TEST_MASK = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/test_masks' |
|
|
22 |
TRAIN_IMG_DIR = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/images' |
|
|
23 |
TRAIN_MAKS_DIR = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/masks' |
|
|
24 |
IMAGE_HEIGHT = 192 # 1280 px. |
|
|
25 |
IMAGE_WIDTH = 192 # 1918 px. |
|
|
26 |
PREDS_PATH = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/saved_images_test' |
|
|
27 |
|
|
|
28 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
29 |
|
|
|
30 |
def evaluate_training_set(model, loader, device): |
|
|
31 |
model.eval() |
|
|
32 |
|
|
|
33 |
dice_score = 0.0 |
|
|
34 |
jaccard = 0.0 |
|
|
35 |
num_samples = 0 |
|
|
36 |
|
|
|
37 |
with torch.no_grad(): |
|
|
38 |
for data, y in loader: |
|
|
39 |
data = data.to(device) |
|
|
40 |
y = y.float().unsqueeze(1).to(device) |
|
|
41 |
preds = model(data) |
|
|
42 |
preds = torch.sigmoid(model(data)) |
|
|
43 |
|
|
|
44 |
preds = (preds > 0.5).float() |
|
|
45 |
|
|
|
46 |
if((preds.sum() == 0) and (y.sum() == 0)): |
|
|
47 |
jaccard +=1 |
|
|
48 |
dice_score +=1 |
|
|
49 |
else: |
|
|
50 |
jaccard += compute_jaccard_index(preds, y) |
|
|
51 |
dice_score += (2*(preds*y).sum())/((preds + y).sum() + 1e-10) |
|
|
52 |
|
|
|
53 |
|
|
|
54 |
dice_s = dice_score/len(loader) |
|
|
55 |
jaccard_idx = jaccard/len(loader) |
|
|
56 |
|
|
|
57 |
|
|
|
58 |
return dice_s, jaccard_idx |
|
|
59 |
|
|
|
60 |
def main(): |
|
|
61 |
x_start = 50 |
|
|
62 |
x_end = 462 |
|
|
63 |
y_start = 50 |
|
|
64 |
y_end = 462 |
|
|
65 |
val_transforms = A.Compose( |
|
|
66 |
[A.Crop(x_start, y_start, x_end, y_end, always_apply= False), |
|
|
67 |
A.Resize(height = IMAGE_HEIGHT, width = IMAGE_WIDTH), |
|
|
68 |
A.Normalize( |
|
|
69 |
mean = [0.0, 0.0, 0.0], |
|
|
70 |
std = [1.0, 1.0, 1.0], |
|
|
71 |
max_pixel_value = 255.0 |
|
|
72 |
), |
|
|
73 |
ToTensorV2(), # conversión a tensor de Pytorch |
|
|
74 |
], |
|
|
75 |
) |
|
|
76 |
train_loader, test_loader = get_loaders( # cargar y preprocesar los datos de entrenamiento y validación con base en los formatos tensoriales |
|
|
77 |
TRAIN_IMG_DIR, |
|
|
78 |
TRAIN_MAKS_DIR, |
|
|
79 |
TEST_IMAGES, |
|
|
80 |
TEST_MASK, |
|
|
81 |
16, |
|
|
82 |
None, |
|
|
83 |
val_transforms, |
|
|
84 |
16, |
|
|
85 |
w_level=40, |
|
|
86 |
w_width=350, |
|
|
87 |
pin_memory= True, |
|
|
88 |
normalize=False, |
|
|
89 |
|
|
|
90 |
) |
|
|
91 |
model = smp.MAnet( |
|
|
92 |
encoder_name="mit_b3", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 |
|
|
93 |
encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization |
|
|
94 |
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.) |
|
|
95 |
classes=1, |
|
|
96 |
decoder_use_batchnorm=True, |
|
|
97 |
).to(DEVICE) |
|
|
98 |
load_checkpoint(torch.load(CHECKPOINT_PATH), model) |
|
|
99 |
|
|
|
100 |
|
|
|
101 |
dice_score, jaccard_index = evaluate_training_set(model, test_loader, DEVICE) |
|
|
102 |
print(f"Average Dice Score (Training Set): {dice_score:.4f}") |
|
|
103 |
print(f"Average Jaccard Index (Training Set): {jaccard_index:.4f}") |
|
|
104 |
save_preds_as_imgs(test_loader, model, DEVICE,PREDS_PATH) |
|
|
105 |
|
|
|
106 |
if __name__ == '__main__': |
|
|
107 |
main() |