--- a +++ b/clinical_ts/cpc.py @@ -0,0 +1,196 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/17_cpc.ipynb (unless otherwise specified). + +__all__ = ['CPCEncoder', 'CPCModel'] + +# Cell +import torch +import torch.nn.functional as F +import torch.nn as nn +from clinical_ts.basic_conv1d import _conv1d +import numpy as np + +from .basic_conv1d import listify, bn_drop_lin + +# Cell +class CPCEncoder(nn.Sequential): + 'CPC Encoder' + def __init__(self, input_channels, strides=[5,4,2,2,2], kss=[10,8,4,4,4], features=[512,512,512,512],bn=False): + assert(len(strides)==len(kss) and len(strides)==len(features)) + lst = [] + for i,(s,k,f) in enumerate(zip(strides,kss,features)): + lst.append(_conv1d(input_channels if i==0 else features[i-1],f,kernel_size=k,stride=s,bn=bn)) + super().__init__(*lst) + self.downsampling_factor = np.prod(strides) + self.output_dim = features[-1] + # output: bs, output_dim, seq//downsampling_factor + def encode(self, input): + #bs = input.size()[0] + #ch = input.size()[1] + #seq = input.size()[2] + #segments = seq//self.downsampling_factor + #input_encoded = self.forward(input[:,:,:segments*self.downsampling_factor]).transpose(1,2) #bs, seq//downsampling, encoder_output_dim (standard ordering for batch_first RNNs) + input_encoded = self.forward(input).transpose(1,2) + return input_encoded + +# Cell +class CPCModel(nn.Module): + "CPC model" + def __init__(self, input_channels, strides=[5,4,2,2,2], kss=[10,8,4,4,4], features=[512,512,512,512],bn_encoder=False, n_hidden=512,n_layers=2,mlp=False,lstm=True,bias_proj=False, num_classes=None, concat_pooling=True, ps_head=0.5,lin_ftrs_head=[512],bn_head=True,skip_encoder=False): + super().__init__() + assert(skip_encoder is False or num_classes is not None)#pretraining only with encoder + self.encoder = CPCEncoder(input_channels,strides=strides,kss=kss,features=features,bn=bn_encoder) if skip_encoder is False else None + self.encoder_output_dim = self.encoder.output_dim if skip_encoder is False else None + self.encoder_downsampling_factor = self.encoder.downsampling_factor if skip_encoder is False else None + self.n_hidden = n_hidden + self.n_layers = n_layers + self.mlp = mlp + + self.num_classes = num_classes + self.concat_pooling = concat_pooling + + self.rnn = nn.LSTM(self.encoder_output_dim if skip_encoder is False else input_channels,n_hidden,num_layers=n_layers,batch_first=True) if lstm is True else nn.GRU(self.encoder.output_dim,n_hidden,num_layers=n_layers,batch_first=True) + + if(num_classes is None): #pretraining + if(mlp):# additional hidden layer as in simclr + self.proj = nn.Sequential(nn.Linear(n_hidden, n_hidden),nn.ReLU(inplace=True),nn.Linear(n_hidden, self.encoder_output_dim,bias=bias_proj)) + else: + self.proj = nn.Linear(n_hidden, self.encoder_output_dim,bias=bias_proj) + else: #classifier + #slightly adapted from RNN1d + layers_head =[] + if(self.concat_pooling): + layers_head.append(AdaptiveConcatPoolRNN()) + + #classifier + nf = 3*n_hidden if concat_pooling else n_hidden + lin_ftrs_head = [nf, num_classes] if lin_ftrs_head is None else [nf] + lin_ftrs_head + [num_classes] + ps_head = listify(ps_head) + if len(ps_head)==1: + ps_head = [ps_head[0]/2] * (len(lin_ftrs_head)-2) + ps_head + actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs_head)-2) + [None] + + for ni,no,p,actn in zip(lin_ftrs_head[:-1],lin_ftrs_head[1:],ps_head,actns): + layers_head+=bn_drop_lin(ni,no,bn_head,p,actn) + self.head=nn.Sequential(*layers_head) + + + def forward(self, input): + # input shape bs,ch,seq + if(self.encoder is not None): + input_encoded = self.encoder.encode(input) + else: + input_encoded = input.transpose(1,2) #bs, seq, channels + output_rnn, _ = self.rnn(input_encoded) #output_rnn: bs, seq, n_hidden + if(self.num_classes is None):#pretraining + return input_encoded, self.proj(output_rnn) + else:#classifier + output = output_rnn.transpose(1,2)#bs,n_hidden,seq (i.e. standard CNN channel ordering) + if(self.concat_pooling is False): + output = output[:,:,-1] + return self.head(output) + + def get_layer_groups(self): + return (self.encoder,self.rnn,self.head) + + def get_output_layer(self): + return self.head[-1] + + def set_output_layer(self,x): + self.head[-1] = x + + def cpc_loss(self,input, target=None, steps_predicted=5, n_false_negatives=9, negatives_from_same_seq_only=False, eval_acc=False): + assert(self.num_classes is None) + + input_encoded, output = self.forward(input) #input_encoded: bs, seq, features; output: bs,seq,features + input_encoded_flat = input_encoded.reshape(-1,input_encoded.size(2)) #for negatives below: -1, features + + bs = input_encoded.size()[0] + seq = input_encoded.size()[1] + + loss = torch.tensor(0,dtype=torch.float32).to(input.device) + tp_cnt = torch.tensor(0,dtype=torch.int64).to(input.device) + + for i in range(input_encoded.size()[1]-steps_predicted): + positives = input_encoded[:,i+steps_predicted].unsqueeze(1) #bs,1,encoder_output_dim + if(negatives_from_same_seq_only): + idxs = torch.randint(0,(seq-1),(bs*n_false_negatives,)).to(input.device) + else:#negative from everywhere + idxs = torch.randint(0,bs*(seq-1),(bs*n_false_negatives,)).to(input.device) + idxs_seq = torch.remainder(idxs,seq-1) #bs*false_neg + idxs_seq2 = idxs_seq * (idxs_seq<(i+steps_predicted)).long() +(idxs_seq+1)*(idxs_seq>=(i+steps_predicted)).long()#bs*false_neg + if(negatives_from_same_seq_only): + idxs_batch = torch.arange(0,bs).repeat_interleave(n_false_negatives).to(input.device) + else: + idxs_batch = idxs//(seq-1) + idxs2_flat = idxs_batch*seq+idxs_seq2 #for negatives from everywhere: this skips step i+steps_predicted from the other sequences as well for simplicity + + negatives = input_encoded_flat[idxs2_flat].view(bs,n_false_negatives,-1) #bs*false_neg, encoder_output_dim + candidates = torch.cat([positives,negatives],dim=1)#bs,false_neg+1,encoder_output_dim + preds=torch.sum(output[:,i].unsqueeze(1)*candidates,dim=-1) #bs,(false_neg+1) + targs = torch.zeros(bs, dtype=torch.int64).to(input.device) + + if(eval_acc): + preds_argmax = torch.argmax(preds,dim=-1) + tp_cnt += torch.sum(preds_argmax == targs) + + loss += F.cross_entropy(preds,targs) + if(eval_acc): + return loss, tp_cnt.float()/bs/(input_encoded.size()[1]-steps_predicted) + else: + return loss + +#copied from RNN1d +class AdaptiveConcatPoolRNN(nn.Module): + def __init__(self, bidirectional=False): + super().__init__() + self.bidirectional = bidirectional + def forward(self,x): + #input shape bs, ch, ts + t1 = nn.AdaptiveAvgPool1d(1)(x) + t2 = nn.AdaptiveMaxPool1d(1)(x) + + if(self.bidirectional is False): + t3 = x[:,:,-1] + else: + channels = x.size()[1] + t3 = torch.cat([x[:,:channels,-1],x[:,channels:,0]],1) + out=torch.cat([t1.squeeze(-1),t2.squeeze(-1),t3],1) #output shape bs, 3*ch + return out + +#class CPCClassifier(nn.Module): +# def __init__(self, cpcmodel, num_classes, concat_pooling=True, ps_head=0.5,lin_ftrs_head=None,bn_head=True): +# super().__init__() +# self.cpcmodel = cpcmodel +# self.concat_pooling = concat_pooling +# +# #slightly adapted from RNN1d +# layers_head =[] +# if(self.concat_pooling): +# layers_head.append(AdaptiveConcatPoolRNN()) +# +# #classifier +# nf = 3*self.cpcmodel.encoder_output_dim if concat_pooling else self.cpcmodel.encoder_output_dim +# lin_ftrs_head = [nf, num_classes] if lin_ftrs_head is None else [nf] + lin_ftrs_head + [num_classes] +# ps_head = listify(ps_head) +# if len(ps_head)==1: +# ps_head = [ps_head[0]/2] * (len(lin_ftrs_head)-2) + ps_head +# actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs_head)-2) + [None] +# +# for ni,no,p,actn in zip(lin_ftrs_head[:-1],lin_ftrs_head[1:],ps_head,actns): +# layers_head+=bn_drop_lin(ni,no,bn_head,p,actn) +# self.head=nn.Sequential(*layers_head) +# +# def forward(self,input): +# output = self.cpcmodel(input) +# if(self.concat_pooling is False): +# output = output[:,:,-1] +# return self.head(output) +# +# def get_layer_groups(self): +# return (self.cpcmodel.encoder,self.cpcmodel.rnn,self.head) +# +# def get_output_layer(self): +# return self.head[-1] +# +# def set_output_layer(self,x): +# self.head[-1] = x