Switch to unified view

a b/clinical_ts/basic_conv1d.py
1
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/11_basic_conv1d.ipynb (unless otherwise specified).
2
3
__all__ = ['cd_adaptiveconcatpool', 'attrib_adaptiveconcatpool', 'AdaptiveConcatPool1d', 'SqueezeExcite1d',
4
           'weight_init', 'create_head1d', 'basic_conv1d', 'fcn', 'fcn_wang', 'schirrmeister', 'sen', 'basic1d']
5
6
# Cell
7
import torch
8
import torch.nn as nn
9
import torch.nn.functional as F
10
import math
11
12
#from fastai.layers import *
13
#from fastai.core import *
14
from typing import Iterable
15
16
class Flatten(nn.Module):
17
    "Flatten `x` to a single dimension, often used at the end of a model. `full` for rank-1 tensor"
18
    def __init__(self, full:bool=False): 
19
        super().__init__()
20
        self.full = full
21
    def forward(self, x): return x.view(-1) if self.full else x.view(x.size(0), -1)
22
23
24
def listify(p=None, q=None):
25
    "Make `p` listy and the same length as `q`."
26
    if p is None: p=[]
27
    elif isinstance(p, str):          p = [p]
28
    elif not isinstance(p, Iterable): p = [p]
29
    #Rank 0 tensors in PyTorch are Iterable but don't have a length.
30
    else:
31
        try: a = len(p)
32
        except: p = [p]
33
    n = q if type(q)==int else len(p) if q is None else len(q)
34
    if len(p)==1: p = p * n
35
    assert len(p)==n, f'List len mismatch ({len(p)} vs {n})'
36
    return list(p)
37
38
39
def bn_drop_lin(n_in, n_out, bn=True, p=0., actn=None):
40
    "Sequence of batchnorm (if `bn`), dropout (with `p`) and linear (`n_in`,`n_out`) layers followed by `actn`."
41
    layers = [nn.BatchNorm1d(n_in)] if bn else []
42
    if p != 0: layers.append(nn.Dropout(p))
43
    layers.append(nn.Linear(n_in, n_out))
44
    if actn is not None: layers.append(actn)
45
    return layers
46
47
# Cell
48
def _conv1d(in_planes,out_planes,kernel_size=3, stride=1, dilation=1, act="relu", bn=True, drop_p=0):
49
    lst=[]
50
    if(drop_p>0):
51
        lst.append(nn.Dropout(drop_p))
52
    lst.append(nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, dilation=dilation, bias=not(bn)))
53
    if(bn):
54
        lst.append(nn.BatchNorm1d(out_planes))
55
    if(act=="relu"):
56
        lst.append(nn.ReLU(True))
57
    if(act=="elu"):
58
        lst.append(nn.ELU(True))
59
    if(act=="prelu"):
60
        lst.append(nn.PReLU(True))
61
    return nn.Sequential(*lst)
62
63
def _fc(in_planes,out_planes, act="relu", bn=True):
64
    lst = [nn.Linear(in_planes, out_planes, bias=not(bn))]
65
    if(bn):
66
        lst.append(nn.BatchNorm1d(out_planes))
67
    if(act=="relu"):
68
        lst.append(nn.ReLU(True))
69
    if(act=="elu"):
70
        lst.append(nn.ELU(True))
71
    if(act=="prelu"):
72
        lst.append(nn.PReLU(True))
73
    return nn.Sequential(*lst)
74
75
class AdaptiveConcatPool1d(nn.Module):
76
    "Layer that concats `AdaptiveAvgPool1d` and `AdaptiveMaxPool1d`."
77
    def __init__(self, sz=None):
78
        "Output will be 2*sz or 2 if sz is None"
79
        super().__init__()
80
        sz = sz or 1
81
        self.ap,self.mp = nn.AdaptiveAvgPool1d(sz), nn.AdaptiveMaxPool1d(sz)
82
    def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
83
    def attrib(self,relevant,irrelevant):
