|
a |
|
b/clinical_ts/basic_conv1d.py |
|
|
1 |
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/11_basic_conv1d.ipynb (unless otherwise specified). |
|
|
2 |
|
|
|
3 |
__all__ = ['cd_adaptiveconcatpool', 'attrib_adaptiveconcatpool', 'AdaptiveConcatPool1d', 'SqueezeExcite1d', |
|
|
4 |
'weight_init', 'create_head1d', 'basic_conv1d', 'fcn', 'fcn_wang', 'schirrmeister', 'sen', 'basic1d'] |
|
|
5 |
|
|
|
6 |
# Cell |
|
|
7 |
import torch |
|
|
8 |
import torch.nn as nn |
|
|
9 |
import torch.nn.functional as F |
|
|
10 |
import math |
|
|
11 |
|
|
|
12 |
#from fastai.layers import * |
|
|
13 |
#from fastai.core import * |
|
|
14 |
from typing import Iterable |
|
|
15 |
|
|
|
16 |
class Flatten(nn.Module): |
|
|
17 |
"Flatten `x` to a single dimension, often used at the end of a model. `full` for rank-1 tensor" |
|
|
18 |
def __init__(self, full:bool=False): |
|
|
19 |
super().__init__() |
|
|
20 |
self.full = full |
|
|
21 |
def forward(self, x): return x.view(-1) if self.full else x.view(x.size(0), -1) |
|
|
22 |
|
|
|
23 |
|
|
|
24 |
def listify(p=None, q=None): |
|
|
25 |
"Make `p` listy and the same length as `q`." |
|
|
26 |
if p is None: p=[] |
|
|
27 |
elif isinstance(p, str): p = [p] |
|
|
28 |
elif not isinstance(p, Iterable): p = [p] |
|
|
29 |
#Rank 0 tensors in PyTorch are Iterable but don't have a length. |
|
|
30 |
else: |
|
|
31 |
try: a = len(p) |
|
|
32 |
except: p = [p] |
|
|
33 |
n = q if type(q)==int else len(p) if q is None else len(q) |
|
|
34 |
if len(p)==1: p = p * n |
|
|
35 |
assert len(p)==n, f'List len mismatch ({len(p)} vs {n})' |
|
|
36 |
return list(p) |
|
|
37 |
|
|
|
38 |
|
|
|
39 |
def bn_drop_lin(n_in, n_out, bn=True, p=0., actn=None): |
|
|
40 |
"Sequence of batchnorm (if `bn`), dropout (with `p`) and linear (`n_in`,`n_out`) layers followed by `actn`." |
|
|
41 |
layers = [nn.BatchNorm1d(n_in)] if bn else [] |
|
|
42 |
if p != 0: layers.append(nn.Dropout(p)) |
|
|
43 |
layers.append(nn.Linear(n_in, n_out)) |
|
|
44 |
if actn is not None: layers.append(actn) |
|
|
45 |
return layers |
|
|
46 |
|
|
|
47 |
# Cell |
|
|
48 |
def _conv1d(in_planes,out_planes,kernel_size=3, stride=1, dilation=1, act="relu", bn=True, drop_p=0): |
|
|
49 |
lst=[] |
|
|
50 |
if(drop_p>0): |
|
|
51 |
lst.append(nn.Dropout(drop_p)) |
|
|
52 |
lst.append(nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, dilation=dilation, bias=not(bn))) |
|
|
53 |
if(bn): |
|
|
54 |
lst.append(nn.BatchNorm1d(out_planes)) |
|
|
55 |
if(act=="relu"): |
|
|
56 |
lst.append(nn.ReLU(True)) |
|
|
57 |
if(act=="elu"): |
|
|
58 |
lst.append(nn.ELU(True)) |
|
|
59 |
if(act=="prelu"): |
|
|
60 |
lst.append(nn.PReLU(True)) |
|
|
61 |
return nn.Sequential(*lst) |
|
|
62 |
|
|
|
63 |
def _fc(in_planes,out_planes, act="relu", bn=True): |
|
|
64 |
lst = [nn.Linear(in_planes, out_planes, bias=not(bn))] |
|
|
65 |
if(bn): |
|
|
66 |
lst.append(nn.BatchNorm1d(out_planes)) |
|
|
67 |
if(act=="relu"): |
|
|
68 |
lst.append(nn.ReLU(True)) |
|
|
69 |
if(act=="elu"): |
|
|
70 |
lst.append(nn.ELU(True)) |
|
|
71 |
if(act=="prelu"): |
|
|
72 |
lst.append(nn.PReLU(True)) |
|
|
73 |
return nn.Sequential(*lst) |
|
|
74 |
|
|
|
75 |
class AdaptiveConcatPool1d(nn.Module): |
|
|
76 |
"Layer that concats `AdaptiveAvgPool1d` and `AdaptiveMaxPool1d`." |
|
|
77 |
def __init__(self, sz=None): |
|
|
78 |
"Output will be 2*sz or 2 if sz is None" |
|
|
79 |
super().__init__() |
|
|
80 |
sz = sz or 1 |
|
|
81 |
self.ap,self.mp = nn.AdaptiveAvgPool1d(sz), nn.AdaptiveMaxPool1d(sz) |
|
|
82 |
def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1) |
|
|
83 |
def attrib(self,relevant,irrelevant): |
|
|
84 |
return attrib_adaptiveconcatpool(self,relevant,irrelevant) |
|
|
85 |
|
|
|
86 |
|
|
|
87 |
# Cell |
|
|
88 |
class SqueezeExcite1d(nn.Module): |
|
|
89 |
'''squeeze excite block as used for example in LSTM FCN''' |
|
|
90 |
def __init__(self,channels,reduction=16): |
|
|
91 |
super().__init__() |
|
|
92 |
channels_reduced = channels//reduction |
|
|
93 |
self.w1 = torch.nn.Parameter(torch.randn(channels_reduced,channels).unsqueeze(0)) |
|
|
94 |
self.w2 = torch.nn.Parameter(torch.randn(channels, channels_reduced).unsqueeze(0)) |
|
|
95 |
|
|
|
96 |
def forward(self, x): |
|
|
97 |
#input is bs,ch,seq |
|
|
98 |
z=torch.mean(x,dim=2,keepdim=True)#bs,ch |
|
|
99 |
intermed = F.relu(torch.matmul(self.w1,z))#(1,ch_red,ch * bs,ch,1) = (bs, ch_red, 1) |
|
|
100 |
s=F.sigmoid(torch.matmul(self.w2,intermed))#(1,ch,ch_red * bs, ch_red, 1=bs, ch, 1 |
|
|
101 |
return s*x #bs,ch,seq * bs, ch,1 = bs,ch,seq |
|
|
102 |
|
|
|
103 |
# Cell |
|
|
104 |
def weight_init(m): |
|
|
105 |
'''call weight initialization for model n via n.appy(weight_init)''' |
|
|
106 |
if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear): |
|
|
107 |
nn.init.kaiming_normal_(m.weight) |
|
|
108 |
if m.bias is not None: |
|
|
109 |
nn.init.zeros_(m.bias) |
|
|
110 |
if isinstance(m, nn.BatchNorm1d): |
|
|
111 |
nn.init.constant_(m.weight,1) |
|
|
112 |
nn.init.constant_(m.bias,0) |
|
|
113 |
if isinstance(m,SqueezeExcite1d): |
|
|
114 |
stdv1=math.sqrt(2./m.w1.size[0]) |
|
|
115 |
nn.init.normal_(m.w1,0.,stdv1) |
|
|
116 |
stdv2=math.sqrt(1./m.w2.size[1]) |
|
|
117 |
nn.init.normal_(m.w2,0.,stdv2) |
|
|
118 |
|
|
|
119 |
# Cell |
|
|
120 |
def create_head1d(nf, nc, lin_ftrs=None, ps=0.5, bn_final:bool=False, bn:bool=True, act="relu", concat_pooling=True): |
|
|
121 |
"Model head that takes `nf` features, runs through `lin_ftrs`, and about `nc` classes; added bn and act here" |
|
|
122 |
lin_ftrs = [2*nf if concat_pooling else nf, nc] if lin_ftrs is None else [2*nf if concat_pooling else nf] + lin_ftrs + [nc] #was [nf, 512,nc] |
|
|
123 |
ps = listify(ps) |
|
|
124 |
if len(ps)==1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps |
|
|
125 |
actns = [nn.ReLU(inplace=True) if act=="relu" else nn.ELU(inplace=True)] * (len(lin_ftrs)-2) + [None] |
|
|
126 |
layers = [AdaptiveConcatPool1d() if concat_pooling else nn.MaxPool1d(2), Flatten()] |
|
|
127 |
for ni,no,p,actn in zip(lin_ftrs[:-1],lin_ftrs[1:],ps,actns): |
|
|
128 |
layers += bn_drop_lin(ni,no,bn,p,actn) |
|
|
129 |
if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01)) |
|
|
130 |
return nn.Sequential(*layers) |
|
|
131 |
|
|
|
132 |
# Cell |
|
|
133 |
class basic_conv1d(nn.Sequential): |
|
|
134 |
'''basic conv1d''' |
|
|
135 |
def __init__(self, filters=[128,128,128,128],kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1, squeeze_excite_reduction=0, num_classes=2, input_channels=8, act="relu", bn=True, headless=False,split_first_layer=False,drop_p=0.,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True): |
|
|
136 |
layers = [] |
|
|
137 |
if(isinstance(kernel_size,int)): |
|
|
138 |
kernel_size = [kernel_size]*len(filters) |
|
|
139 |
for i in range(len(filters)): |
|
|
140 |
layers_tmp = [] |
|
|
141 |
|
|
|
142 |
layers_tmp.append(_conv1d(input_channels if i==0 else filters[i-1],filters[i],kernel_size=kernel_size[i],stride=(1 if (split_first_layer is True and i==0) else stride),dilation=dilation,act="none" if ((headless is True and i==len(filters)-1) or (split_first_layer is True and i==0)) else act, bn=False if (headless is True and i==len(filters)-1) else bn,drop_p=(0. if i==0 else drop_p))) |
|
|
143 |
if((split_first_layer is True and i==0)): |
|
|
144 |
layers_tmp.append(_conv1d(filters[0],filters[0],kernel_size=1,stride=1,act=act, bn=bn,drop_p=0.)) |
|
|
145 |
#layers_tmp.append(nn.Linear(filters[0],filters[0],bias=not(bn))) |
|
|
146 |
#layers_tmp.append(_fc(filters[0],filters[0],act=act,bn=bn)) |
|
|
147 |
if(pool>0 and i<len(filters)-1): |
|
|
148 |
layers_tmp.append(nn.MaxPool1d(pool,stride=pool_stride,padding=(pool-1)//2)) |
|
|
149 |
if(squeeze_excite_reduction>0): |
|
|
150 |
layers_tmp.append(SqueezeExcite1d(filters[i],squeeze_excite_reduction)) |
|
|
151 |
layers.append(nn.Sequential(*layers_tmp)) |
|
|
152 |
|
|
|
153 |
#head |
|
|
154 |
#layers.append(nn.AdaptiveAvgPool1d(1)) |
|
|
155 |
#layers.append(nn.Linear(filters[-1],num_classes)) |
|
|
156 |
#head #inplace=True leads to a runtime error see ReLU+ dropout https://discuss.pytorch.org/t/relu-dropout-inplace/13467/5 |
|
|
157 |
self.headless = headless |
|
|
158 |
if(headless is True): |
|
|
159 |
head = nn.Sequential(nn.AdaptiveAvgPool1d(1),Flatten()) |
|
|
160 |
else: |
|
|
161 |
head=create_head1d(filters[-1], nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head, bn_final=bn_final_head, bn=bn_head, act=act_head, concat_pooling=concat_pooling) |
|
|
162 |
layers.append(head) |
|
|
163 |
|
|
|
164 |
super().__init__(*layers) |
|
|
165 |
|
|
|
166 |
def get_layer_groups(self): |
|
|
167 |
return (self[2],self[-1]) |
|
|
168 |
|
|
|
169 |
def get_output_layer(self): |
|
|
170 |
if self.headless is False: |
|
|
171 |
return self[-1][-1] |
|
|
172 |
else: |
|
|
173 |
return None |
|
|
174 |
|
|
|
175 |
def set_output_layer(self,x): |
|
|
176 |
if self.headless is False: |
|
|
177 |
self[-1][-1] = x |
|
|
178 |
|
|
|
179 |
|
|
|
180 |
# Cell |
|
|
181 |
def fcn(filters=[128]*5,num_classes=2,input_channels=8,**kwargs): |
|
|
182 |
filters_in = filters + [num_classes] |
|
|
183 |
return basic_conv1d(filters=filters_in,kernel_size=3,stride=1,pool=2,pool_stride=2,input_channels=input_channels,act="relu",bn=True,headless=True) |
|
|
184 |
|
|
|
185 |
def fcn_wang(num_classes=2,input_channels=8,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs): |
|
|
186 |
return basic_conv1d(filters=[128,256,128],kernel_size=[8,5,3],stride=1,pool=0,pool_stride=2, num_classes=num_classes,input_channels=input_channels,act="relu",bn=True,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling) |
|
|
187 |
|
|
|
188 |
def schirrmeister(num_classes=2,input_channels=8,kernel_size=10,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs): |
|
|
189 |
return basic_conv1d(filters=[25,50,100,200],kernel_size=kernel_size, stride=3, pool=3, pool_stride=1, num_classes=num_classes, input_channels=input_channels, act="relu", bn=True, headless=False,split_first_layer=True,drop_p=0.5,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling) |
|
|
190 |
|
|
|
191 |
def sen(filters=[128]*5,num_classes=2,input_channels=8,kernel_size=3,squeeze_excite_reduction=16,drop_p=0.,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs): |
|
|
192 |
return basic_conv1d(filters=filters,kernel_size=kernel_size,stride=2,pool=0,pool_stride=0,input_channels=input_channels,act="relu",bn=True,num_classes=num_classes,squeeze_excite_reduction=squeeze_excite_reduction,drop_p=drop_p,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling) |
|
|
193 |
|
|
|
194 |
def basic1d(filters=[128]*5,kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1, squeeze_excite_reduction=0, num_classes=2, input_channels=8, act="relu", bn=True, headless=False,drop_p=0.,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs): |
|
|
195 |
return basic_conv1d(filters=filters,kernel_size=kernel_size, stride=stride, dilation=dilation, pool=pool, pool_stride=pool_stride, squeeze_excite_reduction=squeeze_excite_reduction, num_classes=num_classes, input_channels=input_channels, act=act, bn=bn, headless=headless,drop_p=drop_p,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling) |