--- a +++ b/patient/models/adacare.py @@ -0,0 +1,207 @@ +import torch +from torch import nn +import torch.nn.utils.rnn as rnn_utils +from torch.utils import data +from torch.autograd import Variable +#from baseline import * +from models.baseline import * +import torch.nn.functional as F + +class Sparsemax(nn.Module): + """Sparsemax function.""" + + def __init__(self, dim=None): + super(Sparsemax, self).__init__() + + self.dim = -1 if dim is None else dim + + def forward(self, input, device='cuda'): + original_size = input.size() + input = input.view(-1, input.size(self.dim)) + + dim = 1 + number_of_logits = input.size(dim) + + input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input) + + zs = torch.sort(input=input, dim=dim, descending=True)[0] + range = torch.arange(start=1, end=number_of_logits + 1, device=device, dtype=torch.float32).view(1, -1) + range = range.expand_as(zs) + + bound = 1 + range * zs + cumulative_sum_zs = torch.cumsum(zs, dim) + is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type()) + k = torch.max(is_gt * range, dim, keepdim=True)[0] + + zs_sparse = is_gt * zs + taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k + taus = taus.expand_as(input) + + self.output = torch.max(torch.zeros_like(input), input - taus) + + output = self.output.view(original_size) + + return output + + def backward(self, grad_output): + dim = 1 + + nonzeros = torch.ne(self.output, 0) + sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim) + self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output)) + + return self.grad_input + + +class CausalConv1d(torch.nn.Conv1d): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + groups=1, + bias=True): + self.__padding = (kernel_size - 1) * dilation + + super(CausalConv1d, self).__init__( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=self.__padding, + dilation=dilation, + groups=groups, + bias=bias) + + def forward(self, input): + result = super(CausalConv1d, self).forward(input) + if self.__padding != 0: + return result[:, :, :-self.__padding] + return result + + +class Recalibration(nn.Module): + def __init__(self, channel, reduction=9, use_h=True, use_c=True, activation='sigmoid'): + super(Recalibration, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool1d(1) + self.use_h = use_h + self.use_c = use_c + scale_dim = 0 + self.activation = activation + + self.nn_c = nn.Linear(channel, channel // reduction) + scale_dim += channel // reduction + + self.nn_rescale = nn.Linear(scale_dim, channel) + self.sparsemax = Sparsemax(dim=1) + + def forward(self, x, device='cuda'): + b, c, t = x.size() + + y_origin = x[:, :, -1] + se_c = self.nn_c(y_origin) + se_c = torch.relu(se_c) + y = se_c + + y = self.nn_rescale(y).view(b, c, 1) + if self.activation == 'sigmoid': + y = torch.sigmoid(y) + else: + y = self.sparsemax(y, device) + return x * y.expand_as(x), y + + +class AdaCare(nn.Module): + def __init__(self, vocab_size, hidden_dim=128, kernel_size=2, kernel_num=64, input_dim=128, output_dim=1, dropout=0.5, r_v=4, + r_c=4, activation='sigmoid', pretrain = False): + super(AdaCare, self).__init__() + # self.embedding = nn.Embedding(vocab_size + 1, hidden_dim, padding_idx=-1) + # self.embedding = nn.Sequential(nn.Linear(vocab_size + 1, hidden_dim), nn.ReLU()) + self.embbedding1 = nn.Sequential(nn.Linear(vocab_size, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU()) + self.hidden_dim = hidden_dim + self.kernel_size = kernel_size + self.kernel_num = kernel_num + self.input_dim = input_dim + self.output_dim = output_dim + self.dropout = dropout + + self.nn_conv1 = CausalConv1d(input_dim, kernel_num, kernel_size, 1, 1) + self.nn_conv3 = CausalConv1d(input_dim, kernel_num, kernel_size, 1, 3) + self.nn_conv5 = CausalConv1d(input_dim, kernel_num, kernel_size, 1, 5) + torch.nn.init.xavier_uniform_(self.nn_conv1.weight) + torch.nn.init.xavier_uniform_(self.nn_conv3.weight) + torch.nn.init.xavier_uniform_(self.nn_conv5.weight) + + self.nn_convse = Recalibration(3 * kernel_num, r_c, use_h=False, use_c=True, activation='sigmoid') + self.nn_inputse = Recalibration(input_dim, r_v, use_h=False, use_c=True, activation=activation) + self.rnn = nn.GRUCell(input_dim + 3 * kernel_num, hidden_dim) + self.nn_output = nn.Linear(hidden_dim, 2) + self.nn_dropout = nn.Dropout(dropout) + + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + self.tanh = nn.Tanh() + self.pooler = MaxPoolLayer() + self.pretrain = pretrain + if self.pretrain == True or self.pretrain == "True": + print(self.state_dict()["embbedding1.0.weight"]) + pt_emb = torch.load("/home/xmw5190/HMMP/model/HADM_AE_200_5_28.p").state_dict() + + state_dict = {k.split('enc.')[-1]:v for k,v in pt_emb.items() if k.split('enc.')[-1] in self.state_dict().keys() and "embbedding1" in k} + print(state_dict.keys()) + self.state_dict().update(state_dict) + self.load_state_dict(state_dict, strict = False) + print(self.state_dict()["embbedding1.0.weight"]) + # if self.pretrain == True or self.pretrain == "True": + # print(self.state_dict()["embedding.0.weight"]) + # pt_emb = torch.load("/home/xmw5190/HMMP/model/HADM_AE_200_new_mask_lowlr.p").state_dict() + + # weight, bias = pt_emb['enc.embbedding1.0.weight'], pt_emb["enc.embbedding1.0.bias"] + # self.state_dict().update({"embedding.0.weight":weight}) + # self.state_dict().update({"embedding.0.bias":bias}) + # self.load_state_dict({"embedding.0.weight":weight, "embedding.0.bias":bias},strict = False) + # print(self.state_dict()["embedding.0.weight"]) + + + def forward(self, input, device): + # input shape [batch_size, timestep, feature_dim] + batch_size = input.size(0) + time_step = input.size(1) + # feature_dim = input.size(2) + input = self.embbedding1(input)#.sum(dim=2) + + cur_h = Variable(torch.zeros(batch_size, self.hidden_dim)).to(device) + inputse_att = [] + convse_att = [] + h = [] + + conv_input = input.permute(0, 2, 1) + conv_res1 = self.nn_conv1(conv_input) + conv_res3 = self.nn_conv3(conv_input) + conv_res5 = self.nn_conv5(conv_input) + + conv_res = torch.cat((conv_res1, conv_res3, conv_res5), dim=1) + conv_res = self.relu(conv_res) + + for cur_time in range(time_step): + convse_res, cur_convatt = self.nn_convse(conv_res[:, :, :cur_time + 1], device=device) + inputse_res, cur_inputatt = self.nn_inputse(input[:, :cur_time + 1, :].permute(0, 2, 1), device=device) + cur_input = torch.cat((convse_res[:, :, -1], inputse_res[:, :, -1]), dim=-1) + + cur_h = self.rnn(cur_input, cur_h) + h.append(cur_h) + convse_att.append(cur_convatt) + inputse_att.append(cur_inputatt) + + h = torch.stack(h).permute(1, 0, 2) + h_reshape = h.contiguous().view(batch_size, time_step, self.hidden_dim) + if self.dropout > 0.0: + h_reshape = self.nn_dropout(h_reshape) + output = self.pooler(h_reshape) + output = self.nn_output(output) + return output + +if __name__ == '__main__': + model = AdaCare(vocab_size = 7687, hidden_dim=64, kernel_size=2, kernel_num=64, input_dim=64, output_dim=1, dropout=0.5, r_v=4, + r_c=4, activation='sigmoid', pretrain = "False") \ No newline at end of file