84
        return attrib_adaptiveconcatpool(self,relevant,irrelevant)
85
86
87
# Cell
88
class SqueezeExcite1d(nn.Module):
89
    '''squeeze excite block as used for example in LSTM FCN'''
90
    def __init__(self,channels,reduction=16):
91
        super().__init__()
92
        channels_reduced = channels//reduction
93
        self.w1 = torch.nn.Parameter(torch.randn(channels_reduced,channels).unsqueeze(0))
94
        self.w2 = torch.nn.Parameter(torch.randn(channels, channels_reduced).unsqueeze(0))
95
96
    def forward(self, x):
97
        #input is bs,ch,seq
98
        z=torch.mean(x,dim=2,keepdim=True)#bs,ch
99
        intermed = F.relu(torch.matmul(self.w1,z))#(1,ch_red,ch * bs,ch,1) = (bs, ch_red, 1)
100
        s=F.sigmoid(torch.matmul(self.w2,intermed))#(1,ch,ch_red * bs, ch_red, 1=bs, ch, 1
101
        return s*x #bs,ch,seq * bs, ch,1 = bs,ch,seq
102
103
# Cell
104
def weight_init(m):
105
    '''call weight initialization for model n via n.appy(weight_init)'''
106
    if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
107
        nn.init.kaiming_normal_(m.weight)
108
        if m.bias is not None:
109
            nn.init.zeros_(m.bias)
110
    if isinstance(m, nn.BatchNorm1d):
111
        nn.init.constant_(m.weight,1)
112
        nn.init.constant_(m.bias,0)
113
    if isinstance(m,SqueezeExcite1d):
114
        stdv1=math.sqrt(2./m.w1.size[0])
115
        nn.init.normal_(m.w1,0.,stdv1)
116
        stdv2=math.sqrt(1./m.w2.size[1])
117
        nn.init.normal_(m.w2,0.,stdv2)
118
119
# Cell
120
def create_head1d(nf, nc, lin_ftrs=None, ps=0.5, bn_final:bool=False, bn:bool=True, act="relu", concat_pooling=True):
121
    "Model head that takes `nf` features, runs through `lin_ftrs`, and about `nc` classes; added bn and act here"
122
    lin_ftrs = [2*nf if concat_pooling else nf, nc] if lin_ftrs is None else [2*nf if concat_pooling else nf] + lin_ftrs + [nc] #was [nf, 512,nc]
123
    ps = listify(ps)
124
    if len(ps)==1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps
125
    actns = [nn.ReLU(inplace=True) if act=="relu" else nn.ELU(inplace=True)] * (len(lin_ftrs)-2) + [None]
126
    layers = [AdaptiveConcatPool1d() if concat_pooling else nn.MaxPool1d(2), Flatten()]
127
    for ni,no,p,actn in zip(lin_ftrs[:-1],lin_ftrs[1:],ps,actns):
128
        layers += bn_drop_lin(ni,no,bn,p,actn)
129
    if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01))
130
    return nn.Sequential(*layers)
131
132
# Cell
133
class basic_conv1d(nn.Sequential):
134
    '''basic conv1d'''
135
    def __init__(self, filters=[128,128,128,128],kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1, squeeze_excite_reduction=0, num_classes=2, input_channels=8, act="relu", bn=True, headless=False,split_first_layer=False,drop_p=0.,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True):
136
        layers = []
137
        if(isinstance(kernel_size,int)):
138
            kernel_size = [kernel_size]*len(filters)
139
        for i in range(len(filters)):
140
            layers_tmp = []
141
142
            layers_tmp.append(_conv1d(input_channels if i==0 else filters[i-1],filters[i],kernel_size=kernel_size[i],stride=(1 if (split_first_layer is True and i==0) else stride),dilation=dilation,act="none" if ((headless is True and i==len(filters)-1) or (split_first_layer is True and i==0)) else act, bn=False if (headless is True and i==len(filters)-1) else bn,drop_p=(0. if i==0 else drop_p)))
143
            if((split_first_layer is True and i==0)):
