[134fd7]: / clinical_ts / basic_conv1d.py

Download this file

196 lines (166 with data), 10.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/11_basic_conv1d.ipynb (unless otherwise specified).
__all__ = ['cd_adaptiveconcatpool', 'attrib_adaptiveconcatpool', 'AdaptiveConcatPool1d', 'SqueezeExcite1d',
'weight_init', 'create_head1d', 'basic_conv1d', 'fcn', 'fcn_wang', 'schirrmeister', 'sen', 'basic1d']
# Cell
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
#from fastai.layers import *
#from fastai.core import *
from typing import Iterable
class Flatten(nn.Module):
"Flatten `x` to a single dimension, often used at the end of a model. `full` for rank-1 tensor"
def __init__(self, full:bool=False):
super().__init__()
self.full = full
def forward(self, x): return x.view(-1) if self.full else x.view(x.size(0), -1)
def listify(p=None, q=None):
"Make `p` listy and the same length as `q`."
if p is None: p=[]
elif isinstance(p, str): p = [p]
elif not isinstance(p, Iterable): p = [p]
#Rank 0 tensors in PyTorch are Iterable but don't have a length.
else:
try: a = len(p)
except: p = [p]
n = q if type(q)==int else len(p) if q is None else len(q)
if len(p)==1: p = p * n
assert len(p)==n, f'List len mismatch ({len(p)} vs {n})'
return list(p)
def bn_drop_lin(n_in, n_out, bn=True, p=0., actn=None):
"Sequence of batchnorm (if `bn`), dropout (with `p`) and linear (`n_in`,`n_out`) layers followed by `actn`."
layers = [nn.BatchNorm1d(n_in)] if bn else []
if p != 0: layers.append(nn.Dropout(p))
layers.append(nn.Linear(n_in, n_out))
if actn is not None: layers.append(actn)
return layers
# Cell
def _conv1d(in_planes,out_planes,kernel_size=3, stride=1, dilation=1, act="relu", bn=True, drop_p=0):
lst=[]
if(drop_p>0):
lst.append(nn.Dropout(drop_p))
lst.append(nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, dilation=dilation, bias=not(bn)))
if(bn):
lst.append(nn.BatchNorm1d(out_planes))
if(act=="relu"):
lst.append(nn.ReLU(True))
if(act=="elu"):
lst.append(nn.ELU(True))
if(act=="prelu"):
lst.append(nn.PReLU(True))
return nn.Sequential(*lst)
def _fc(in_planes,out_planes, act="relu", bn=True):
lst = [nn.Linear(in_planes, out_planes, bias=not(bn))]
if(bn):
lst.append(nn.BatchNorm1d(out_planes))
if(act=="relu"):
lst.append(nn.ReLU(True))
if(act=="elu"):
lst.append(nn.ELU(True))
if(act=="prelu"):
lst.append(nn.PReLU(True))
return nn.Sequential(*lst)
class AdaptiveConcatPool1d(nn.Module):
"Layer that concats `AdaptiveAvgPool1d` and `AdaptiveMaxPool1d`."
def __init__(self, sz=None):
"Output will be 2*sz or 2 if sz is None"
super().__init__()
sz = sz or 1
self.ap,self.mp = nn.AdaptiveAvgPool1d(sz), nn.AdaptiveMaxPool1d(sz)
def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
def attrib(self,relevant,irrelevant):
return attrib_adaptiveconcatpool(self,relevant,irrelevant)
# Cell
class SqueezeExcite1d(nn.Module):
'''squeeze excite block as used for example in LSTM FCN'''
def __init__(self,channels,reduction=16):
super().__init__()
channels_reduced = channels//reduction
self.w1 = torch.nn.Parameter(torch.randn(channels_reduced,channels).unsqueeze(0))
self.w2 = torch.nn.Parameter(torch.randn(channels, channels_reduced).unsqueeze(0))
def forward(self, x):
#input is bs,ch,seq
z=torch.mean(x,dim=2,keepdim=True)#bs,ch
intermed = F.relu(torch.matmul(self.w1,z))#(1,ch_red,ch * bs,ch,1) = (bs, ch_red, 1)
s=F.sigmoid(torch.matmul(self.w2,intermed))#(1,ch,ch_red * bs, ch_red, 1=bs, ch, 1
return s*x #bs,ch,seq * bs, ch,1 = bs,ch,seq
# Cell
def weight_init(m):
'''call weight initialization for model n via n.appy(weight_init)'''
if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
if isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight,1)
nn.init.constant_(m.bias,0)
if isinstance(m,SqueezeExcite1d):
stdv1=math.sqrt(2./m.w1.size[0])
nn.init.normal_(m.w1,0.,stdv1)
stdv2=math.sqrt(1./m.w2.size[1])
nn.init.normal_(m.w2,0.,stdv2)
# Cell
def create_head1d(nf, nc, lin_ftrs=None, ps=0.5, bn_final:bool=False, bn:bool=True, act="relu", concat_pooling=True):
"Model head that takes `nf` features, runs through `lin_ftrs`, and about `nc` classes; added bn and act here"
lin_ftrs = [2*nf if concat_pooling else nf, nc] if lin_ftrs is None else [2*nf if concat_pooling else nf] + lin_ftrs + [nc] #was [nf, 512,nc]
ps = listify(ps)
if len(ps)==1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps
actns = [nn.ReLU(inplace=True) if act=="relu" else nn.ELU(inplace=True)] * (len(lin_ftrs)-2) + [None]
layers = [AdaptiveConcatPool1d() if concat_pooling else nn.MaxPool1d(2), Flatten()]
for ni,no,p,actn in zip(lin_ftrs[:-1],lin_ftrs[1:],ps,actns):
layers += bn_drop_lin(ni,no,bn,p,actn)
if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01))
return nn.Sequential(*layers)
# Cell
class basic_conv1d(nn.Sequential):
'''basic conv1d'''
def __init__(self, filters=[128,128,128,128],kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1, squeeze_excite_reduction=0, num_classes=2, input_channels=8, act="relu", bn=True, headless=False,split_first_layer=False,drop_p=0.,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True):
layers = []
if(isinstance(kernel_size,int)):
kernel_size = [kernel_size]*len(filters)
for i in range(len(filters)):
layers_tmp = []
layers_tmp.append(_conv1d(input_channels if i==0 else filters[i-1],filters[i],kernel_size=kernel_size[i],stride=(1 if (split_first_layer is True and i==0) else stride),dilation=dilation,act="none" if ((headless is True and i==len(filters)-1) or (split_first_layer is True and i==0)) else act, bn=False if (headless is True and i==len(filters)-1) else bn,drop_p=(0. if i==0 else drop_p)))
if((split_first_layer is True and i==0)):
layers_tmp.append(_conv1d(filters[0],filters[0],kernel_size=1,stride=1,act=act, bn=bn,drop_p=0.))
#layers_tmp.append(nn.Linear(filters[0],filters[0],bias=not(bn)))
#layers_tmp.append(_fc(filters[0],filters[0],act=act,bn=bn))
if(pool>0 and i<len(filters)-1):
layers_tmp.append(nn.MaxPool1d(pool,stride=pool_stride,padding=(pool-1)//2))
if(squeeze_excite_reduction>0):
layers_tmp.append(SqueezeExcite1d(filters[i],squeeze_excite_reduction))
layers.append(nn.Sequential(*layers_tmp))
#head
#layers.append(nn.AdaptiveAvgPool1d(1))
#layers.append(nn.Linear(filters[-1],num_classes))
#head #inplace=True leads to a runtime error see ReLU+ dropout https://discuss.pytorch.org/t/relu-dropout-inplace/13467/5
self.headless = headless
if(headless is True):
head = nn.Sequential(nn.AdaptiveAvgPool1d(1),Flatten())
else:
head=create_head1d(filters[-1], nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head, bn_final=bn_final_head, bn=bn_head, act=act_head, concat_pooling=concat_pooling)
layers.append(head)
super().__init__(*layers)
def get_layer_groups(self):
return (self[2],self[-1])
def get_output_layer(self):
if self.headless is False:
return self[-1][-1]
else:
return None
def set_output_layer(self,x):
if self.headless is False:
self[-1][-1] = x
# Cell
def fcn(filters=[128]*5,num_classes=2,input_channels=8,**kwargs):
filters_in = filters + [num_classes]
return basic_conv1d(filters=filters_in,kernel_size=3,stride=1,pool=2,pool_stride=2,input_channels=input_channels,act="relu",bn=True,headless=True)
def fcn_wang(num_classes=2,input_channels=8,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs):
return basic_conv1d(filters=[128,256,128],kernel_size=[8,5,3],stride=1,pool=0,pool_stride=2, num_classes=num_classes,input_channels=input_channels,act="relu",bn=True,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)
def schirrmeister(num_classes=2,input_channels=8,kernel_size=10,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs):
return basic_conv1d(filters=[25,50,100,200],kernel_size=kernel_size, stride=3, pool=3, pool_stride=1, num_classes=num_classes, input_channels=input_channels, act="relu", bn=True, headless=False,split_first_layer=True,drop_p=0.5,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)
def sen(filters=[128]*5,num_classes=2,input_channels=8,kernel_size=3,squeeze_excite_reduction=16,drop_p=0.,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs):
return basic_conv1d(filters=filters,kernel_size=kernel_size,stride=2,pool=0,pool_stride=0,input_channels=input_channels,act="relu",bn=True,num_classes=num_classes,squeeze_excite_reduction=squeeze_excite_reduction,drop_p=drop_p,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)
def basic1d(filters=[128]*5,kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1, squeeze_excite_reduction=0, num_classes=2, input_channels=8, act="relu", bn=True, headless=False,drop_p=0.,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs):
return basic_conv1d(filters=filters,kernel_size=kernel_size, stride=stride, dilation=dilation, pool=pool, pool_stride=pool_stride, squeeze_excite_reduction=squeeze_excite_reduction, num_classes=num_classes, input_channels=input_channels, act=act, bn=bn, headless=headless,drop_p=drop_p,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)