Diff of /clinical_ts/xresnet1d.py [000000] .. [134fd7]

Switch to unified view

a b/clinical_ts/xresnet1d.py
1
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/13_xresnet1d.ipynb (unless otherwise specified).
2
3
__all__ = ['delegates', 'store_attr', 'init_default', 'BatchNorm', 'NormType', 'ConvLayer', 'AdaptiveAvgPool',
4
           'MaxPool', 'AvgPool', 'ResBlock', 'init_cnn', 'XResNet1d', 'xresnet1d18', 'xresnet1d34', 'xresnet1d50',
5
           'xresnet1d101', 'xresnet1d152', 'xresnet1d18_deep', 'xresnet1d34_deep', 'xresnet1d50_deep',
6
           'xresnet1d18_deeper', 'xresnet1d34_deeper', 'xresnet1d50_deeper']
7
8
# Cell
9
import torch
10
import torch.nn as nn
11
import torch.nn.functional as F
12
13
from .basic_conv1d import create_head1d, Flatten
14
15
from enum import Enum
16
import re
17
18
# Cell
19
import inspect
20
21
def delegates(to=None, keep=False):
22
    "Decorator: replace `**kwargs` in signature with params from `to`"
23
    def _f(f):
24
        if to is None: to_f,from_f = f.__base__.__init__,f.__init__
25
        else:          to_f,from_f = to,f
26
        sig = inspect.signature(from_f)
27
        sigd = dict(sig.parameters)
28
        k = sigd.pop('kwargs')
29
        s2 = {k:v for k,v in inspect.signature(to_f).parameters.items()
30
              if v.default != inspect.Parameter.empty and k not in sigd}
31
        sigd.update(s2)
32
        if keep: sigd['kwargs'] = k
33
        from_f.__signature__ = sig.replace(parameters=sigd.values())
34
        return f
35
    return _f
36
37
def store_attr(self, nms):
38
    "Store params named in comma-separated `nms` from calling context into attrs in `self`"
39
    mod = inspect.currentframe().f_back.f_locals
40
    for n in re.split(', *', nms): setattr(self,n,mod[n])
41
42
# Cell
43
NormType = Enum('NormType', 'Batch BatchZero Weight Spectral Instance InstanceZero')
44
45
def _conv_func(ndim=2, transpose=False):
46
    "Return the proper conv `ndim` function, potentially `transposed`."
47
    assert 1 <= ndim <=3
48
    return getattr(nn, f'Conv{"Transpose" if transpose else ""}{ndim}d')
49
50
def init_default(m, func=nn.init.kaiming_normal_):
51
    "Initialize `m` weights with `func` and set `bias` to 0."
52
    if func and hasattr(m, 'weight'): func(m.weight)
53
    with torch.no_grad():
54
        if getattr(m, 'bias', None) is not None: m.bias.fill_(0.)
55
    return m
56
57
def _get_norm(prefix, nf, ndim=2, zero=False, **kwargs):
58
    "Norm layer with `nf` features and `ndim` initialized depending on `norm_type`."
59
    assert 1 <= ndim <= 3
60
    bn = getattr(nn, f"{prefix}{ndim}d")(nf, **kwargs)
61
    if bn.affine:
62
        bn.bias.data.fill_(1e-3)
63
        bn.weight.data.fill_(0. if zero else 1.)
64
    return bn
65
66
def BatchNorm(nf, ndim=2, norm_type=NormType.Batch, **kwargs):
67
    "BatchNorm layer with `nf` features and `ndim` initialized depending on `norm_type`."
68
    return _get_norm('BatchNorm', nf, ndim, zero=norm_type==NormType.BatchZero, **kwargs)
69
70
# Cell
71
class ConvLayer(nn.Sequential):
72
    "Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and `norm_type` layers."
73
    def __init__(self, ni, nf, ks=3, stride=1, padding=None, bias=None, ndim=2, norm_type=NormType.Batch, bn_1st=True,
74
                 act_cls=nn.ReLU, transpose=False, init=nn.init.kaiming_normal_, xtra=None, **kwargs):