144
                layers_tmp.append(_conv1d(filters[0],filters[0],kernel_size=1,stride=1,act=act, bn=bn,drop_p=0.))
145
                #layers_tmp.append(nn.Linear(filters[0],filters[0],bias=not(bn)))
146
                #layers_tmp.append(_fc(filters[0],filters[0],act=act,bn=bn))
147
            if(pool>0 and i<len(filters)-1):
148
                layers_tmp.append(nn.MaxPool1d(pool,stride=pool_stride,padding=(pool-1)//2))
149
            if(squeeze_excite_reduction>0):
150
                layers_tmp.append(SqueezeExcite1d(filters[i],squeeze_excite_reduction))
151
            layers.append(nn.Sequential(*layers_tmp))
152
153
        #head
154
        #layers.append(nn.AdaptiveAvgPool1d(1))
155
        #layers.append(nn.Linear(filters[-1],num_classes))
156
        #head #inplace=True leads to a runtime error see ReLU+ dropout https://discuss.pytorch.org/t/relu-dropout-inplace/13467/5
157
        self.headless = headless
158
        if(headless is True):
159
            head = nn.Sequential(nn.AdaptiveAvgPool1d(1),Flatten())
160
        else:
161
            head=create_head1d(filters[-1], nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head, bn_final=bn_final_head, bn=bn_head, act=act_head, concat_pooling=concat_pooling)
162
        layers.append(head)
163
164
        super().__init__(*layers)
165
166
    def get_layer_groups(self):
167
        return (self[2],self[-1])
168
169
    def get_output_layer(self):
170
        if self.headless is False:
171
            return self[-1][-1]
172
        else:
173
            return None
174
175
    def set_output_layer(self,x):
176
        if self.headless is False:
177
            self[-1][-1] = x
178
179
180
# Cell
181
def fcn(filters=[128]*5,num_classes=2,input_channels=8,**kwargs):
182
    filters_in = filters + [num_classes]
183
    return basic_conv1d(filters=filters_in,kernel_size=3,stride=1,pool=2,pool_stride=2,input_channels=input_channels,act="relu",bn=True,headless=True)
184
185
def fcn_wang(num_classes=2,input_channels=8,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs):
186
    return basic_conv1d(filters=[128,256,128],kernel_size=[8,5,3],stride=1,pool=0,pool_stride=2, num_classes=num_classes,input_channels=input_channels,act="relu",bn=True,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)
187
188
def schirrmeister(num_classes=2,input_channels=8,kernel_size=10,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs):
189
    return basic_conv1d(filters=[25,50,100,200],kernel_size=kernel_size, stride=3, pool=3, pool_stride=1, num_classes=num_classes, input_channels=input_channels, act="relu", bn=True, headless=False,split_first_layer=True,drop_p=0.5,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)
190
191
def sen(filters=[128]*5,num_classes=2,input_channels=8,kernel_size=3,squeeze_excite_reduction=16,drop_p=0.,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs):
192
    return basic_conv1d(filters=filters,kernel_size=kernel_size,stride=2,pool=0,pool_stride=0,input_channels=input_channels,act="relu",bn=True,num_classes=num_classes,squeeze_excite_reduction=squeeze_excite_reduction,drop_p=drop_p,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)
193
194
def basic1d(filters=[128]*5,kernel_size=3, stride=2, dilation=1, pool=0, pool_stride=1, squeeze_excite_reduction=0, num_classes=2, input_channels=8, act="relu", bn=True, headless=False,drop_p=0.,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True, **kwargs):
195
    return basic_conv1d(filters=filters,kernel_size=kernel_size, stride=stride, dilation=dilation, pool=pool, pool_stride=pool_stride, squeeze_excite_reduction=squeeze_excite_reduction, num_classes=num_classes, input_channels=input_channels, act=act, bn=bn, headless=headless,drop_p=drop_p,lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_final_head=bn_final_head, bn_head=bn_head, act_head=act_head, concat_pooling=concat_pooling)