|
a |
|
b/src/Unet.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
""" |
|
|
3 |
Desarrollado por: Wilson Javier Arenas López |
|
|
4 |
Tema: UNET (CNNs & DCNNs) |
|
|
5 |
Objetivo: Segmentar una imagen a través de la codificación (CNN) y decodificación (DCNN) de la misma, |
|
|
6 |
teniendo como referencia la BD "Carvana Image Masking Challenge" y la arquitectura UNET |
|
|
7 |
(https://arxiv.org/abs/1505.04597) |
|
|
8 |
Parte: II (Construcción del modelo UNET) |
|
|
9 |
Fecha: 24/04/2023 |
|
|
10 |
""" |
|
|
11 |
|
|
|
12 |
import torch.nn as nn # para crear, definir y personalizar diferentes tipos de capas, modelos y criterios de pérdida en DL |
|
|
13 |
import torch # para optimizar procesos de DL |
|
|
14 |
import torchvision.transforms.functional as TF # para trasformaciones de tensores |
|
|
15 |
|
|
|
16 |
class DoubleConv(nn.Module): |
|
|
17 |
def __init__(self, in_channels, out_channels): |
|
|
18 |
super(DoubleConv, self).__init__() |
|
|
19 |
self.conv = nn.Sequential( |
|
|
20 |
# Conv1 (flechas azules) (ver arquitectura UNET): |
|
|
21 |
nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False), |
|
|
22 |
nn.BatchNorm2d(out_channels), |
|
|
23 |
nn.ReLU(inplace = True), |
|
|
24 |
# Conv2 (flechas azules) (ver arquitectura UNET): |
|
|
25 |
nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False), |
|
|
26 |
nn.BatchNorm2d(out_channels), |
|
|
27 |
nn.ReLU(inplace = True) |
|
|
28 |
) |
|
|
29 |
|
|
|
30 |
def forward(self, X): |
|
|
31 |
return self.conv(X) |
|
|
32 |
|
|
|
33 |
class UNET(nn.Module): |
|
|
34 |
def __init__(self, in_channels = 3, out_channels = 1, feature_maps = [64, 128, 256, 512]): # asumiendo una salida binaria (máscara) |
|
|
35 |
super(UNET, self).__init__() |
|
|
36 |
self.downs = nn.ModuleList() |
|
|
37 |
self.ups = nn.ModuleList() |
|
|
38 |
self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2) # flechas rojas (ver arquitectura UNET) |
|
|
39 |
|
|
|
40 |
# CNN (downs): |
|
|
41 |
for feat in feature_maps: |
|
|
42 |
self.downs.append(DoubleConv(in_channels, feat)) |
|
|
43 |
in_channels = feat # actualizamos con base en los feature_maps |
|
|
44 |
|
|
|
45 |
# Bottleneck: |
|
|
46 |
self.bottleneck = DoubleConv(feature_maps[-1], feature_maps[-1]*2) |
|
|
47 |
|
|
|
48 |
# DCNN (ups): |
|
|
49 |
for feat in reversed(feature_maps): |
|
|
50 |
self.ups.append(nn.ConvTranspose2d(feat*2, feat, kernel_size = 2, stride = 2)) # flechas verdes (ver arquitectura UNET) |
|
|
51 |
self.ups.append(DoubleConv(feat*2, feat)) # flechas azules (ver arquitectura UNET) |
|
|
52 |
|
|
|
53 |
# Output: |
|
|
54 |
self.final_conv = nn.Conv2d(feature_maps[0], out_channels, kernel_size = 1) # flecha cian (ver arquitectura UNET) |
|
|
55 |
|
|
|
56 |
def forward(self, X): |
|
|
57 |
skip_connections = [] |
|
|
58 |
|
|
|
59 |
# CNN (downs): |
|
|
60 |
for down in self.downs: |
|
|
61 |
X = down(X) |
|
|
62 |
skip_connections.append(X) |
|
|
63 |
X = self.pool(X) |
|
|
64 |
|
|
|
65 |
# Bottleneck: |
|
|
66 |
X = self.bottleneck(X) |
|
|
67 |
skip_connections = skip_connections[::-1] # inversión del orden de las skip connections |
|
|
68 |
|
|
|
69 |
# DCNN (ups): |
|
|
70 |
for idx in range(0, len(self.ups), 2): # el 2 es debido a la doble convolución |
|
|
71 |
X = self.ups[idx](X) |
|
|
72 |
skip_connection = skip_connections[idx//2] # módulo de la división |
|
|
73 |
|
|
|
74 |
if X.shape != skip_connection.shape: # para garantizar la igualdad de dimensiones |
|
|
75 |
X = TF.resize(X, size = skip_connection.shape[2:]) # solo me interesa altura y anchura |
|
|
76 |
|
|
|
77 |
concat_skip = torch.cat((skip_connection, X), dim = 1) |
|
|
78 |
X = self.ups[idx + 1](concat_skip) |
|
|
79 |
|
|
|
80 |
return self.final_conv(X) |
|
|
81 |
|
|
|
82 |
# def val_size(): |
|
|
83 |
# X = torch.randn((3, 1, 255, 255)) |
|
|
84 |
# model = UNET(in_channels = 1, out_channels = 1) |
|
|
85 |
# preds = model(X) |
|
|
86 |
# print('\nInput size: {}'.format(X.shape)) |
|
|
87 |
# print('Output size: {}'.format(preds.shape)) |
|
|
88 |
|
|
|
89 |
# if __name__ == '__main__': |
|
|
90 |
# val_size() |