--- a +++ b/stay_admission/operations.py @@ -0,0 +1,308 @@ +import numpy as np +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence +import sklearn + +OPS1 = { + 'identity': lambda d_model: Identity(d_model), + 'ffn': lambda d_model: FFN(d_model), + 'interaction_1': lambda d_model: Attention_s1(d_model), + 'interaction_2': lambda d_model: Attention_s2(d_model), +} + +OPS2 = { + 'identity': lambda d_model: Identity(d_model), + 'conv': lambda d_model: Conv(d_model), + 'attention': lambda d_model: SelfAttention(d_model), + 'rnn': lambda d_model: RNN(d_model), + 'ffn': lambda d_model: FFN(d_model), + 'interaction_1': lambda d_model: CatFC(d_model), + 'interaction_2': lambda d_model: Attention_x(d_model) +} + +OPS3 = { + 'identity': lambda d_model: Identity(d_model), + 'zero': lambda d_model: Zero(d_model), +} + +OPS4 = { + 'sum': lambda d_model: Sum(d_model), + 'mul': lambda d_model: Mul(d_model), +} + +class Zero(nn.Module): + def __init__(self, d_model): + super(Zero, self).__init__() + def forward(self, x, masks, lengths): + return torch.mul(x, 0) + + +class Sum(nn.Module): + def __init__(self, d_model): + super(Sum, self).__init__() + def forward(self, all_x): + out = all_x[0] + for x in all_x[1:]: + out += x + return out + +class Mul(nn.Module): + def __init__(self, d_model): + super(Mul, self).__init__() + def forward(self, all_x): + out = all_x[0] + for x in all_x[1:]: + out = out * x + return out + +class CatFC(nn.Module): + def __init__(self, d_model): + super(CatFC, self).__init__() + self.ffn = nn.Sequential(nn.Linear(2*d_model, 4 * d_model), nn.ReLU(), + nn.Linear(4 * d_model, d_model)) + self.layer_norm = nn.LayerNorm(d_model) + def forward(self, current_x, s, other_x): + s_ = s.unsqueeze(1).expand_as(current_x) + x = torch.cat((current_x, s_), dim=-1) + return self.layer_norm(self.ffn(x)) + + +class Conv(nn.Module): + def __init__(self, d_model): + super(Conv, self).__init__() + self.op = nn.Sequential( + nn.ReLU(), + nn.Conv1d(d_model, d_model, 3, padding=1), + nn.BatchNorm1d(d_model, affine=True) + ) + # self.batchnm = nn.BatchNorm1d(d_model, affine=True) + # self.conv = nn.Conv1d(d_model, d_model, 3, padding=1) + + def forward(self, x, masks, lengths): + x = self.op(x.permute(0, 2, 1)) + return x.permute(0, 2, 1) + + +class FFN(nn.Module): + + def __init__(self, d_model): + super(FFN, self).__init__() + self.ffn = nn.Sequential(nn.Linear(d_model, 4 * d_model), nn.ReLU(), + nn.Linear(4 * d_model, d_model)) + self.layer_norm = nn.LayerNorm(d_model) + + def forward(self, x, masks, lengths): + x = self.layer_norm(x + self.ffn(x)) + return x + + +class Identity(nn.Module): + def __init__(self, d_model): + super(Identity, self).__init__() + def forward(self, x, masks, lengths): + return x + + +class SelfAttention(nn.Module): + def __init__(self, in_feature, num_head=4, dropout=0.1): + super(SelfAttention, self).__init__() + self.in_feature = in_feature + self.num_head = num_head + self.size_per_head = in_feature // num_head + self.out_dim = num_head * self.size_per_head + assert self.size_per_head * num_head == in_feature + self.q_linear = nn.Linear(in_feature, in_feature, bias=False) + self.k_linear = nn.Linear(in_feature, in_feature, bias=False) + self.v_linear = nn.Linear(in_feature, in_feature, bias=False) + self.fc = nn.Linear(in_feature, in_feature, bias=False) + self.dropout = nn.Dropout(dropout) + self.layer_norm = nn.LayerNorm(in_feature) + + def forward(self, x, attn_mask, lengths): + batch_size = x.size(0) + res = x + query = self.q_linear(x) + key = self.k_linear(x) + value = self.v_linear(x) + + query = query.view(batch_size, self.num_head, -1, self.size_per_head) + key = key.view(batch_size, self.num_head, -1, self.size_per_head) + value = value.view(batch_size, self.num_head, -1, self.size_per_head) + + scale = np.sqrt(self.size_per_head) + energy = torch.matmul(query, key.permute(0, 1, 3, 2)) / scale + + attention = torch.softmax(energy, dim=-1) + x = torch.matmul(attention, value) + x = x.permute(0, 2, 1, 3).contiguous() + x = x.view(batch_size, -1, self.in_feature) + x = self.fc(x) + x = self.dropout(x) + x += res + x = self.layer_norm(x) + return x + + +class Attention_s1(nn.Module): + def __init__(self, in_feature, num_head=4, dropout=0.1): + super(Attention_s1, self).__init__() + self.in_feature = in_feature + self.num_head = num_head + self.size_per_head = in_feature // num_head + self.out_dim = num_head * self.size_per_head + assert self.size_per_head * num_head == in_feature + self.q_linear = nn.Linear(in_feature, in_feature, bias=False) + self.k_linear = nn.Linear(in_feature, in_feature, bias=False) + self.v_linear = nn.Linear(in_feature, in_feature, bias=False) + self.fc = nn.Linear(in_feature, in_feature, bias=False) + self.dropout = nn.Dropout(dropout) + self.layer_norm = nn.LayerNorm(in_feature) + + def forward(self, s, x1, x2): + batch_size = x1.size(0) + s = s.unsqueeze(1) + res = s + query = self.q_linear(s) + key = self.k_linear(x1) + value = self.v_linear(x1) + + query = query.view(batch_size, self.num_head, -1, self.size_per_head) + key = key.view(batch_size, self.num_head, -1, self.size_per_head) + value = value.view(batch_size, self.num_head, -1, self.size_per_head) + + scale = np.sqrt(self.size_per_head) + energy = torch.matmul(query, key.permute(0, 1, 3, 2)) / scale + + attention = torch.softmax(energy, dim=-1) + x = torch.matmul(attention, value) + x = x.permute(0, 2, 1, 3).contiguous() + x = x.view(batch_size, -1, self.in_feature) + x = self.fc(x) + x = self.dropout(x) + x += res + x = self.layer_norm(x) + return x.squeeze() + + +class Attention_s2(nn.Module): + def __init__(self, in_feature, num_head=4, dropout=0.1): + super(Attention_s2, self).__init__() + self.in_feature = in_feature + self.num_head = num_head + self.size_per_head = in_feature // num_head + self.out_dim = num_head * self.size_per_head + assert self.size_per_head * num_head == in_feature + self.q_linear = nn.Linear(in_feature, in_feature, bias=False) + self.k_linear = nn.Linear(in_feature, in_feature, bias=False) + self.v_linear = nn.Linear(in_feature, in_feature, bias=False) + self.fc = nn.Linear(in_feature, in_feature, bias=False) + self.dropout = nn.Dropout(dropout) + self.layer_norm = nn.LayerNorm(in_feature) + + def forward(self, s, x1, x2): + batch_size = x2.size(0) + s = s.unsqueeze(1) + res = s + query = self.q_linear(s) + key = self.k_linear(x2) + value = self.v_linear(x2) + + query = query.view(batch_size, self.num_head, -1, self.size_per_head) + key = key.view(batch_size, self.num_head, -1, self.size_per_head) + value = value.view(batch_size, self.num_head, -1, self.size_per_head) + + scale = np.sqrt(self.size_per_head) + energy = torch.matmul(query, key.permute(0, 1, 3, 2)) / scale + + attention = torch.softmax(energy, dim=-1) + x = torch.matmul(attention, value) + x = x.permute(0, 2, 1, 3).contiguous() + x = x.view(batch_size, -1, self.in_feature) + x = self.fc(x) + x = self.dropout(x) + x += res + x = self.layer_norm(x) + return x.squeeze() + +class Attention_x(nn.Module): + def __init__(self, in_feature, num_head=4, dropout=0.1): + super(Attention_x, self).__init__() + self.in_feature = in_feature + self.num_head = num_head + self.size_per_head = in_feature // num_head + self.out_dim = num_head * self.size_per_head + assert self.size_per_head * num_head == in_feature + self.q_linear = nn.Linear(in_feature, in_feature, bias=False) + self.k_linear = nn.Linear(in_feature, in_feature, bias=False) + self.v_linear = nn.Linear(in_feature, in_feature, bias=False) + self.fc = nn.Linear(in_feature, in_feature, bias=False) + self.dropout = nn.Dropout(dropout) + self.layer_norm = nn.LayerNorm(in_feature) + + def forward(self, current_x, s, other_x): + batch_size = current_x.size(0) + res = current_x + query = self.q_linear(current_x) + key = self.k_linear(other_x) + value = self.v_linear(other_x) + + query = query.view(batch_size, self.num_head, -1, self.size_per_head) + key = key.view(batch_size, self.num_head, -1, self.size_per_head) + value = value.view(batch_size, self.num_head, -1, self.size_per_head) + + scale = np.sqrt(self.size_per_head) + energy = torch.matmul(query, key.permute(0, 1, 3, 2)) / scale + + attention = torch.softmax(energy, dim=-1) + x = torch.matmul(attention, value) + x = x.permute(0, 2, 1, 3).contiguous() + x = x.view(batch_size, -1, self.in_feature) + x = self.fc(x) + x = self.dropout(x) + x += res + x = self.layer_norm(x) + return x + + +class RNN(nn.Module): + def __init__(self, d_model): + super(RNN, self).__init__() + self.rnn = nn.GRU(d_model, d_model, num_layers=1, batch_first=True) + def forward(self, x, masks, lengths): + rnn_input = x + rnn_output, _ = self.rnn(rnn_input) + return rnn_output + + +class MaxPoolLayer(nn.Module): + """ + A layer that performs max pooling along the sequence dimension + """ + + def __init__(self): + super().__init__() + + def forward(self, inputs, mask_or_lengths=None): + """ + inputs: tensor of shape (batch_size, seq_len, hidden_size) + mask_or_lengths: tensor of shape (batch_size) or (batch_size, seq_len) + + returns: tensor of shape (batch_size, hidden_size) + """ + bs, sl, _ = inputs.size() + if mask_or_lengths is not None: + if len(mask_or_lengths.size()) == 1: + mask = (torch.arange(sl, device=inputs.device).unsqueeze(0).expand(bs, sl) >= mask_or_lengths.unsqueeze( + 1)) + else: + mask = mask_or_lengths + inputs = inputs.masked_fill(mask.unsqueeze(-1).expand_as(inputs), float('-inf')) + max_pooled = inputs.max(1)[0] + return max_pooled + + +def prroc(testy, probs): + precision, recall, thresholds = sklearn.metrics.precision_recall_curve(testy, probs) + auc = auc(recall, precision) + return auc