75
        if padding is None: padding = ((ks-1)//2 if not transpose else 0)
76
        bn = norm_type in (NormType.Batch, NormType.BatchZero)
77
        inn = norm_type in (NormType.Instance, NormType.InstanceZero)
78
        if bias is None: bias = not (bn or inn)
79
        conv_func = _conv_func(ndim, transpose=transpose)
80
        conv = init_default(conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding, **kwargs), init)
81
        if   norm_type==NormType.Weight:   conv = weight_norm(conv)
82
        elif norm_type==NormType.Spectral: conv = spectral_norm(conv)
83
        layers = [conv]
84
        act_bn = []
85
        if act_cls is not None: act_bn.append(act_cls())
86
        if bn: act_bn.append(BatchNorm(nf, norm_type=norm_type, ndim=ndim))
87
        if inn: act_bn.append(InstanceNorm(nf, norm_type=norm_type, ndim=ndim))
88
        if bn_1st: act_bn.reverse()
89
        layers += act_bn
90
        if xtra: layers.append(xtra)
91
        super().__init__(*layers)
92
93
# Cell
94
def AdaptiveAvgPool(sz=1, ndim=2):
95
    "nn.AdaptiveAvgPool layer for `ndim`"
96
    assert 1 <= ndim <= 3
97
    return getattr(nn, f"AdaptiveAvgPool{ndim}d")(sz)
98
99
def MaxPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False):
100
    "nn.MaxPool layer for `ndim`"
101
    assert 1 <= ndim <= 3
102
    return getattr(nn, f"MaxPool{ndim}d")(ks, stride=stride, padding=padding)
103
104
def AvgPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False):
105
    "nn.AvgPool layer for `ndim`"
106
    assert 1 <= ndim <= 3
107
    return getattr(nn, f"AvgPool{ndim}d")(ks, stride=stride, padding=padding, ceil_mode=ceil_mode)
108
109
# Cell
110
class ResBlock(nn.Module):
111
    "Resnet block from `ni` to `nh` with `stride`"
112
    @delegates(ConvLayer.__init__)
113
    def __init__(self, expansion, ni, nf, stride=1, kernel_size=3, groups=1, reduction=None, nh1=None, nh2=None, dw=False, g2=1,
114
                 sa=False, sym=False, norm_type=NormType.Batch, act_cls=nn.ReLU, ndim=2,
115
                 pool=AvgPool, pool_first=True, **kwargs):
116
        super().__init__()
117
        norm2 = (NormType.BatchZero if norm_type==NormType.Batch else
118
                 NormType.InstanceZero if norm_type==NormType.Instance else norm_type)
119
        if nh2 is None: nh2 = nf
120
        if nh1 is None: nh1 = nh2
121
        nf,ni = nf*expansion,ni*expansion
122
        k0 = dict(norm_type=norm_type, act_cls=act_cls, ndim=ndim, **kwargs)
123
        k1 = dict(norm_type=norm2, act_cls=None, ndim=ndim, **kwargs)
124
        layers  = [ConvLayer(ni,  nh2, kernel_size, stride=stride, groups=ni if dw else groups, **k0),
125
                   ConvLayer(nh2,  nf, kernel_size, groups=g2, **k1)
126
        ] if expansion == 1 else [
127
                   ConvLayer(ni,  nh1, 1, **k0),
128
                   ConvLayer(nh1, nh2, kernel_size, stride=stride, groups=nh1 if dw else groups, **k0),
129
                   ConvLayer(nh2,  nf, 1, groups=g2, **k1)]
130
        self.convs = nn.Sequential(*layers)
131
        convpath = [self.convs]
132
        if reduction: convpath.append(SEModule(nf, reduction=reduction, act_cls=act_cls))
133
        if sa: convpath.append(SimpleSelfAttention(nf,ks=1,sym=sym))
134
        self.convpath = nn.Sequential(*convpath)
135
        idpath = []
136
        if ni!=nf: idpath.append(ConvLayer(ni, nf, 1, act_cls=None, ndim=ndim, **kwargs))
137
        if stride!=1: idpath.insert((1,0)[pool_first], pool(2, ndim=ndim, ceil_mode=True))
138
        self.idpath = nn.Sequential(*idpath)
139
        self.act = nn.ReLU(inplace=True) if act_cls is nn.ReLU else act_cls()
140
141
    def forward(self, x): return self.act(self.convpath(x) + self.idpath(x))
