Diff of /clinical_ts/cpc.py [000000] .. [134fd7]

Switch to unified view

a b/clinical_ts/cpc.py
1
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/17_cpc.ipynb (unless otherwise specified).
2
3
__all__ = ['CPCEncoder', 'CPCModel']
4
5
# Cell
6
import torch
7
import torch.nn.functional as F
8
import torch.nn as nn
9
from clinical_ts.basic_conv1d import _conv1d
10
import numpy as np
11
12
from .basic_conv1d import listify, bn_drop_lin
13
14
# Cell
15
class CPCEncoder(nn.Sequential):
16
    'CPC Encoder'
17
    def __init__(self, input_channels, strides=[5,4,2,2,2], kss=[10,8,4,4,4], features=[512,512,512,512],bn=False):
18
        assert(len(strides)==len(kss) and len(strides)==len(features))
19
        lst = []
20
        for i,(s,k,f) in enumerate(zip(strides,kss,features)):
21
            lst.append(_conv1d(input_channels if i==0 else features[i-1],f,kernel_size=k,stride=s,bn=bn))
22
        super().__init__(*lst)
23
        self.downsampling_factor = np.prod(strides)
24
        self.output_dim = features[-1]
25
        # output: bs, output_dim, seq//downsampling_factor
26
    def encode(self, input):
27
        #bs = input.size()[0]
28
        #ch = input.size()[1]
29
        #seq = input.size()[2]
30
        #segments = seq//self.downsampling_factor
31
        #input_encoded = self.forward(input[:,:,:segments*self.downsampling_factor]).transpose(1,2) #bs, seq//downsampling, encoder_output_dim (standard ordering for batch_first RNNs)
32
        input_encoded = self.forward(input).transpose(1,2)
33
        return input_encoded
34
35
# Cell
36
class CPCModel(nn.Module):
37
    "CPC model"
38
    def __init__(self, input_channels, strides=[5,4,2,2,2], kss=[10,8,4,4,4], features=[512,512,512,512],bn_encoder=False, n_hidden=512,n_layers=2,mlp=False,lstm=True,bias_proj=False, num_classes=None, concat_pooling=True, ps_head=0.5,lin_ftrs_head=[512],bn_head=True,skip_encoder=False):
39
        super().__init__()
40
        assert(skip_encoder is False or num_classes is not None)#pretraining only with encoder
41
        self.encoder = CPCEncoder(input_channels,strides=strides,kss=kss,features=features,bn=bn_encoder) if skip_encoder is False else None
42
        self.encoder_output_dim = self.encoder.output_dim if skip_encoder is False else None
43
        self.encoder_downsampling_factor = self.encoder.downsampling_factor if skip_encoder is False else None
44
        self.n_hidden = n_hidden
45
        self.n_layers = n_layers
46
        self.mlp = mlp
47
        
48
        self.num_classes = num_classes
49
        self.concat_pooling = concat_pooling
50
        
51
        self.rnn = nn.LSTM(self.encoder_output_dim if skip_encoder is False else input_channels,n_hidden,num_layers=n_layers,batch_first=True) if lstm is True else nn.GRU(self.encoder.output_dim,n_hidden,num_layers=n_layers,batch_first=True)
52
        
53
        if(num_classes is None): #pretraining
54
            if(mlp):# additional hidden layer as in simclr
55
                self.proj = nn.Sequential(nn.Linear(n_hidden, n_hidden),nn.ReLU(inplace=True),nn.Linear(n_hidden, self.encoder_output_dim,bias=bias_proj))
56
            else:
57
                self.proj = nn.Linear(n_hidden, self.encoder_output_dim,bias=bias_proj)
58
        else: #classifier
59
            #slightly adapted from RNN1d
60
            layers_head =[]
61
            if(self.concat_pooling):
62
                layers_head.append(AdaptiveConcatPoolRNN())
63
64
            #classifier
65
            nf = 3*n_hidden if concat_pooling else n_hidden
66
            lin_ftrs_head = [nf, num_classes] if lin_ftrs_head is None else [nf] + lin_ftrs_head + [num_classes]
67
            ps_head = listify(ps_head)
68
            if len(ps_head)==1:
69
                ps_head = [ps_head[0]/2] * (len(lin_ftrs_head)-2) + ps_head
70
            actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs_head)-2) + [None]
71
72
            for ni,no,p,actn in zip(lin_ftrs_head[:-1],lin_ftrs_head[1:],ps_head,actns):
73
                layers_head+=bn_drop_lin(ni,no,bn_head,p,actn)
74
            self.head=nn.Sequential(*layers_head)
75
            
76
77
    def forward(self, input):
78
        # input shape bs,ch,seq
79
        if(self.encoder is not None):
80
            input_encoded = self.encoder.encode(input)
81
        else:
82
            input_encoded = input.transpose(1,2) #bs, seq, channels
83
        output_rnn, _ = self.rnn(input_encoded) #output_rnn: bs, seq, n_hidden
84
        if(self.num_classes is None):#pretraining
85
            return input_encoded, self.proj(output_rnn)
86
        else:#classifier
87
            output = output_rnn.transpose(1,2)#bs,n_hidden,seq (i.e. standard CNN channel ordering)
88
            if(self.concat_pooling is False):
89
                output = output[:,:,-1]
90
            return self.head(output)
91
        
92
    def get_layer_groups(self):
93
        return (self.encoder,self.rnn,self.head)
94
95
    def get_output_layer(self):
96
        return self.head[-1]
97
98
    def set_output_layer(self,x):
