Diff of /src/evaluation.py [000000] .. [3475df]

Switch to unified view

a b/src/evaluation.py
1
from utils import ( # .py previamente desarrollado
2
                       get_loaders,
3
                       load_checkpoint,
4
                       save_checkpoint,
5
                       perf,
6
                       save_preds_as_imgs,
7
                       save_metrics_to_csv,
8
                       compute_jaccard_index,
9
                      )
10
import torch
11
from segmentation_models_pytorch.encoders import get_preprocessing_fn
12
import multiprocessing # para conocer el número de workers (núcleos)
13
import albumentations as A # para emplear data augmentation
14
from albumentations.pytorch import ToTensorV2 # para convertir imágenes y etiquetas en tensores de Pytorch
15
import torch.optim as optim # para optimización de parámetros
16
import segmentation_models_pytorch as smp
17
18
19
CHECKPOINT_PATH = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/checkpoints/MANET_my_checkpoint.pth.tar'
20
TEST_IMAGES = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/test_images'
21
TEST_MASK = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/test_masks'
22
TRAIN_IMG_DIR = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/images'
23
TRAIN_MAKS_DIR = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/2d_data/masks'
24
IMAGE_HEIGHT = 192 # 1280 px.
25
IMAGE_WIDTH = 192 # 1918 px.
26
PREDS_PATH = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Heart_Segmentation/saved_images_test'
27
28
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
29
30
def evaluate_training_set(model, loader, device):
31
    model.eval()
32
33
    dice_score = 0.0
34
    jaccard = 0.0
35
    num_samples = 0
36
37
    with torch.no_grad():
38
        for data, y in loader:
39
            data = data.to(device)
40
            y = y.float().unsqueeze(1).to(device)
41
            preds = model(data)
42
            preds = torch.sigmoid(model(data))
43
            
44
            preds = (preds > 0.5).float()
45
            
46
            if((preds.sum() == 0) and (y.sum() == 0)):
47
                 jaccard +=1
48
                 dice_score +=1
49
            else:
50
                jaccard += compute_jaccard_index(preds, y)
51
                dice_score += (2*(preds*y).sum())/((preds + y).sum() + 1e-10)
52
            
53
54
    dice_s = dice_score/len(loader) 
55
    jaccard_idx = jaccard/len(loader)
56
57
58
    return dice_s, jaccard_idx
59
60
def main():
61
    x_start = 50
62
    x_end = 462
63
    y_start = 50
64
    y_end = 462 
65
    val_transforms = A.Compose(
66
                               [A.Crop(x_start, y_start, x_end, y_end, always_apply= False),
67
                                A.Resize(height = IMAGE_HEIGHT, width = IMAGE_WIDTH), 
68
                                A.Normalize(
69
                                            mean = [0.0, 0.0, 0.0],
70
                                            std = [1.0, 1.0, 1.0], 
71
                                            max_pixel_value = 255.0
72
                                         ),
73
                              ToTensorV2(), # conversión a tensor de Pytorch
74
                             ],
75
                            )
76
    train_loader, test_loader = get_loaders( # cargar y preprocesar los datos de entrenamiento y validación con base en los formatos tensoriales
77
                                           TRAIN_IMG_DIR, 
78
                                           TRAIN_MAKS_DIR, 
79
                                           TEST_IMAGES, 
80
                                           TEST_MASK,
81
                                           16,
82
                                           None,
83
                                           val_transforms, 
84
                                           16,
85
                                           w_level=40, 
86
                                           w_width=350,
87
                                           pin_memory= True, 
88
                                           normalize=False, 
89
                                           
90
                                          )
91
    model = smp.MAnet(   
92
                encoder_name="mit_b3",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
93
                encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
94
                in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
95
                classes=1,
96
                decoder_use_batchnorm=True,
97
                ).to(DEVICE)
98
    load_checkpoint(torch.load(CHECKPOINT_PATH), model)
99
100
101
    dice_score, jaccard_index = evaluate_training_set(model, test_loader, DEVICE)
102
    print(f"Average Dice Score (Training Set): {dice_score:.4f}")
103
    print(f"Average Jaccard Index (Training Set): {jaccard_index:.4f}")
104
    save_preds_as_imgs(test_loader, model, DEVICE,PREDS_PATH)
105
106
if __name__ == '__main__': 
107
    main()