|
a |
|
b/clinical_ts/xresnet1d.py |
|
|
1 |
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/13_xresnet1d.ipynb (unless otherwise specified). |
|
|
2 |
|
|
|
3 |
__all__ = ['delegates', 'store_attr', 'init_default', 'BatchNorm', 'NormType', 'ConvLayer', 'AdaptiveAvgPool', |
|
|
4 |
'MaxPool', 'AvgPool', 'ResBlock', 'init_cnn', 'XResNet1d', 'xresnet1d18', 'xresnet1d34', 'xresnet1d50', |
|
|
5 |
'xresnet1d101', 'xresnet1d152', 'xresnet1d18_deep', 'xresnet1d34_deep', 'xresnet1d50_deep', |
|
|
6 |
'xresnet1d18_deeper', 'xresnet1d34_deeper', 'xresnet1d50_deeper'] |
|
|
7 |
|
|
|
8 |
# Cell |
|
|
9 |
import torch |
|
|
10 |
import torch.nn as nn |
|
|
11 |
import torch.nn.functional as F |
|
|
12 |
|
|
|
13 |
from .basic_conv1d import create_head1d, Flatten |
|
|
14 |
|
|
|
15 |
from enum import Enum |
|
|
16 |
import re |
|
|
17 |
|
|
|
18 |
# Cell |
|
|
19 |
import inspect |
|
|
20 |
|
|
|
21 |
def delegates(to=None, keep=False): |
|
|
22 |
"Decorator: replace `**kwargs` in signature with params from `to`" |
|
|
23 |
def _f(f): |
|
|
24 |
if to is None: to_f,from_f = f.__base__.__init__,f.__init__ |
|
|
25 |
else: to_f,from_f = to,f |
|
|
26 |
sig = inspect.signature(from_f) |
|
|
27 |
sigd = dict(sig.parameters) |
|
|
28 |
k = sigd.pop('kwargs') |
|
|
29 |
s2 = {k:v for k,v in inspect.signature(to_f).parameters.items() |
|
|
30 |
if v.default != inspect.Parameter.empty and k not in sigd} |
|
|
31 |
sigd.update(s2) |
|
|
32 |
if keep: sigd['kwargs'] = k |
|
|
33 |
from_f.__signature__ = sig.replace(parameters=sigd.values()) |
|
|
34 |
return f |
|
|
35 |
return _f |
|
|
36 |
|
|
|
37 |
def store_attr(self, nms): |
|
|
38 |
"Store params named in comma-separated `nms` from calling context into attrs in `self`" |
|
|
39 |
mod = inspect.currentframe().f_back.f_locals |
|
|
40 |
for n in re.split(', *', nms): setattr(self,n,mod[n]) |
|
|
41 |
|
|
|
42 |
# Cell |
|
|
43 |
NormType = Enum('NormType', 'Batch BatchZero Weight Spectral Instance InstanceZero') |
|
|
44 |
|
|
|
45 |
def _conv_func(ndim=2, transpose=False): |
|
|
46 |
"Return the proper conv `ndim` function, potentially `transposed`." |
|
|
47 |
assert 1 <= ndim <=3 |
|
|
48 |
return getattr(nn, f'Conv{"Transpose" if transpose else ""}{ndim}d') |
|
|
49 |
|
|
|
50 |
def init_default(m, func=nn.init.kaiming_normal_): |
|
|
51 |
"Initialize `m` weights with `func` and set `bias` to 0." |
|
|
52 |
if func and hasattr(m, 'weight'): func(m.weight) |
|
|
53 |
with torch.no_grad(): |
|
|
54 |
if getattr(m, 'bias', None) is not None: m.bias.fill_(0.) |
|
|
55 |
return m |
|
|
56 |
|
|
|
57 |
def _get_norm(prefix, nf, ndim=2, zero=False, **kwargs): |
|
|
58 |
"Norm layer with `nf` features and `ndim` initialized depending on `norm_type`." |
|
|
59 |
assert 1 <= ndim <= 3 |
|
|
60 |
bn = getattr(nn, f"{prefix}{ndim}d")(nf, **kwargs) |
|
|
61 |
if bn.affine: |
|
|
62 |
bn.bias.data.fill_(1e-3) |
|
|
63 |
bn.weight.data.fill_(0. if zero else 1.) |
|
|
64 |
return bn |
|
|
65 |
|
|
|
66 |
def BatchNorm(nf, ndim=2, norm_type=NormType.Batch, **kwargs): |
|
|
67 |
"BatchNorm layer with `nf` features and `ndim` initialized depending on `norm_type`." |
|
|
68 |
return _get_norm('BatchNorm', nf, ndim, zero=norm_type==NormType.BatchZero, **kwargs) |
|
|
69 |
|
|
|
70 |
# Cell |
|
|
71 |
class ConvLayer(nn.Sequential): |
|
|
72 |
"Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and `norm_type` layers." |
|
|
73 |
def __init__(self, ni, nf, ks=3, stride=1, padding=None, bias=None, ndim=2, norm_type=NormType.Batch, bn_1st=True, |
|
|
74 |
act_cls=nn.ReLU, transpose=False, init=nn.init.kaiming_normal_, xtra=None, **kwargs): |
|
|
75 |
if padding is None: padding = ((ks-1)//2 if not transpose else 0) |
|
|
76 |
bn = norm_type in (NormType.Batch, NormType.BatchZero) |
|
|
77 |
inn = norm_type in (NormType.Instance, NormType.InstanceZero) |
|
|
78 |
if bias is None: bias = not (bn or inn) |
|
|
79 |
conv_func = _conv_func(ndim, transpose=transpose) |
|
|
80 |
conv = init_default(conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding, **kwargs), init) |
|
|
81 |
if norm_type==NormType.Weight: conv = weight_norm(conv) |
|
|
82 |
elif norm_type==NormType.Spectral: conv = spectral_norm(conv) |
|
|
83 |
layers = [conv] |
|
|
84 |
act_bn = [] |
|
|
85 |
if act_cls is not None: act_bn.append(act_cls()) |
|
|
86 |
if bn: act_bn.append(BatchNorm(nf, norm_type=norm_type, ndim=ndim)) |
|
|
87 |
if inn: act_bn.append(InstanceNorm(nf, norm_type=norm_type, ndim=ndim)) |
|
|
88 |
if bn_1st: act_bn.reverse() |
|
|
89 |
layers += act_bn |
|
|
90 |
if xtra: layers.append(xtra) |
|
|
91 |
super().__init__(*layers) |
|
|
92 |
|
|
|
93 |
# Cell |
|
|
94 |
def AdaptiveAvgPool(sz=1, ndim=2): |
|
|
95 |
"nn.AdaptiveAvgPool layer for `ndim`" |
|
|
96 |
assert 1 <= ndim <= 3 |
|
|
97 |
return getattr(nn, f"AdaptiveAvgPool{ndim}d")(sz) |
|
|
98 |
|
|
|
99 |
def MaxPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False): |
|
|
100 |
"nn.MaxPool layer for `ndim`" |
|
|
101 |
assert 1 <= ndim <= 3 |
|
|
102 |
return getattr(nn, f"MaxPool{ndim}d")(ks, stride=stride, padding=padding) |
|
|
103 |
|
|
|
104 |
def AvgPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False): |
|
|
105 |
"nn.AvgPool layer for `ndim`" |
|
|
106 |
assert 1 <= ndim <= 3 |
|
|
107 |
return getattr(nn, f"AvgPool{ndim}d")(ks, stride=stride, padding=padding, ceil_mode=ceil_mode) |
|
|
108 |
|
|
|
109 |
# Cell |
|
|
110 |
class ResBlock(nn.Module): |
|
|
111 |
"Resnet block from `ni` to `nh` with `stride`" |
|
|
112 |
@delegates(ConvLayer.__init__) |
|
|
113 |
def __init__(self, expansion, ni, nf, stride=1, kernel_size=3, groups=1, reduction=None, nh1=None, nh2=None, dw=False, g2=1, |
|
|
114 |
sa=False, sym=False, norm_type=NormType.Batch, act_cls=nn.ReLU, ndim=2, |
|
|
115 |
pool=AvgPool, pool_first=True, **kwargs): |
|
|
116 |
super().__init__() |
|
|
117 |
norm2 = (NormType.BatchZero if norm_type==NormType.Batch else |
|
|
118 |
NormType.InstanceZero if norm_type==NormType.Instance else norm_type) |
|
|
119 |
if nh2 is None: nh2 = nf |
|
|
120 |
if nh1 is None: nh1 = nh2 |
|
|
121 |
nf,ni = nf*expansion,ni*expansion |
|
|
122 |
k0 = dict(norm_type=norm_type, act_cls=act_cls, ndim=ndim, **kwargs) |
|
|
123 |
k1 = dict(norm_type=norm2, act_cls=None, ndim=ndim, **kwargs) |
|
|
124 |
layers = [ConvLayer(ni, nh2, kernel_size, stride=stride, groups=ni if dw else groups, **k0), |
|
|
125 |
ConvLayer(nh2, nf, kernel_size, groups=g2, **k1) |
|
|
126 |
] if expansion == 1 else [ |
|
|
127 |
ConvLayer(ni, nh1, 1, **k0), |
|
|
128 |
ConvLayer(nh1, nh2, kernel_size, stride=stride, groups=nh1 if dw else groups, **k0), |
|
|
129 |
ConvLayer(nh2, nf, 1, groups=g2, **k1)] |
|
|
130 |
self.convs = nn.Sequential(*layers) |
|
|
131 |
convpath = [self.convs] |
|
|
132 |
if reduction: convpath.append(SEModule(nf, reduction=reduction, act_cls=act_cls)) |
|
|
133 |
if sa: convpath.append(SimpleSelfAttention(nf,ks=1,sym=sym)) |
|
|
134 |
self.convpath = nn.Sequential(*convpath) |
|
|
135 |
idpath = [] |
|
|
136 |
if ni!=nf: idpath.append(ConvLayer(ni, nf, 1, act_cls=None, ndim=ndim, **kwargs)) |
|
|
137 |
if stride!=1: idpath.insert((1,0)[pool_first], pool(2, ndim=ndim, ceil_mode=True)) |
|
|
138 |
self.idpath = nn.Sequential(*idpath) |
|
|
139 |
self.act = nn.ReLU(inplace=True) if act_cls is nn.ReLU else act_cls() |
|
|
140 |
|
|
|
141 |
def forward(self, x): return self.act(self.convpath(x) + self.idpath(x)) |
|
|
142 |
|
|
|
143 |
|
|
|
144 |
|
|
|
145 |
# Cell |
|
|
146 |
def init_cnn(m): |
|
|
147 |
if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0) |
|
|
148 |
if isinstance(m, (nn.Conv1d, nn.Conv2d,nn.Linear)): nn.init.kaiming_normal_(m.weight) |
|
|
149 |
for l in m.children(): init_cnn(l) |
|
|
150 |
|
|
|
151 |
# Cell |
|
|
152 |
class XResNet1d(nn.Sequential): |
|
|
153 |
@delegates(ResBlock) |
|
|
154 |
def __init__(self, block, expansion, layers, p=0.0, input_channels=3, num_classes=1000, stem_szs=(32,32,64),kernel_size=5,kernel_size_stem=5, |
|
|
155 |
widen=1.0, sa=False, act_cls=nn.ReLU, lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs): |
|
|
156 |
store_attr(self, 'block,expansion,act_cls') |
|
|
157 |
stem_szs = [input_channels, *stem_szs] |
|
|
158 |
stem = [ConvLayer(stem_szs[i], stem_szs[i+1], ks=kernel_size_stem, stride=2 if i==0 else 1, act_cls=act_cls, ndim=1) |
|
|
159 |
for i in range(3)] |
|
|
160 |
|
|
|
161 |
#block_szs = [int(o*widen) for o in [64,128,256,512] +[256]*(len(layers)-4)] |
|
|
162 |
block_szs = [int(o*widen) for o in [64,64,64,64] +[32]*(len(layers)-4)] |
|
|
163 |
block_szs = [64//expansion] + block_szs |
|
|
164 |
blocks = [self._make_layer(ni=block_szs[i], nf=block_szs[i+1], blocks=l, |
|
|
165 |
stride=1 if i==0 else 2, kernel_size=kernel_size, sa=sa and i==len(layers)-4, ndim=1, **kwargs) |
|
|
166 |
for i,l in enumerate(layers)] |
|
|
167 |
|
|
|
168 |
head = create_head1d(block_szs[-1]*expansion, nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head, bn_final=bn_final_head, bn=bn_head, act=act_head, concat_pooling=concat_pooling) |
|
|
169 |
|
|
|
170 |
super().__init__( |
|
|
171 |
*stem, nn.MaxPool1d(kernel_size=3, stride=2, padding=1), |
|
|
172 |
*blocks, |
|
|
173 |
head, |
|
|
174 |
) |
|
|
175 |
init_cnn(self) |
|
|
176 |
|
|
|
177 |
def _make_layer(self, ni, nf, blocks, stride, kernel_size, sa, **kwargs): |
|
|
178 |
return nn.Sequential( |
|
|
179 |
*[self.block(self.expansion, ni if i==0 else nf, nf, stride=stride if i==0 else 1, |
|
|
180 |
kernel_size=kernel_size, sa=sa and i==(blocks-1), act_cls=self.act_cls, **kwargs) |
|
|
181 |
for i in range(blocks)]) |
|
|
182 |
|
|
|
183 |
def get_layer_groups(self): |
|
|
184 |
return (self[3],self[-1]) |
|
|
185 |
|
|
|
186 |
def get_output_layer(self): |
|
|
187 |
return self[-1][-1] |
|
|
188 |
|
|
|
189 |
def set_output_layer(self,x): |
|
|
190 |
self[-1][-1]=x |
|
|
191 |
|
|
|
192 |
# Cell |
|
|
193 |
def _xresnet1d(expansion, layers, **kwargs): |
|
|
194 |
return XResNet1d(ResBlock, expansion, layers, **kwargs) |
|
|
195 |
|
|
|
196 |
def xresnet1d18 (**kwargs): return _xresnet1d(1, [2, 2, 2, 2], **kwargs) |
|
|
197 |
def xresnet1d34 (**kwargs): return _xresnet1d(1, [3, 4, 6, 3], **kwargs) |
|
|
198 |
def xresnet1d50 (**kwargs): return _xresnet1d(4, [3, 4, 6, 3], **kwargs) |
|
|
199 |
def xresnet1d101(**kwargs): return _xresnet1d(4, [3, 4, 23, 3], **kwargs) |
|
|
200 |
def xresnet1d152(**kwargs): return _xresnet1d(4, [3, 8, 36, 3], **kwargs) |
|
|
201 |
def xresnet1d18_deep (**kwargs): return _xresnet1d(1, [2,2,2,2,1,1], **kwargs) |
|
|
202 |
def xresnet1d34_deep (**kwargs): return _xresnet1d(1, [3,4,6,3,1,1], **kwargs) |
|
|
203 |
def xresnet1d50_deep (**kwargs): return _xresnet1d(4, [3,4,6,3,1,1], **kwargs) |
|
|
204 |
def xresnet1d18_deeper(**kwargs): return _xresnet1d(1, [2,2,1,1,1,1,1,1], **kwargs) |
|
|
205 |
def xresnet1d34_deeper(**kwargs): return _xresnet1d(1, [3,4,6,3,1,1,1,1], **kwargs) |
|
|
206 |
def xresnet1d50_deeper(**kwargs): return _xresnet1d(4, [3,4,6,3,1,1,1,1], **kwargs) |