99
        self.head[-1] = x
100
            
101
    def cpc_loss(self,input, target=None, steps_predicted=5, n_false_negatives=9, negatives_from_same_seq_only=False, eval_acc=False):
102
        assert(self.num_classes is None)
103
104
        input_encoded, output = self.forward(input) #input_encoded: bs, seq, features; output: bs,seq,features
105
        input_encoded_flat = input_encoded.reshape(-1,input_encoded.size(2)) #for negatives below: -1, features
106
        
107
        bs = input_encoded.size()[0]
108
        seq = input_encoded.size()[1]
109
        
110
        loss = torch.tensor(0,dtype=torch.float32).to(input.device)
111
        tp_cnt = torch.tensor(0,dtype=torch.int64).to(input.device)
112
        
113
        for i in range(input_encoded.size()[1]-steps_predicted):
114
            positives = input_encoded[:,i+steps_predicted].unsqueeze(1) #bs,1,encoder_output_dim
115
            if(negatives_from_same_seq_only):
116
                idxs = torch.randint(0,(seq-1),(bs*n_false_negatives,)).to(input.device)
117
            else:#negative from everywhere
118
                idxs = torch.randint(0,bs*(seq-1),(bs*n_false_negatives,)).to(input.device)
119
            idxs_seq = torch.remainder(idxs,seq-1) #bs*false_neg
120
            idxs_seq2 = idxs_seq * (idxs_seq<(i+steps_predicted)).long() +(idxs_seq+1)*(idxs_seq>=(i+steps_predicted)).long()#bs*false_neg
121
            if(negatives_from_same_seq_only):
122
                idxs_batch = torch.arange(0,bs).repeat_interleave(n_false_negatives).to(input.device)
123
            else:
124
                idxs_batch = idxs//(seq-1)
125
            idxs2_flat = idxs_batch*seq+idxs_seq2 #for negatives from everywhere: this skips step i+steps_predicted from the other sequences as well for simplicity
126
            
127
            negatives = input_encoded_flat[idxs2_flat].view(bs,n_false_negatives,-1) #bs*false_neg, encoder_output_dim
128
            candidates = torch.cat([positives,negatives],dim=1)#bs,false_neg+1,encoder_output_dim
129
            preds=torch.sum(output[:,i].unsqueeze(1)*candidates,dim=-1) #bs,(false_neg+1)
130
            targs = torch.zeros(bs, dtype=torch.int64).to(input.device)
131
            
132
            if(eval_acc):
133
                preds_argmax = torch.argmax(preds,dim=-1)
134
                tp_cnt += torch.sum(preds_argmax == targs)
135
               
136
            loss += F.cross_entropy(preds,targs)
137
        if(eval_acc):
138
            return loss, tp_cnt.float()/bs/(input_encoded.size()[1]-steps_predicted)
139
        else:
140
            return loss
141
142
#copied from RNN1d
143
class AdaptiveConcatPoolRNN(nn.Module):
144
    def __init__(self, bidirectional=False):
145
        super().__init__()
146
        self.bidirectional = bidirectional
147
    def forward(self,x):
148
        #input shape bs, ch, ts
149
        t1 = nn.AdaptiveAvgPool1d(1)(x)
150
        t2 = nn.AdaptiveMaxPool1d(1)(x)
151
152
        if(self.bidirectional is False):
153
            t3 = x[:,:,-1]
154
        else:
155
            channels = x.size()[1]
156
            t3 = torch.cat([x[:,:channels,-1],x[:,channels:,0]],1)
157
        out=torch.cat([t1.squeeze(-1),t2.squeeze(-1),t3],1) #output shape bs, 3*ch
158
        return out
159
    
160
#class CPCClassifier(nn.Module):
161
#    def __init__(self, cpcmodel, num_classes, concat_pooling=True, ps_head=0.5,lin_ftrs_head=None,bn_head=True):
162
#        super().__init__()
163
#        self.cpcmodel = cpcmodel
164
#        self.concat_pooling = concat_pooling
165
#        
166
#        #slightly adapted from RNN1d
167
#        layers_head =[]
168
#        if(self.concat_pooling):
169
#            layers_head.append(AdaptiveConcatPoolRNN())
170
#
171
#        #classifier
172
#        nf = 3*self.cpcmodel.encoder_output_dim if concat_pooling else self.cpcmodel.encoder_output_dim
173
#        lin_ftrs_head = [nf, num_classes] if lin_ftrs_head is None else [nf] + lin_ftrs_head + [num_classes]
174
#        ps_head = listify(ps_head)
175
#        if len(ps_head)==1:
176
#            ps_head = [ps_head[0]/2] * (len(lin_ftrs_head)-2) + ps_head
177
#        actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs_head)-2) + [None]
178
#
179
#        for ni,no,p,actn in zip(lin_ftrs_head[:-1],lin_ftrs_head[1:],ps_head,actns):
180
#            layers_head+=bn_drop_lin(ni,no,bn_head,p,actn)
181
#        self.head=nn.Sequential(*layers_head)
182
#        
183
#    def forward(self,input):
184
#        output = self.cpcmodel(input)
185
#        if(self.concat_pooling is False):
186
#            output = output[:,:,-1]
187
#        return self.head(output)
188
#    
189
#    def get_layer_groups(self):
190
#        return (self.cpcmodel.encoder,self.cpcmodel.rnn,self.head)
191
#
192
#    def get_output_layer(self):
193
#        return self.head[-1]
194
#
195
#    def set_output_layer(self,x):
196
#        self.head[-1] = x