Diff of /src/Unet.py [000000] .. [3475df]

Switch to unified view

a b/src/Unet.py
1
# -*- coding: utf-8 -*-
2
"""
3
Desarrollado por: Wilson Javier Arenas López
4
Tema: UNET (CNNs & DCNNs)
5
Objetivo: Segmentar una imagen a través de la codificación (CNN) y decodificación (DCNN) de la misma, 
6
          teniendo como referencia la BD "Carvana Image Masking Challenge" y la arquitectura UNET 
7
          (https://arxiv.org/abs/1505.04597)
8
Parte: II (Construcción del modelo UNET)
9
Fecha: 24/04/2023
10
"""
11
12
import torch.nn as nn # para crear, definir y personalizar diferentes tipos de capas, modelos y criterios de pérdida en DL
13
import torch # para optimizar procesos de DL
14
import torchvision.transforms.functional as TF # para trasformaciones de tensores
15
16
class DoubleConv(nn.Module): 
17
    def __init__(self, in_channels, out_channels):
18
        super(DoubleConv, self).__init__()
19
        self.conv = nn.Sequential(
20
            # Conv1 (flechas azules) (ver arquitectura UNET):
21
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False), 
22
            nn.BatchNorm2d(out_channels),
23
            nn.ReLU(inplace = True), 
24
            # Conv2 (flechas azules) (ver arquitectura UNET): 
25
            nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False), 
26
            nn.BatchNorm2d(out_channels),
27
            nn.ReLU(inplace = True)
28
            )
29
            
30
    def forward(self, X): 
31
        return self.conv(X)
32
    
33
class UNET(nn.Module): 
34
    def __init__(self, in_channels = 3, out_channels = 1, feature_maps = [64, 128, 256, 512]): # asumiendo una salida binaria (máscara)
35
        super(UNET, self).__init__()
36
        self.downs = nn.ModuleList()
37
        self.ups = nn.ModuleList()
38
        self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2) # flechas rojas (ver arquitectura UNET)
39
        
40
        # CNN (downs): 
41
        for feat in feature_maps: 
42
            self.downs.append(DoubleConv(in_channels, feat))
43
            in_channels = feat # actualizamos con base en los feature_maps
44
            
45
        # Bottleneck:
46
        self.bottleneck = DoubleConv(feature_maps[-1], feature_maps[-1]*2)
47
        
48
        # DCNN (ups): 
49
        for feat in reversed(feature_maps):
50
            self.ups.append(nn.ConvTranspose2d(feat*2, feat, kernel_size = 2, stride = 2)) # flechas verdes (ver arquitectura UNET)
51
            self.ups.append(DoubleConv(feat*2, feat)) # flechas azules (ver arquitectura UNET)
52
            
53
        # Output: 
54
        self.final_conv = nn.Conv2d(feature_maps[0], out_channels, kernel_size = 1) # flecha cian (ver arquitectura UNET)
55
        
56
    def forward(self, X):
57
        skip_connections = []
58
        
59
        # CNN (downs): 
60
        for down in self.downs: 
61
            X = down(X)
62
            skip_connections.append(X)
63
            X = self.pool(X)            
64
    
65
        # Bottleneck:
66
        X = self.bottleneck(X)
67
        skip_connections = skip_connections[::-1] # inversión del orden de las skip connections
68
69
        # DCNN (ups): 
70
        for idx in range(0, len(self.ups), 2): # el 2 es debido a la doble convolución
71
            X = self.ups[idx](X)
72
            skip_connection = skip_connections[idx//2] # módulo de la división
73
            
74
            if X.shape != skip_connection.shape: # para garantizar la igualdad de dimensiones
75
                X = TF.resize(X, size = skip_connection.shape[2:]) # solo me interesa altura y anchura
76
                
77
            concat_skip = torch.cat((skip_connection, X), dim = 1)
78
            X = self.ups[idx + 1](concat_skip)
79
            
80
        return self.final_conv(X)       
81
    
82
# def val_size(): 
83
#     X = torch.randn((3, 1, 255, 255))
84
#     model = UNET(in_channels = 1, out_channels = 1)
85
#     preds = model(X)
86
#     print('\nInput size: {}'.format(X.shape))
87
#     print('Output size: {}'.format(preds.shape))
88
    
89
# if __name__ == '__main__': 
90
#     val_size()