142
143
144
145
# Cell
146
def init_cnn(m):
147
    if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)
148
    if isinstance(m, (nn.Conv1d, nn.Conv2d,nn.Linear)): nn.init.kaiming_normal_(m.weight)
149
    for l in m.children(): init_cnn(l)
150
151
# Cell
152
class XResNet1d(nn.Sequential):
153
    @delegates(ResBlock)
154
    def __init__(self, block, expansion, layers, p=0.0, input_channels=3, num_classes=1000, stem_szs=(32,32,64),kernel_size=5,kernel_size_stem=5,
155
                 widen=1.0, sa=False, act_cls=nn.ReLU, lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs):
156
        store_attr(self, 'block,expansion,act_cls')
157
        stem_szs = [input_channels, *stem_szs]
158
        stem = [ConvLayer(stem_szs[i], stem_szs[i+1], ks=kernel_size_stem, stride=2 if i==0 else 1, act_cls=act_cls, ndim=1)
159
                for i in range(3)]
160
161
        #block_szs = [int(o*widen) for o in [64,128,256,512] +[256]*(len(layers)-4)]
162
        block_szs = [int(o*widen) for o in [64,64,64,64] +[32]*(len(layers)-4)]
163
        block_szs = [64//expansion] + block_szs
164
        blocks = [self._make_layer(ni=block_szs[i], nf=block_szs[i+1], blocks=l,
165
                                   stride=1 if i==0 else 2, kernel_size=kernel_size, sa=sa and i==len(layers)-4, ndim=1, **kwargs)
166
                  for i,l in enumerate(layers)]
167
168
        head = create_head1d(block_szs[-1]*expansion, nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head, bn_final=bn_final_head, bn=bn_head, act=act_head, concat_pooling=concat_pooling)
169
170
        super().__init__(
171
            *stem, nn.MaxPool1d(kernel_size=3, stride=2, padding=1),
172
            *blocks,
173
            head,
174
        )
175
        init_cnn(self)
176
177
    def _make_layer(self, ni, nf, blocks, stride, kernel_size, sa, **kwargs):
178
        return nn.Sequential(
179
            *[self.block(self.expansion, ni if i==0 else nf, nf, stride=stride if i==0 else 1,
180
                      kernel_size=kernel_size, sa=sa and i==(blocks-1), act_cls=self.act_cls, **kwargs)
181
              for i in range(blocks)])
182
183
    def get_layer_groups(self):
184
        return (self[3],self[-1])
185
186
    def get_output_layer(self):
187
        return self[-1][-1]
188
189
    def set_output_layer(self,x):
190
        self[-1][-1]=x
191
192
# Cell
193
def _xresnet1d(expansion, layers, **kwargs):
194
    return XResNet1d(ResBlock, expansion, layers, **kwargs)
195
196
def xresnet1d18 (**kwargs): return _xresnet1d(1, [2, 2,  2, 2], **kwargs)
197
def xresnet1d34 (**kwargs): return _xresnet1d(1, [3, 4,  6, 3], **kwargs)
198
def xresnet1d50 (**kwargs): return _xresnet1d(4, [3, 4,  6, 3], **kwargs)
199
def xresnet1d101(**kwargs): return _xresnet1d(4, [3, 4, 23, 3], **kwargs)
200
def xresnet1d152(**kwargs): return _xresnet1d(4, [3, 8, 36, 3], **kwargs)
201
def xresnet1d18_deep  (**kwargs): return _xresnet1d(1, [2,2,2,2,1,1], **kwargs)
202
def xresnet1d34_deep  (**kwargs): return _xresnet1d(1, [3,4,6,3,1,1], **kwargs)
203
def xresnet1d50_deep  (**kwargs): return _xresnet1d(4, [3,4,6,3,1,1], **kwargs)
204
def xresnet1d18_deeper(**kwargs): return _xresnet1d(1, [2,2,1,1,1,1,1,1], **kwargs)
205
def xresnet1d34_deeper(**kwargs): return _xresnet1d(1, [3,4,6,3,1,1,1,1], **kwargs)
206
def xresnet1d50_deeper(**kwargs): return _xresnet1d(4, [3,4,6,3,1,1,1,1], **kwargs)