--- a +++ b/app/models/backbones/concare.py @@ -0,0 +1,589 @@ +# import packages +import copy + +# import packages +import math + +import torch +import torch.nn.functional as F +from torch import nn +from torch.autograd import Variable + + +class SingleAttention(nn.Module): + def __init__( + self, + attention_input_dim, + attention_hidden_dim, + attention_type="add", + demographic_dim=12, + time_aware=False, + use_demographic=False, + ): + super(SingleAttention, self).__init__() + + self.attention_type = attention_type + self.attention_hidden_dim = attention_hidden_dim + self.attention_input_dim = attention_input_dim + self.use_demographic = use_demographic + self.demographic_dim = demographic_dim + self.time_aware = time_aware + + # batch_time = torch.arange(0, batch_mask.size()[1], dtype=torch.float32).reshape(1, batch_mask.size()[1], 1) + # batch_time = batch_time.repeat(batch_mask.size()[0], 1, 1) + + if attention_type == "add": + if self.time_aware: + # self.Wx = nn.Parameter(torch.randn(attention_input_dim+1, attention_hidden_dim)) + self.Wx = nn.Parameter( + torch.randn(attention_input_dim, attention_hidden_dim) + ) + self.Wtime_aware = nn.Parameter(torch.randn(1, attention_hidden_dim)) + nn.init.kaiming_uniform_(self.Wtime_aware, a=math.sqrt(5)) + else: + self.Wx = nn.Parameter( + torch.randn(attention_input_dim, attention_hidden_dim) + ) + self.Wt = nn.Parameter( + torch.randn(attention_input_dim, attention_hidden_dim) + ) + self.Wd = nn.Parameter(torch.randn(demographic_dim, attention_hidden_dim)) + self.bh = nn.Parameter( + torch.zeros( + attention_hidden_dim, + ) + ) + self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1)) + self.ba = nn.Parameter( + torch.zeros( + 1, + ) + ) + + nn.init.kaiming_uniform_(self.Wd, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.Wx, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.Wt, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5)) + elif attention_type == "mul": + self.Wa = nn.Parameter( + torch.randn(attention_input_dim, attention_input_dim) + ) + self.ba = nn.Parameter( + torch.zeros( + 1, + ) + ) + + nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5)) + elif attention_type == "concat": + if self.time_aware: + self.Wh = nn.Parameter( + torch.randn(2 * attention_input_dim + 1, attention_hidden_dim) + ) + else: + self.Wh = nn.Parameter( + torch.randn(2 * attention_input_dim, attention_hidden_dim) + ) + + self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1)) + self.ba = nn.Parameter( + torch.zeros( + 1, + ) + ) + + nn.init.kaiming_uniform_(self.Wh, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5)) + + elif attention_type == "new": + self.Wt = nn.Parameter( + torch.randn(attention_input_dim, attention_hidden_dim) + ) + self.Wx = nn.Parameter( + torch.randn(attention_input_dim, attention_hidden_dim) + ) + + self.rate = nn.Parameter(torch.zeros(1) + 0.8) + nn.init.kaiming_uniform_(self.Wx, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.Wt, a=math.sqrt(5)) + + else: + raise RuntimeError("Wrong attention type.") + + self.tanh = nn.Tanh() + self.softmax = nn.Softmax(dim=1) + self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU() + + def forward(self, input, device, demo=None): + + ( + batch_size, + time_step, + input_dim, + ) = input.size() # batch_size * time_step * hidden_dim(i) + + time_decays = ( + torch.tensor(range(time_step - 1, -1, -1), dtype=torch.float32) + .unsqueeze(-1).unsqueeze(0).to(device=device) + ) # 1*t*1 + b_time_decays = time_decays.repeat(batch_size, 1, 1) + 1 # b t 1 + + if self.attention_type == "add": # B*T*I @ H*I + q = torch.matmul(input[:, -1, :], self.Wt) # b h + q = torch.reshape(q, (batch_size, 1, self.attention_hidden_dim)) # B*1*H + if self.time_aware == True: + k = torch.matmul(input, self.Wx) # b t h + time_hidden = torch.matmul(b_time_decays, self.Wtime_aware) # b t h + else: + k = torch.matmul(input, self.Wx) # b t h + if self.use_demographic: + d = torch.matmul(demo, self.Wd) # B*H + d = torch.reshape( + d, (batch_size, 1, self.attention_hidden_dim) + ) # b 1 h + h = q + k + self.bh # b t h + if self.time_aware: + h += time_hidden + h = self.tanh(h) # B*T*H + e = torch.matmul(h, self.Wa) + self.ba # B*T*1 + e = torch.reshape(e, (batch_size, time_step)) # b t + elif self.attention_type == "mul": + e = torch.matmul(input[:, -1, :], self.Wa) # b i + e = ( + torch.matmul(e.unsqueeze(1), input.permute(0, 2, 1)).squeeze() + self.ba + ) # b t + elif self.attention_type == "concat": + q = input[:, -1, :].unsqueeze(1).repeat(1, time_step, 1) # b t i + k = input + c = torch.cat((q, k), dim=-1) # B*T*2I + if self.time_aware: + c = torch.cat((c, b_time_decays), dim=-1) # B*T*2I+1 + h = torch.matmul(c, self.Wh) + h = self.tanh(h) + e = torch.matmul(h, self.Wa) + self.ba # B*T*1 + e = torch.reshape(e, (batch_size, time_step)) # b t + + elif self.attention_type == "new": + + q = torch.matmul(input[:, -1, :], self.Wt) # b h + q = torch.reshape(q, (batch_size, 1, self.attention_hidden_dim)) # B*1*H + k = torch.matmul(input, self.Wx) # b t h + dot_product = torch.matmul(q, k.transpose(1, 2)).squeeze() # b t + denominator = self.sigmoid(self.rate) * ( + torch.log(2.72 + (1 - self.sigmoid(dot_product))) + * (b_time_decays.squeeze()) + ) + e = self.relu(self.sigmoid(dot_product) / (denominator)) # b * t + + a = self.softmax(e) # B*T + v = torch.matmul(a.unsqueeze(1), input).squeeze() # B*I + + return v, a + + +class FinalAttentionQKV(nn.Module): + def __init__( + self, + attention_input_dim, + attention_hidden_dim, + attention_type="add", + dropout=None, + ): + super(FinalAttentionQKV, self).__init__() + + self.attention_type = attention_type + self.attention_hidden_dim = attention_hidden_dim + self.attention_input_dim = attention_input_dim + + self.W_q = nn.Linear(attention_input_dim, attention_hidden_dim) + self.W_k = nn.Linear(attention_input_dim, attention_hidden_dim) + self.W_v = nn.Linear(attention_input_dim, attention_hidden_dim) + + self.W_out = nn.Linear(attention_hidden_dim, 1) + + self.b_in = nn.Parameter( + torch.zeros( + 1, + ) + ) + self.b_out = nn.Parameter( + torch.zeros( + 1, + ) + ) + + nn.init.kaiming_uniform_(self.W_q.weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.W_k.weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.W_v.weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.W_out.weight, a=math.sqrt(5)) + + self.Wh = nn.Parameter( + torch.randn(2 * attention_input_dim, attention_hidden_dim) + ) + self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1)) + self.ba = nn.Parameter( + torch.zeros( + 1, + ) + ) + + nn.init.kaiming_uniform_(self.Wh, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5)) + + self.dropout = nn.Dropout(p=dropout) + self.tanh = nn.Tanh() + self.softmax = nn.Softmax(dim=1) + self.sigmoid = nn.Sigmoid() + + def forward(self, input): + + ( + batch_size, + time_step, + input_dim, + ) = input.size() # batch_size * input_dim + 1 * hidden_dim(i) + input_q = self.W_q(input[:, -1, :]) # b h + input_k = self.W_k(input) # b t h + input_v = self.W_v(input) # b t h + + if self.attention_type == "add": # B*T*I @ H*I + + q = torch.reshape( + input_q, (batch_size, 1, self.attention_hidden_dim) + ) # B*1*H + h = q + input_k + self.b_in # b t h + h = self.tanh(h) # B*T*H + e = self.W_out(h) # b t 1 + e = torch.reshape(e, (batch_size, time_step)) # b t + + elif self.attention_type == "mul": + q = torch.reshape( + input_q, (batch_size, self.attention_hidden_dim, 1) + ) # B*h 1 + e = torch.matmul(input_k, q).squeeze() # b t + + elif self.attention_type == "concat": + q = input_q.unsqueeze(1).repeat(1, time_step, 1) # b t h + k = input_k + c = torch.cat((q, k), dim=-1) # B*T*2I + h = torch.matmul(c, self.Wh) + h = self.tanh(h) + e = torch.matmul(h, self.Wa) + self.ba # B*T*1 + e = torch.reshape(e, (batch_size, time_step)) # b t + + a = self.softmax(e) # B*T + if self.dropout is not None: + a = self.dropout(a) + v = torch.matmul(a.unsqueeze(1), input_v).squeeze() # B*I + + return v, a + + +class PositionwiseFeedForward(nn.Module): # new added + "Implements FFN equation." + + def __init__(self, d_model, d_ff, dropout=0.1): + super(PositionwiseFeedForward, self).__init__() + self.w_1 = nn.Linear(d_model, d_ff) + self.w_2 = nn.Linear(d_ff, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + return self.w_2(self.dropout(F.relu(self.w_1(x)))), None + + +class PositionalEncoding(nn.Module): # new added / not use anymore + "Implement the PE function." + + def __init__(self, d_model, dropout, max_len=400): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model) + position = torch.arange(0.0, max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0.0, d_model, 2) * -(math.log(10000.0) / d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, x): + x = x + Variable(self.pe[:, : x.size(1)], requires_grad=False) + return self.dropout(x) + + +class MultiHeadedAttention(nn.Module): + def __init__(self, h, d_model, dropout=0): + "Take in model size and number of heads." + super(MultiHeadedAttention, self).__init__() + assert d_model % h == 0 + # We assume d_v always equals d_k + self.d_k = d_model // h + self.h = h + self.linears = nn.ModuleList( + [nn.Linear(d_model, self.d_k * self.h) for _ in range(3)] + ) + self.final_linear = nn.Linear(d_model, d_model) + self.attn = None + self.dropout = nn.Dropout(p=dropout) + + def attention(self, query, key, value, mask=None, dropout=None): + "Compute 'Scaled Dot Product Attention'" + d_k = query.size(-1) # b h t d_k + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # b h t t + if mask is not None: # 1 1 t t + scores = scores.masked_fill(mask == 0, -1e9) # b h t t 下三角 + p_attn = F.softmax(scores, dim=-1) # b h t t + if dropout is not None: + p_attn = dropout(p_attn) + return torch.matmul(p_attn, value), p_attn # b h t v (d_k) + + def cov(self, m, y=None): + if y is not None: + m = torch.cat((m, y), dim=0) + m_exp = torch.mean(m, dim=1) + x = m - m_exp[:, None] + cov = 1 / (x.size(1) - 1) * x.mm(x.t()) + return cov + + def forward(self, query, key, value, mask=None): + if mask is not None: + # Same mask applied to all h heads. + mask = mask.unsqueeze(1) # 1 1 t t + + nbatches = query.size(0) # b + input_dim = query.size(1) # i+1 + feature_dim = query.size(1) # i+1 + + # input size -> # batch_size * d_input * hidden_dim + + # d_model => h * d_k + query, key, value = [ + l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) + for l, x in zip(self.linears, (query, key, value)) + ] # b num_head d_input d_k + + x, self.attn = self.attention( + query, key, value, mask=mask, dropout=self.dropout + ) # b num_head d_input d_v (d_k) + + x = ( + x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k) + ) # batch_size * d_input * hidden_dim + + # DeCov + DeCov_contexts = x.transpose(0, 1).transpose(1, 2) # I+1 H B + Covs = self.cov(DeCov_contexts[0, :, :]) + DeCov_loss = 0.5 * ( + torch.norm(Covs, p="fro") ** 2 - torch.norm(torch.diag(Covs)) ** 2 + ) + for i in range(feature_dim - 1): + Covs = self.cov(DeCov_contexts[i + 1, :, :]) + DeCov_loss += 0.5 * ( + torch.norm(Covs, p="fro") ** 2 - torch.norm(torch.diag(Covs)) ** 2 + ) + + return self.final_linear(x), DeCov_loss + + +class LayerNorm(nn.Module): + def __init__(self, features, eps=1e-7): + super(LayerNorm, self).__init__() + self.a_2 = nn.Parameter(torch.ones(features)) + self.b_2 = nn.Parameter(torch.zeros(features)) + self.eps = eps + + def forward(self, x): + mean = x.mean(-1, keepdim=True) + std = x.std(-1, keepdim=True) + return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 + + +class SublayerConnection(nn.Module): + """ + A residual connection followed by a layer norm. + Note for code simplicity the norm is first as opposed to last. + """ + + def __init__(self, size, dropout): + super(SublayerConnection, self).__init__() + self.norm = LayerNorm(size) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, sublayer): + "Apply residual connection to any sublayer with the same size." + returned_value = sublayer(self.norm(x)) + return x + self.dropout(returned_value[0]), returned_value[1] + + +class ConCare(nn.Module): + def __init__( + self, + lab_dim, # lab_dim + hidden_dim, + demo_dim, + d_model, + MHD_num_head, + d_ff, + # output_dim, + # device, + drop=0.5, + ): + super(ConCare, self).__init__() + + # hyperparameters + self.lab_dim = lab_dim + self.hidden_dim = hidden_dim # d_model + self.d_model = d_model + self.MHD_num_head = MHD_num_head + self.d_ff = d_ff + # self.output_dim = output_dim + self.drop = drop + self.demo_dim = demo_dim + + # layers + self.PositionalEncoding = PositionalEncoding( + self.d_model, dropout=0, max_len=400 + ) + + self.GRUs = nn.ModuleList( + [ + copy.deepcopy(nn.GRU(1, self.hidden_dim, batch_first=True)) + for _ in range(self.lab_dim) + ] + ) + self.LastStepAttentions = nn.ModuleList( + [ + copy.deepcopy( + SingleAttention( + self.hidden_dim, + 8, + attention_type="new", + demographic_dim=12, + time_aware=True, + use_demographic=False, + ) + ) + for _ in range(self.lab_dim) + ] + ) + + self.FinalAttentionQKV = FinalAttentionQKV( + self.hidden_dim, + self.hidden_dim, + attention_type="mul", + dropout=self.drop, + ) + + self.MultiHeadedAttention = MultiHeadedAttention( + self.MHD_num_head, self.d_model, dropout=self.drop + ) + self.SublayerConnection = SublayerConnection(self.d_model, dropout=self.drop) + + self.PositionwiseFeedForward = PositionwiseFeedForward( + self.d_model, self.d_ff, dropout=0.1 + ) + + self.demo_lab_proj = nn.Linear(self.demo_dim + self.lab_dim, self.hidden_dim) + self.demo_proj_main = nn.Linear(self.demo_dim, self.hidden_dim) + self.demo_proj = nn.Linear(self.demo_dim, self.hidden_dim) + self.output0 = nn.Linear(self.hidden_dim, self.hidden_dim) + # self.output1 = nn.Linear(self.hidden_dim, self.output_dim) + + self.dropout = nn.Dropout(p=self.drop) + self.tanh = nn.Tanh() + self.softmax = nn.Softmax() + self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU() + + def concare_encoder(self, input, demo_input, device): + + # input shape [batch_size, timestep, feature_dim] + demo_main = self.tanh(self.demo_proj_main(demo_input)).unsqueeze( + 1 + ) # b hidden_dim + + batch_size = input.size(0) + time_step = input.size(1) + feature_dim = input.size(2) + assert feature_dim == self.lab_dim # input Tensor : 256 * 48 * 76 + assert self.d_model % self.MHD_num_head == 0 + + # forward + GRU_embeded_input = self.GRUs[0]( + input[:, :, 0].unsqueeze(-1).to(device=device), + Variable( + torch.zeros(batch_size, self.hidden_dim) + .to(device=device) + .unsqueeze(0) + ), + )[ + 0 + ] # b t h + Attention_embeded_input = self.LastStepAttentions[0](GRU_embeded_input, device)[ + 0 + ].unsqueeze( + 1 + ) # b 1 h + + for i in range(feature_dim - 1): + embeded_input = self.GRUs[i + 1]( + input[:, :, i + 1].unsqueeze(-1), + Variable( + torch.zeros(batch_size, self.hidden_dim) + .to(device=device) + .unsqueeze(0) + ), + )[ + 0 + ] # b 1 h + embeded_input = self.LastStepAttentions[i + 1](embeded_input, device)[0].unsqueeze( + 1 + ) # b 1 h + Attention_embeded_input = torch.cat( + (Attention_embeded_input, embeded_input), 1 + ) # b i h + Attention_embeded_input = torch.cat( + (Attention_embeded_input, demo_main), 1 + ) # b i+1 h + posi_input = self.dropout( + Attention_embeded_input + ) # batch_size * d_input+1 * hidden_dim + + contexts = self.SublayerConnection( + posi_input, + lambda x: self.MultiHeadedAttention( + posi_input, posi_input, posi_input, None + ), + ) # # batch_size * d_input * hidden_dim + + DeCov_loss = contexts[1] + contexts = contexts[0] + + contexts = self.SublayerConnection( + contexts, lambda x: self.PositionwiseFeedForward(contexts) + )[0] + + weighted_contexts = self.FinalAttentionQKV(contexts)[0] + return weighted_contexts + + def forward(self, x, device, info=None): + """extra info is not used here""" + batch_size, time_steps, _ = x.size() + demo_input = x[:, 0, : self.demo_dim] + lab_input = x[:, :, self.demo_dim :] + out = torch.zeros((batch_size, time_steps, self.hidden_dim)) + for cur_time in range(time_steps): + # print(cur_time, end=" ") + cur_lab = lab_input[:, : cur_time + 1, :] + # print("cur_lab", cur_lab.shape) + if cur_time == 0: + out[:, cur_time, :] = self.demo_lab_proj(x[:, 0, :]) + else: + out[:, cur_time, :] = self.concare_encoder(cur_lab, demo_input, device) + # print() + return out