|
a |
|
b/clinical_ts/cpc.py |
|
|
1 |
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/17_cpc.ipynb (unless otherwise specified). |
|
|
2 |
|
|
|
3 |
__all__ = ['CPCEncoder', 'CPCModel'] |
|
|
4 |
|
|
|
5 |
# Cell |
|
|
6 |
import torch |
|
|
7 |
import torch.nn.functional as F |
|
|
8 |
import torch.nn as nn |
|
|
9 |
from clinical_ts.basic_conv1d import _conv1d |
|
|
10 |
import numpy as np |
|
|
11 |
|
|
|
12 |
from .basic_conv1d import listify, bn_drop_lin |
|
|
13 |
|
|
|
14 |
# Cell |
|
|
15 |
class CPCEncoder(nn.Sequential): |
|
|
16 |
'CPC Encoder' |
|
|
17 |
def __init__(self, input_channels, strides=[5,4,2,2,2], kss=[10,8,4,4,4], features=[512,512,512,512],bn=False): |
|
|
18 |
assert(len(strides)==len(kss) and len(strides)==len(features)) |
|
|
19 |
lst = [] |
|
|
20 |
for i,(s,k,f) in enumerate(zip(strides,kss,features)): |
|
|
21 |
lst.append(_conv1d(input_channels if i==0 else features[i-1],f,kernel_size=k,stride=s,bn=bn)) |
|
|
22 |
super().__init__(*lst) |
|
|
23 |
self.downsampling_factor = np.prod(strides) |
|
|
24 |
self.output_dim = features[-1] |
|
|
25 |
# output: bs, output_dim, seq//downsampling_factor |
|
|
26 |
def encode(self, input): |
|
|
27 |
#bs = input.size()[0] |
|
|
28 |
#ch = input.size()[1] |
|
|
29 |
#seq = input.size()[2] |
|
|
30 |
#segments = seq//self.downsampling_factor |
|
|
31 |
#input_encoded = self.forward(input[:,:,:segments*self.downsampling_factor]).transpose(1,2) #bs, seq//downsampling, encoder_output_dim (standard ordering for batch_first RNNs) |
|
|
32 |
input_encoded = self.forward(input).transpose(1,2) |
|
|
33 |
return input_encoded |
|
|
34 |
|
|
|
35 |
# Cell |
|
|
36 |
class CPCModel(nn.Module): |
|
|
37 |
"CPC model" |
|
|
38 |
def __init__(self, input_channels, strides=[5,4,2,2,2], kss=[10,8,4,4,4], features=[512,512,512,512],bn_encoder=False, n_hidden=512,n_layers=2,mlp=False,lstm=True,bias_proj=False, num_classes=None, concat_pooling=True, ps_head=0.5,lin_ftrs_head=[512],bn_head=True,skip_encoder=False): |
|
|
39 |
super().__init__() |
|
|
40 |
assert(skip_encoder is False or num_classes is not None)#pretraining only with encoder |
|
|
41 |
self.encoder = CPCEncoder(input_channels,strides=strides,kss=kss,features=features,bn=bn_encoder) if skip_encoder is False else None |
|
|
42 |
self.encoder_output_dim = self.encoder.output_dim if skip_encoder is False else None |
|
|
43 |
self.encoder_downsampling_factor = self.encoder.downsampling_factor if skip_encoder is False else None |
|
|
44 |
self.n_hidden = n_hidden |
|
|
45 |
self.n_layers = n_layers |
|
|
46 |
self.mlp = mlp |
|
|
47 |
|
|
|
48 |
self.num_classes = num_classes |
|
|
49 |
self.concat_pooling = concat_pooling |
|
|
50 |
|
|
|
51 |
self.rnn = nn.LSTM(self.encoder_output_dim if skip_encoder is False else input_channels,n_hidden,num_layers=n_layers,batch_first=True) if lstm is True else nn.GRU(self.encoder.output_dim,n_hidden,num_layers=n_layers,batch_first=True) |
|
|
52 |
|
|
|
53 |
if(num_classes is None): #pretraining |
|
|
54 |
if(mlp):# additional hidden layer as in simclr |
|
|
55 |
self.proj = nn.Sequential(nn.Linear(n_hidden, n_hidden),nn.ReLU(inplace=True),nn.Linear(n_hidden, self.encoder_output_dim,bias=bias_proj)) |
|
|
56 |
else: |
|
|
57 |
self.proj = nn.Linear(n_hidden, self.encoder_output_dim,bias=bias_proj) |
|
|
58 |
else: #classifier |
|
|
59 |
#slightly adapted from RNN1d |
|
|
60 |
layers_head =[] |
|
|
61 |
if(self.concat_pooling): |
|
|
62 |
layers_head.append(AdaptiveConcatPoolRNN()) |
|
|
63 |
|
|
|
64 |
#classifier |
|
|
65 |
nf = 3*n_hidden if concat_pooling else n_hidden |
|
|
66 |
lin_ftrs_head = [nf, num_classes] if lin_ftrs_head is None else [nf] + lin_ftrs_head + [num_classes] |
|
|
67 |
ps_head = listify(ps_head) |
|
|
68 |
if len(ps_head)==1: |
|
|
69 |
ps_head = [ps_head[0]/2] * (len(lin_ftrs_head)-2) + ps_head |
|
|
70 |
actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs_head)-2) + [None] |
|
|
71 |
|
|
|
72 |
for ni,no,p,actn in zip(lin_ftrs_head[:-1],lin_ftrs_head[1:],ps_head,actns): |
|
|
73 |
layers_head+=bn_drop_lin(ni,no,bn_head,p,actn) |
|
|
74 |
self.head=nn.Sequential(*layers_head) |
|
|
75 |
|
|
|
76 |
|
|
|
77 |
def forward(self, input): |
|
|
78 |
# input shape bs,ch,seq |
|
|
79 |
if(self.encoder is not None): |
|
|
80 |
input_encoded = self.encoder.encode(input) |
|
|
81 |
else: |
|
|
82 |
input_encoded = input.transpose(1,2) #bs, seq, channels |
|
|
83 |
output_rnn, _ = self.rnn(input_encoded) #output_rnn: bs, seq, n_hidden |
|
|
84 |
if(self.num_classes is None):#pretraining |
|
|
85 |
return input_encoded, self.proj(output_rnn) |
|
|
86 |
else:#classifier |
|
|
87 |
output = output_rnn.transpose(1,2)#bs,n_hidden,seq (i.e. standard CNN channel ordering) |
|
|
88 |
if(self.concat_pooling is False): |
|
|
89 |
output = output[:,:,-1] |
|
|
90 |
return self.head(output) |
|
|
91 |
|
|
|
92 |
def get_layer_groups(self): |
|
|
93 |
return (self.encoder,self.rnn,self.head) |
|
|
94 |
|
|
|
95 |
def get_output_layer(self): |
|
|
96 |
return self.head[-1] |
|
|
97 |
|
|
|
98 |
def set_output_layer(self,x): |
|
|
99 |
self.head[-1] = x |
|
|
100 |
|
|
|
101 |
def cpc_loss(self,input, target=None, steps_predicted=5, n_false_negatives=9, negatives_from_same_seq_only=False, eval_acc=False): |
|
|
102 |
assert(self.num_classes is None) |
|
|
103 |
|
|
|
104 |
input_encoded, output = self.forward(input) #input_encoded: bs, seq, features; output: bs,seq,features |
|
|
105 |
input_encoded_flat = input_encoded.reshape(-1,input_encoded.size(2)) #for negatives below: -1, features |
|
|
106 |
|
|
|
107 |
bs = input_encoded.size()[0] |
|
|
108 |
seq = input_encoded.size()[1] |
|
|
109 |
|
|
|
110 |
loss = torch.tensor(0,dtype=torch.float32).to(input.device) |
|
|
111 |
tp_cnt = torch.tensor(0,dtype=torch.int64).to(input.device) |
|
|
112 |
|
|
|
113 |
for i in range(input_encoded.size()[1]-steps_predicted): |
|
|
114 |
positives = input_encoded[:,i+steps_predicted].unsqueeze(1) #bs,1,encoder_output_dim |
|
|
115 |
if(negatives_from_same_seq_only): |
|
|
116 |
idxs = torch.randint(0,(seq-1),(bs*n_false_negatives,)).to(input.device) |
|
|
117 |
else:#negative from everywhere |
|
|
118 |
idxs = torch.randint(0,bs*(seq-1),(bs*n_false_negatives,)).to(input.device) |
|
|
119 |
idxs_seq = torch.remainder(idxs,seq-1) #bs*false_neg |
|
|
120 |
idxs_seq2 = idxs_seq * (idxs_seq<(i+steps_predicted)).long() +(idxs_seq+1)*(idxs_seq>=(i+steps_predicted)).long()#bs*false_neg |
|
|
121 |
if(negatives_from_same_seq_only): |
|
|
122 |
idxs_batch = torch.arange(0,bs).repeat_interleave(n_false_negatives).to(input.device) |
|
|
123 |
else: |
|
|
124 |
idxs_batch = idxs//(seq-1) |
|
|
125 |
idxs2_flat = idxs_batch*seq+idxs_seq2 #for negatives from everywhere: this skips step i+steps_predicted from the other sequences as well for simplicity |
|
|
126 |
|
|
|
127 |
negatives = input_encoded_flat[idxs2_flat].view(bs,n_false_negatives,-1) #bs*false_neg, encoder_output_dim |
|
|
128 |
candidates = torch.cat([positives,negatives],dim=1)#bs,false_neg+1,encoder_output_dim |
|
|
129 |
preds=torch.sum(output[:,i].unsqueeze(1)*candidates,dim=-1) #bs,(false_neg+1) |
|
|
130 |
targs = torch.zeros(bs, dtype=torch.int64).to(input.device) |
|
|
131 |
|
|
|
132 |
if(eval_acc): |
|
|
133 |
preds_argmax = torch.argmax(preds,dim=-1) |
|
|
134 |
tp_cnt += torch.sum(preds_argmax == targs) |
|
|
135 |
|
|
|
136 |
loss += F.cross_entropy(preds,targs) |
|
|
137 |
if(eval_acc): |
|
|
138 |
return loss, tp_cnt.float()/bs/(input_encoded.size()[1]-steps_predicted) |
|
|
139 |
else: |
|
|
140 |
return loss |
|
|
141 |
|
|
|
142 |
#copied from RNN1d |
|
|
143 |
class AdaptiveConcatPoolRNN(nn.Module): |
|
|
144 |
def __init__(self, bidirectional=False): |
|
|
145 |
super().__init__() |
|
|
146 |
self.bidirectional = bidirectional |
|
|
147 |
def forward(self,x): |
|
|
148 |
#input shape bs, ch, ts |
|
|
149 |
t1 = nn.AdaptiveAvgPool1d(1)(x) |
|
|
150 |
t2 = nn.AdaptiveMaxPool1d(1)(x) |
|
|
151 |
|
|
|
152 |
if(self.bidirectional is False): |
|
|
153 |
t3 = x[:,:,-1] |
|
|
154 |
else: |
|
|
155 |
channels = x.size()[1] |
|
|
156 |
t3 = torch.cat([x[:,:channels,-1],x[:,channels:,0]],1) |
|
|
157 |
out=torch.cat([t1.squeeze(-1),t2.squeeze(-1),t3],1) #output shape bs, 3*ch |
|
|
158 |
return out |
|
|
159 |
|
|
|
160 |
#class CPCClassifier(nn.Module): |
|
|
161 |
# def __init__(self, cpcmodel, num_classes, concat_pooling=True, ps_head=0.5,lin_ftrs_head=None,bn_head=True): |
|
|
162 |
# super().__init__() |
|
|
163 |
# self.cpcmodel = cpcmodel |
|
|
164 |
# self.concat_pooling = concat_pooling |
|
|
165 |
# |
|
|
166 |
# #slightly adapted from RNN1d |
|
|
167 |
# layers_head =[] |
|
|
168 |
# if(self.concat_pooling): |
|
|
169 |
# layers_head.append(AdaptiveConcatPoolRNN()) |
|
|
170 |
# |
|
|
171 |
# #classifier |
|
|
172 |
# nf = 3*self.cpcmodel.encoder_output_dim if concat_pooling else self.cpcmodel.encoder_output_dim |
|
|
173 |
# lin_ftrs_head = [nf, num_classes] if lin_ftrs_head is None else [nf] + lin_ftrs_head + [num_classes] |
|
|
174 |
# ps_head = listify(ps_head) |
|
|
175 |
# if len(ps_head)==1: |
|
|
176 |
# ps_head = [ps_head[0]/2] * (len(lin_ftrs_head)-2) + ps_head |
|
|
177 |
# actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs_head)-2) + [None] |
|
|
178 |
# |
|
|
179 |
# for ni,no,p,actn in zip(lin_ftrs_head[:-1],lin_ftrs_head[1:],ps_head,actns): |
|
|
180 |
# layers_head+=bn_drop_lin(ni,no,bn_head,p,actn) |
|
|
181 |
# self.head=nn.Sequential(*layers_head) |
|
|
182 |
# |
|
|
183 |
# def forward(self,input): |
|
|
184 |
# output = self.cpcmodel(input) |
|
|
185 |
# if(self.concat_pooling is False): |
|
|
186 |
# output = output[:,:,-1] |
|
|
187 |
# return self.head(output) |
|
|
188 |
# |
|
|
189 |
# def get_layer_groups(self): |
|
|
190 |
# return (self.cpcmodel.encoder,self.cpcmodel.rnn,self.head) |
|
|
191 |
# |
|
|
192 |
# def get_output_layer(self): |
|
|
193 |
# return self.head[-1] |
|
|
194 |
# |
|
|
195 |
# def set_output_layer(self,x): |
|
|
196 |
# self.head[-1] = x |