Diff of /src/utils.py [000000] .. [3475df]

Switch to unified view

a b/src/utils.py
1
"""
2
Developed by: Daniel Crovo
3
4
"""
5
import numpy as np
6
import csv
7
from hs_dataset import HSDataset
8
from torch.utils.data import DataLoader
9
import torch
10
import os
11
import torchvision
12
from torchmetrics.classification import BinaryJaccardIndex
13
import torchvision.transforms as transforms
14
import torchvision.utils as vutils
15
import torch
16
from PIL import Image
17
18
def get_loaders(train_dir, train_maskdir, val_dir, val_maskdir, batch_size, train_transform, 
19
                val_transform, num_workers, w_level, w_width, pin_memory = True, normalize=False):
20
    """_summary_
21
22
    Args:
23
        train_dir (_type_): _description_
24
        train_maskdir (_type_): _description_
25
        val_dir (_type_): _description_
26
        val_maskdir (_type_): _description_
27
        batch_size (_type_): _description_
28
        train_transform (_type_): _description_
29
        val_transform (_type_): _description_
30
        num_workers (_type_): _description_
31
        pin_memory (bool, optional): _description_. Defaults to True.
32
33
    Returns:
34
        _type_: _description_
35
    """
36
37
    train_img_mask = HSDataset(image_dir = train_dir, 
38
                               mask_dir = train_maskdir, transform = train_transform, 
39
                               normalized=normalize, w_level=w_level, w_width=w_width)
40
    train_loader = DataLoader(train_img_mask, batch_size = batch_size, 
41
                              num_workers = num_workers, pin_memory = pin_memory, shuffle = True)
42
    val_img_mask = HSDataset(image_dir = val_dir, mask_dir = val_maskdir, 
43
                             transform = val_transform, normalized=normalize,
44
                             w_level=w_level, w_width=w_width,)
45
    val_loader = DataLoader(val_img_mask, batch_size = batch_size, 
46
                            num_workers = num_workers, pin_memory = pin_memory, shuffle = False)
47
    
48
    return train_loader, val_loader
49
50
51
def load_checkpoint(checkpoint, model): 
52
    try: 
53
        model.load_state_dict(checkpoint['state_dict'])
54
        print('\nCheckpoint importado exitosamente.')
55
    except: 
56
        print('Error en la importación del Checkpoint.')
57
    
58
def save_checkpoint(state, filename = 'my_checkpoint.pth.tar'): 
59
    try: 
60
        torch.save(state, filename)
61
        print('\nCheckpoint almacenado exitosamente.')
62
    except: 
63
        print('Error en la importación del Checkpoint.')
64
65
def compute_jaccard_index(preds, targets):
66
    intersection = torch.logical_and(preds, targets).sum()
67
    union = torch.logical_or(preds, targets).sum()
68
    jaccard = intersection.item() / (union.item()+ 1e-10)
69
    return jaccard
70
def diceCoef(y_true, y_pred, smooth=1.):
71
  y_true_f = y_true.flatten()
72
  y_pred_f = y_pred.flatten()
73
  intersection = np.sum(y_true_f * y_pred_f)
74
  dice = (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth)
75
  return round(float(dice), 3)
76
77
def perf(loader, model, device): 
78
    dice_score = 0.0
79
    jaccard = 0.0
80
81
    model.eval()
82
    
83
    with torch.no_grad(): # deshabilitar el cálculo y almacenamiento de gradientes en el grafo computacional de PyTorch
84
        for x, y in loader: 
85
            x = x.to(device = device)
86
            y = y.to(device).unsqueeze(1)
87
            preds = torch.sigmoid(model(x))
88
            
89
            preds = (preds > 0.5).float()
90
            if((preds.sum() == 0) and (y.sum() == 0)):
91
                jaccard +=1
92
                dice_score +=1
93
            else:
94
                jaccard += compute_jaccard_index(preds, y)
95
                dice_score += (2*(preds*y).sum())/((preds + y).sum() + 1e-10)
96
97
    dice_s = dice_score/len(loader) 
98
    jaccard_idx = jaccard/len(loader)
99
100
    print('\nDice score: {}'.format(dice_s))
101
    print('Jaccard index: {}\n'.format(jaccard_idx))
102
103
    model.train()
104
    return dice_s, jaccard_idx
105
    
106
def save_preds_as_imgs(loader, model, device, folder = '/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Proyecto/saved_images'): 
107
    model.eval()
108
    for idx, (x, y) in enumerate(loader):
109
        x = x.to(device = device)
110
        with torch.no_grad(): # deshabilitar el cálculo y almacenamiento de gradientes en el grafo computacional de PyTorch
111
            preds = torch.sigmoid(model(x))
112
            preds = (preds > 0.5).float()
113
        
114
        torchvision.utils.save_image(preds, f'{folder}/y_hat_{idx}.png') # almacenamiento de máscaras predichas
115
        y = torch.unsqueeze(y, 1).to(torch.float32)
116
        torchvision.utils.save_image(y, f'{folder}/y_{idx}.png') # almacenamiento de máscaras reales
117
        torchvision.utils.save_image(x, f'{folder}/x_{idx}.png') # almacenamiento de máscaras reales
118
119
        #masked_img= add_mask_to_rgb_image(x,preds)
120
121
        #torchvision.utils.save_image(masked_img, f'{folder}/masked_{idx}.png') # almacenamiento de máscaras reales
122
123
    
124
    model.train()
125
126
def add_mask_to_rgb_image(rgb_image_tensor, mask_tensor):
127
    # Apply the mask to the RGB image tensor
128
    masked_image_tensor = torch.where(mask_tensor > 0, torch.tensor([32,255, 0, 0]), rgb_image_tensor)
129
130
    return masked_image_tensor    #print(dice_score)
131
    #train_loss = [tensor.cpu() for tensor in train_loss]
132
    #train_dice = [tensor.cpu() for tensor in train_dice]
133
134
135
def save_metrics_to_csv(epoch, train_loss, train_dice, train_jaccard, dice_score, jaccard_score, filename):
136
    metrics = {
137
        'epoch': epoch,
138
        'train_loss': train_loss,
139
        'train_dice': train_dice,
140
        'train_jaccard': train_jaccard,
141
        'dice_score': dice_score,
142
        'jaccard_score': jaccard_score
143
    }
144
    
145
    file_exists = os.path.isfile(filename)
146
    
147
    with open(filename, mode='a', newline='') as file:
148
        writer = csv.DictWriter(file, fieldnames=metrics.keys())
149
        
150
        if not file_exists:
151
            writer.writeheader()
152
        
153
        writer.writerow(metrics)
154
155
156
def save_preds_as_imgs2(loader, model, device, folder='/home/danielcrovo/Documents/01.Study/01.MSc/02.MSc AI/Deep Learning/Proyecto/saved_images'):
157
    model.eval()
158
    for idx, (x, y) in enumerate(loader):
159
        x = x.to(device=device)
160
        with torch.no_grad():
161
            preds = torch.sigmoid(model(x))
162
            preds = (preds > 0.5).float()
163
164
        # Convert single-channel mask to RGB format
165
        y_rgb = y.repeat(1, 3, 1, 1)
166
167
        # Merge predicted mask with input image
168
        merged_img = x.clone()
169
        merged_img[:, :3] = torch.where(preds > 0, torch.tensor([1.0, 0.0, 0.0]), merged_img[:, :3])
170
171
        # Save the merged image
172
        vutils.save_image(merged_img, f'{folder}/merged_{idx}.png')
173