Diff of /resnet.py [000000] .. [a8f942]

Switch to unified view

a b/resnet.py
1
import torch.nn as nn
2
import numpy as np
3
4
5
def _padding(downsample, kernel_size):
6
    """Compute required padding"""
7
    padding = max(0, int(np.floor((kernel_size - downsample + 1) / 2)))
8
    return padding
9
10
11
def _downsample(n_samples_in, n_samples_out):
12
    """Compute downsample rate"""
13
    downsample = int(n_samples_in // n_samples_out)
14
    if downsample < 1:
15
        raise ValueError("Number of samples should always decrease")
16
    if n_samples_in % n_samples_out != 0:
17
        raise ValueError("Number of samples for two consecutive blocks "
18
                         "should always decrease by an integer factor.")
19
    return downsample
20
21
22
class ResBlock1d(nn.Module):
23
    """Residual network unit for unidimensional signals."""
24
25
    def __init__(self, n_filters_in, n_filters_out, downsample, kernel_size, dropout_rate):
26
        if kernel_size % 2 == 0:
27
            raise ValueError("The current implementation only support odd values for `kernel_size`.")
28
        super(ResBlock1d, self).__init__()
29
        # Forward path
30
        padding = _padding(1, kernel_size)
31
        self.conv1 = nn.Conv1d(n_filters_in, n_filters_out, kernel_size, padding=padding, bias=False)
32
        self.bn1 = nn.BatchNorm1d(n_filters_out)
33
        self.relu = nn.ReLU()
34
        self.dropout1 = nn.Dropout(dropout_rate)
35
        padding = _padding(downsample, kernel_size)
36
        self.conv2 = nn.Conv1d(n_filters_out, n_filters_out, kernel_size,
37
                               stride=downsample, padding=padding, bias=False)
38
        self.bn2 = nn.BatchNorm1d(n_filters_out)
39
        self.dropout2 = nn.Dropout(dropout_rate)
40
41
        # Skip connection
42
        skip_connection_layers = []
43
        # Deal with downsampling
44
        if downsample > 1:
45
            maxpool = nn.MaxPool1d(downsample, stride=downsample)
46
            skip_connection_layers += [maxpool]
47
        # Deal with n_filters dimension increase
48
        if n_filters_in != n_filters_out:
49
            conv1x1 = nn.Conv1d(n_filters_in, n_filters_out, 1, bias=False)
50
            skip_connection_layers += [conv1x1]
51
        # Build skip conection layer
52
        if skip_connection_layers:
53
            self.skip_connection = nn.Sequential(*skip_connection_layers)
54
        else:
55
            self.skip_connection = None
56
57
    def forward(self, x, y):
58
        """Residual unit."""
59
        if self.skip_connection is not None:
60
            y = self.skip_connection(y)
61
        else:
62
            y = y
63
        # 1st layer
64
        x = self.conv1(x)
65
        x = self.bn1(x)
66
        x = self.relu(x)
67
        x = self.dropout1(x)
68
69
        # 2nd layer
70
        x = self.conv2(x)
71
        x += y  # Sum skip connection and main connection
72
        y = x
73
        x = self.bn2(x)
74
        x = self.relu(x)
75
        x = self.dropout2(x)
76
        return x, y
77
78
79
class ResNet1d(nn.Module):
80
    """Residual network for unidimensional signals.
81
    Parameters
82
    ----------
83
    input_dim : tuple
84
        Input dimensions. Tuple containing dimensions for the neural network
85
        input tensor. Should be like: ``(n_filters, n_samples)``.
86
    blocks_dim : list of tuples
87
        Dimensions of residual blocks.  The i-th tuple should contain the dimensions
88
        of the output (i-1)-th residual block and the input to the i-th residual
89
        block. Each tuple shoud be like: ``(n_filters, n_samples)``. `n_samples`
90
        for two consecutive samples should always decrease by an integer factor.
91
    dropout_rate: float [0, 1), optional
92
        Dropout rate used in all Dropout layers. Default is 0.8
93
    kernel_size: int, optional
94
        Kernel size for convolutional layers. The current implementation
95
        only supports odd kernel sizes. Default is 17.
96
    References
97
    ----------
98
    .. [1] K. He, X. Zhang, S. Ren, and J. Sun, "Identity Mappings in Deep Residual Networks,"
99
           arXiv:1603.05027, Mar. 2016. https://arxiv.org/pdf/1603.05027.pdf.
100
    .. [2] K. He, X. Zhang, S. Ren, and J. Sun, "Deep Residual Learning for Image Recognition," in 2016 IEEE Conference
101
           on Computer Vision and Pattern Recognition (CVPR), 2016, pp. 770-778. https://arxiv.org/pdf/1512.03385.pdf
102
    """
103
104
    def __init__(self, input_dim, blocks_dim, n_classes, kernel_size=17, dropout_rate=0.8):
105
        super(ResNet1d, self).__init__()
106
        # First layers
107
        n_filters_in, n_filters_out = input_dim[0], blocks_dim[0][0]
108
        n_samples_in, n_samples_out = input_dim[1], blocks_dim[0][1]
109
        downsample = _downsample(n_samples_in, n_samples_out)
110
        padding = _padding(downsample, kernel_size)
111
        self.conv1 = nn.Conv1d(n_filters_in, n_filters_out, kernel_size, bias=False,
112
                               stride=downsample, padding=padding)
113
        self.bn1 = nn.BatchNorm1d(n_filters_out)
114
115
        # Residual block layers
116
        self.res_blocks = []
117
        for i, (n_filters, n_samples) in enumerate(blocks_dim):
118
            n_filters_in, n_filters_out = n_filters_out, n_filters
119
            n_samples_in, n_samples_out = n_samples_out, n_samples
120
            downsample = _downsample(n_samples_in, n_samples_out)
121
            resblk1d = ResBlock1d(n_filters_in, n_filters_out, downsample, kernel_size, dropout_rate)
122
            self.add_module('resblock1d_{0}'.format(i), resblk1d)
123
            self.res_blocks += [resblk1d]
124
125
        # Linear layer
126
        n_filters_last, n_samples_last = blocks_dim[-1]
127
        last_layer_dim = n_filters_last * n_samples_last
128
        self.lin = nn.Linear(last_layer_dim, n_classes)
129
        self.n_blk = len(blocks_dim)
130
131
    def forward(self, x):
132
        """Implement ResNet1d forward propagation"""
133
        # First layers
134
        x = self.conv1(x)
135
        x = self.bn1(x)
136
137
        # Residual blocks
138
        y = x
139
        for blk in self.res_blocks:
140
            x, y = blk(x, y)
141
142
        # Flatten array
143
        x = x.view(x.size(0), -1)
144
145
        # Fully conected layer
146
        x = self.lin(x)
147
        return x
148