Switch to side-by-side view

--- a
+++ b/app/models/backbones/concare.py
@@ -0,0 +1,589 @@
+# import packages
+import copy
+
+# import packages
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.autograd import Variable
+
+
+class SingleAttention(nn.Module):
+    def __init__(
+        self,
+        attention_input_dim,
+        attention_hidden_dim,
+        attention_type="add",
+        demographic_dim=12,
+        time_aware=False,
+        use_demographic=False,
+    ):
+        super(SingleAttention, self).__init__()
+
+        self.attention_type = attention_type
+        self.attention_hidden_dim = attention_hidden_dim
+        self.attention_input_dim = attention_input_dim
+        self.use_demographic = use_demographic
+        self.demographic_dim = demographic_dim
+        self.time_aware = time_aware
+
+        # batch_time = torch.arange(0, batch_mask.size()[1], dtype=torch.float32).reshape(1, batch_mask.size()[1], 1)
+        # batch_time = batch_time.repeat(batch_mask.size()[0], 1, 1)
+
+        if attention_type == "add":
+            if self.time_aware:
+                # self.Wx = nn.Parameter(torch.randn(attention_input_dim+1, attention_hidden_dim))
+                self.Wx = nn.Parameter(
+                    torch.randn(attention_input_dim, attention_hidden_dim)
+                )
+                self.Wtime_aware = nn.Parameter(torch.randn(1, attention_hidden_dim))
+                nn.init.kaiming_uniform_(self.Wtime_aware, a=math.sqrt(5))
+            else:
+                self.Wx = nn.Parameter(
+                    torch.randn(attention_input_dim, attention_hidden_dim)
+                )
+            self.Wt = nn.Parameter(
+                torch.randn(attention_input_dim, attention_hidden_dim)
+            )
+            self.Wd = nn.Parameter(torch.randn(demographic_dim, attention_hidden_dim))
+            self.bh = nn.Parameter(
+                torch.zeros(
+                    attention_hidden_dim,
+                )
+            )
+            self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
+            self.ba = nn.Parameter(
+                torch.zeros(
+                    1,
+                )
+            )
+
+            nn.init.kaiming_uniform_(self.Wd, a=math.sqrt(5))
+            nn.init.kaiming_uniform_(self.Wx, a=math.sqrt(5))
+            nn.init.kaiming_uniform_(self.Wt, a=math.sqrt(5))
+            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
+        elif attention_type == "mul":
+            self.Wa = nn.Parameter(
+                torch.randn(attention_input_dim, attention_input_dim)
+            )
+            self.ba = nn.Parameter(
+                torch.zeros(
+                    1,
+                )
+            )
+
+            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
+        elif attention_type == "concat":
+            if self.time_aware:
+                self.Wh = nn.Parameter(
+                    torch.randn(2 * attention_input_dim + 1, attention_hidden_dim)
+                )
+            else:
+                self.Wh = nn.Parameter(
+                    torch.randn(2 * attention_input_dim, attention_hidden_dim)
+                )
+
+            self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
+            self.ba = nn.Parameter(
+                torch.zeros(
+                    1,
+                )
+            )
+
+            nn.init.kaiming_uniform_(self.Wh, a=math.sqrt(5))
+            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
+
+        elif attention_type == "new":
+            self.Wt = nn.Parameter(
+                torch.randn(attention_input_dim, attention_hidden_dim)
+            )
+            self.Wx = nn.Parameter(
+                torch.randn(attention_input_dim, attention_hidden_dim)
+            )
+
+            self.rate = nn.Parameter(torch.zeros(1) + 0.8)
+            nn.init.kaiming_uniform_(self.Wx, a=math.sqrt(5))
+            nn.init.kaiming_uniform_(self.Wt, a=math.sqrt(5))
+
+        else:
+            raise RuntimeError("Wrong attention type.")
+
+        self.tanh = nn.Tanh()
+        self.softmax = nn.Softmax(dim=1)
+        self.sigmoid = nn.Sigmoid()
+        self.relu = nn.ReLU()
+
+    def forward(self, input, device, demo=None):
+
+        (
+            batch_size,
+            time_step,
+            input_dim,
+        ) = input.size()  # batch_size * time_step * hidden_dim(i)
+
+        time_decays = (
+            torch.tensor(range(time_step - 1, -1, -1), dtype=torch.float32)
+            .unsqueeze(-1).unsqueeze(0).to(device=device)
+        )  # 1*t*1
+        b_time_decays = time_decays.repeat(batch_size, 1, 1) + 1  # b t 1
+
+        if self.attention_type == "add":  # B*T*I  @ H*I
+            q = torch.matmul(input[:, -1, :], self.Wt)  # b h
+            q = torch.reshape(q, (batch_size, 1, self.attention_hidden_dim))  # B*1*H
+            if self.time_aware == True:
+                k = torch.matmul(input, self.Wx)  # b t h
+                time_hidden = torch.matmul(b_time_decays, self.Wtime_aware)  # b t h
+            else:
+                k = torch.matmul(input, self.Wx)  # b t h
+            if self.use_demographic:
+                d = torch.matmul(demo, self.Wd)  # B*H
+                d = torch.reshape(
+                    d, (batch_size, 1, self.attention_hidden_dim)
+                )  # b 1 h
+            h = q + k + self.bh  # b t h
+            if self.time_aware:
+                h += time_hidden
+            h = self.tanh(h)  # B*T*H
+            e = torch.matmul(h, self.Wa) + self.ba  # B*T*1
+            e = torch.reshape(e, (batch_size, time_step))  # b t
+        elif self.attention_type == "mul":
+            e = torch.matmul(input[:, -1, :], self.Wa)  # b i
+            e = (
+                torch.matmul(e.unsqueeze(1), input.permute(0, 2, 1)).squeeze() + self.ba
+            )  # b t
+        elif self.attention_type == "concat":
+            q = input[:, -1, :].unsqueeze(1).repeat(1, time_step, 1)  # b t i
+            k = input
+            c = torch.cat((q, k), dim=-1)  # B*T*2I
+            if self.time_aware:
+                c = torch.cat((c, b_time_decays), dim=-1)  # B*T*2I+1
+            h = torch.matmul(c, self.Wh)
+            h = self.tanh(h)
+            e = torch.matmul(h, self.Wa) + self.ba  # B*T*1
+            e = torch.reshape(e, (batch_size, time_step))  # b t
+
+        elif self.attention_type == "new":
+
+            q = torch.matmul(input[:, -1, :], self.Wt)  # b h
+            q = torch.reshape(q, (batch_size, 1, self.attention_hidden_dim))  # B*1*H
+            k = torch.matmul(input, self.Wx)  # b t h
+            dot_product = torch.matmul(q, k.transpose(1, 2)).squeeze()  # b t
+            denominator = self.sigmoid(self.rate) * (
+                torch.log(2.72 + (1 - self.sigmoid(dot_product)))
+                * (b_time_decays.squeeze())
+            )
+            e = self.relu(self.sigmoid(dot_product) / (denominator))  # b * t
+
+        a = self.softmax(e)  # B*T
+        v = torch.matmul(a.unsqueeze(1), input).squeeze()  # B*I
+
+        return v, a
+
+
+class FinalAttentionQKV(nn.Module):
+    def __init__(
+        self,
+        attention_input_dim,
+        attention_hidden_dim,
+        attention_type="add",
+        dropout=None,
+    ):
+        super(FinalAttentionQKV, self).__init__()
+
+        self.attention_type = attention_type
+        self.attention_hidden_dim = attention_hidden_dim
+        self.attention_input_dim = attention_input_dim
+
+        self.W_q = nn.Linear(attention_input_dim, attention_hidden_dim)
+        self.W_k = nn.Linear(attention_input_dim, attention_hidden_dim)
+        self.W_v = nn.Linear(attention_input_dim, attention_hidden_dim)
+
+        self.W_out = nn.Linear(attention_hidden_dim, 1)
+
+        self.b_in = nn.Parameter(
+            torch.zeros(
+                1,
+            )
+        )
+        self.b_out = nn.Parameter(
+            torch.zeros(
+                1,
+            )
+        )
+
+        nn.init.kaiming_uniform_(self.W_q.weight, a=math.sqrt(5))
+        nn.init.kaiming_uniform_(self.W_k.weight, a=math.sqrt(5))
+        nn.init.kaiming_uniform_(self.W_v.weight, a=math.sqrt(5))
+        nn.init.kaiming_uniform_(self.W_out.weight, a=math.sqrt(5))
+
+        self.Wh = nn.Parameter(
+            torch.randn(2 * attention_input_dim, attention_hidden_dim)
+        )
+        self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
+        self.ba = nn.Parameter(
+            torch.zeros(
+                1,
+            )
+        )
+
+        nn.init.kaiming_uniform_(self.Wh, a=math.sqrt(5))
+        nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
+
+        self.dropout = nn.Dropout(p=dropout)
+        self.tanh = nn.Tanh()
+        self.softmax = nn.Softmax(dim=1)
+        self.sigmoid = nn.Sigmoid()
+
+    def forward(self, input):
+
+        (
+            batch_size,
+            time_step,
+            input_dim,
+        ) = input.size()  # batch_size * input_dim + 1 * hidden_dim(i)
+        input_q = self.W_q(input[:, -1, :])  # b h
+        input_k = self.W_k(input)  # b t h
+        input_v = self.W_v(input)  # b t h
+
+        if self.attention_type == "add":  # B*T*I  @ H*I
+
+            q = torch.reshape(
+                input_q, (batch_size, 1, self.attention_hidden_dim)
+            )  # B*1*H
+            h = q + input_k + self.b_in  # b t h
+            h = self.tanh(h)  # B*T*H
+            e = self.W_out(h)  # b t 1
+            e = torch.reshape(e, (batch_size, time_step))  # b t
+
+        elif self.attention_type == "mul":
+            q = torch.reshape(
+                input_q, (batch_size, self.attention_hidden_dim, 1)
+            )  # B*h 1
+            e = torch.matmul(input_k, q).squeeze()  # b t
+
+        elif self.attention_type == "concat":
+            q = input_q.unsqueeze(1).repeat(1, time_step, 1)  # b t h
+            k = input_k
+            c = torch.cat((q, k), dim=-1)  # B*T*2I
+            h = torch.matmul(c, self.Wh)
+            h = self.tanh(h)
+            e = torch.matmul(h, self.Wa) + self.ba  # B*T*1
+            e = torch.reshape(e, (batch_size, time_step))  # b t
+
+        a = self.softmax(e)  # B*T
+        if self.dropout is not None:
+            a = self.dropout(a)
+        v = torch.matmul(a.unsqueeze(1), input_v).squeeze()  # B*I
+
+        return v, a
+
+
+class PositionwiseFeedForward(nn.Module):  # new added
+    "Implements FFN equation."
+
+    def __init__(self, d_model, d_ff, dropout=0.1):
+        super(PositionwiseFeedForward, self).__init__()
+        self.w_1 = nn.Linear(d_model, d_ff)
+        self.w_2 = nn.Linear(d_ff, d_model)
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(self, x):
+        return self.w_2(self.dropout(F.relu(self.w_1(x)))), None
+
+
+class PositionalEncoding(nn.Module):  # new added / not use anymore
+    "Implement the PE function."
+
+    def __init__(self, d_model, dropout, max_len=400):
+        super(PositionalEncoding, self).__init__()
+        self.dropout = nn.Dropout(p=dropout)
+
+        # Compute the positional encodings once in log space.
+        pe = torch.zeros(max_len, d_model)
+        position = torch.arange(0.0, max_len).unsqueeze(1)
+        div_term = torch.exp(
+            torch.arange(0.0, d_model, 2) * -(math.log(10000.0) / d_model)
+        )
+        pe[:, 0::2] = torch.sin(position * div_term)
+        pe[:, 1::2] = torch.cos(position * div_term)
+        pe = pe.unsqueeze(0)
+        self.register_buffer("pe", pe)
+
+    def forward(self, x):
+        x = x + Variable(self.pe[:, : x.size(1)], requires_grad=False)
+        return self.dropout(x)
+
+
+class MultiHeadedAttention(nn.Module):
+    def __init__(self, h, d_model, dropout=0):
+        "Take in model size and number of heads."
+        super(MultiHeadedAttention, self).__init__()
+        assert d_model % h == 0
+        # We assume d_v always equals d_k
+        self.d_k = d_model // h
+        self.h = h
+        self.linears = nn.ModuleList(
+            [nn.Linear(d_model, self.d_k * self.h) for _ in range(3)]
+        )
+        self.final_linear = nn.Linear(d_model, d_model)
+        self.attn = None
+        self.dropout = nn.Dropout(p=dropout)
+
+    def attention(self, query, key, value, mask=None, dropout=None):
+        "Compute 'Scaled Dot Product Attention'"
+        d_k = query.size(-1)  # b h t d_k
+        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)  # b h t t
+        if mask is not None:  # 1 1 t t
+            scores = scores.masked_fill(mask == 0, -1e9)  # b h t t 下三角
+        p_attn = F.softmax(scores, dim=-1)  # b h t t
+        if dropout is not None:
+            p_attn = dropout(p_attn)
+        return torch.matmul(p_attn, value), p_attn  # b h t v (d_k)
+
+    def cov(self, m, y=None):
+        if y is not None:
+            m = torch.cat((m, y), dim=0)
+        m_exp = torch.mean(m, dim=1)
+        x = m - m_exp[:, None]
+        cov = 1 / (x.size(1) - 1) * x.mm(x.t())
+        return cov
+
+    def forward(self, query, key, value, mask=None):
+        if mask is not None:
+            # Same mask applied to all h heads.
+            mask = mask.unsqueeze(1)  # 1 1 t t
+
+        nbatches = query.size(0)  # b
+        input_dim = query.size(1)  # i+1
+        feature_dim = query.size(1)  # i+1
+
+        # input size -> # batch_size * d_input * hidden_dim
+
+        # d_model => h * d_k
+        query, key, value = [
+            l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
+            for l, x in zip(self.linears, (query, key, value))
+        ]  # b num_head d_input d_k
+
+        x, self.attn = self.attention(
+            query, key, value, mask=mask, dropout=self.dropout
+        )  # b num_head d_input d_v (d_k)
+
+        x = (
+            x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
+        )  # batch_size * d_input * hidden_dim
+
+        # DeCov
+        DeCov_contexts = x.transpose(0, 1).transpose(1, 2)  # I+1 H B
+        Covs = self.cov(DeCov_contexts[0, :, :])
+        DeCov_loss = 0.5 * (
+            torch.norm(Covs, p="fro") ** 2 - torch.norm(torch.diag(Covs)) ** 2
+        )
+        for i in range(feature_dim - 1):
+            Covs = self.cov(DeCov_contexts[i + 1, :, :])
+            DeCov_loss += 0.5 * (
+                torch.norm(Covs, p="fro") ** 2 - torch.norm(torch.diag(Covs)) ** 2
+            )
+
+        return self.final_linear(x), DeCov_loss
+
+
+class LayerNorm(nn.Module):
+    def __init__(self, features, eps=1e-7):
+        super(LayerNorm, self).__init__()
+        self.a_2 = nn.Parameter(torch.ones(features))
+        self.b_2 = nn.Parameter(torch.zeros(features))
+        self.eps = eps
+
+    def forward(self, x):
+        mean = x.mean(-1, keepdim=True)
+        std = x.std(-1, keepdim=True)
+        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
+
+
+class SublayerConnection(nn.Module):
+    """
+    A residual connection followed by a layer norm.
+    Note for code simplicity the norm is first as opposed to last.
+    """
+
+    def __init__(self, size, dropout):
+        super(SublayerConnection, self).__init__()
+        self.norm = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(self, x, sublayer):
+        "Apply residual connection to any sublayer with the same size."
+        returned_value = sublayer(self.norm(x))
+        return x + self.dropout(returned_value[0]), returned_value[1]
+
+
+class ConCare(nn.Module):
+    def __init__(
+        self,
+        lab_dim,  # lab_dim
+        hidden_dim,
+        demo_dim,
+        d_model,
+        MHD_num_head,
+        d_ff,
+        # output_dim,
+        # device,
+        drop=0.5,
+    ):
+        super(ConCare, self).__init__()
+
+        # hyperparameters
+        self.lab_dim = lab_dim
+        self.hidden_dim = hidden_dim  # d_model
+        self.d_model = d_model
+        self.MHD_num_head = MHD_num_head
+        self.d_ff = d_ff
+        # self.output_dim = output_dim
+        self.drop = drop
+        self.demo_dim = demo_dim
+
+        # layers
+        self.PositionalEncoding = PositionalEncoding(
+            self.d_model, dropout=0, max_len=400
+        )
+
+        self.GRUs = nn.ModuleList(
+            [
+                copy.deepcopy(nn.GRU(1, self.hidden_dim, batch_first=True))
+                for _ in range(self.lab_dim)
+            ]
+        )
+        self.LastStepAttentions = nn.ModuleList(
+            [
+                copy.deepcopy(
+                    SingleAttention(
+                        self.hidden_dim,
+                        8,
+                        attention_type="new",
+                        demographic_dim=12,
+                        time_aware=True,
+                        use_demographic=False,
+                    )
+                )
+                for _ in range(self.lab_dim)
+            ]
+        )
+
+        self.FinalAttentionQKV = FinalAttentionQKV(
+            self.hidden_dim,
+            self.hidden_dim,
+            attention_type="mul",
+            dropout=self.drop,
+        )
+
+        self.MultiHeadedAttention = MultiHeadedAttention(
+            self.MHD_num_head, self.d_model, dropout=self.drop
+        )
+        self.SublayerConnection = SublayerConnection(self.d_model, dropout=self.drop)
+
+        self.PositionwiseFeedForward = PositionwiseFeedForward(
+            self.d_model, self.d_ff, dropout=0.1
+        )
+
+        self.demo_lab_proj = nn.Linear(self.demo_dim + self.lab_dim, self.hidden_dim)
+        self.demo_proj_main = nn.Linear(self.demo_dim, self.hidden_dim)
+        self.demo_proj = nn.Linear(self.demo_dim, self.hidden_dim)
+        self.output0 = nn.Linear(self.hidden_dim, self.hidden_dim)
+        # self.output1 = nn.Linear(self.hidden_dim, self.output_dim)
+
+        self.dropout = nn.Dropout(p=self.drop)
+        self.tanh = nn.Tanh()
+        self.softmax = nn.Softmax()
+        self.sigmoid = nn.Sigmoid()
+        self.relu = nn.ReLU()
+
+    def concare_encoder(self, input, demo_input, device):
+
+        # input shape [batch_size, timestep, feature_dim]
+        demo_main = self.tanh(self.demo_proj_main(demo_input)).unsqueeze(
+            1
+        )  # b hidden_dim
+
+        batch_size = input.size(0)
+        time_step = input.size(1)
+        feature_dim = input.size(2)
+        assert feature_dim == self.lab_dim  # input Tensor : 256 * 48 * 76
+        assert self.d_model % self.MHD_num_head == 0
+
+        # forward
+        GRU_embeded_input = self.GRUs[0](
+            input[:, :, 0].unsqueeze(-1).to(device=device),
+            Variable(
+                torch.zeros(batch_size, self.hidden_dim)
+                .to(device=device)
+                .unsqueeze(0)
+            ),
+        )[
+            0
+        ]  # b t h
+        Attention_embeded_input = self.LastStepAttentions[0](GRU_embeded_input, device)[
+            0
+        ].unsqueeze(
+            1
+        )  # b 1 h
+
+        for i in range(feature_dim - 1):
+            embeded_input = self.GRUs[i + 1](
+                input[:, :, i + 1].unsqueeze(-1),
+                Variable(
+                    torch.zeros(batch_size, self.hidden_dim)
+                    .to(device=device)
+                    .unsqueeze(0)
+                ),
+            )[
+                0
+            ]  # b 1 h
+            embeded_input = self.LastStepAttentions[i + 1](embeded_input, device)[0].unsqueeze(
+                1
+            )  # b 1 h
+            Attention_embeded_input = torch.cat(
+                (Attention_embeded_input, embeded_input), 1
+            )  # b i h
+        Attention_embeded_input = torch.cat(
+            (Attention_embeded_input, demo_main), 1
+        )  # b i+1 h
+        posi_input = self.dropout(
+            Attention_embeded_input
+        )  # batch_size * d_input+1 * hidden_dim
+
+        contexts = self.SublayerConnection(
+            posi_input,
+            lambda x: self.MultiHeadedAttention(
+                posi_input, posi_input, posi_input, None
+            ),
+        )  # # batch_size * d_input * hidden_dim
+
+        DeCov_loss = contexts[1]
+        contexts = contexts[0]
+
+        contexts = self.SublayerConnection(
+            contexts, lambda x: self.PositionwiseFeedForward(contexts)
+        )[0]
+
+        weighted_contexts = self.FinalAttentionQKV(contexts)[0]
+        return weighted_contexts
+
+    def forward(self, x, device, info=None):
+        """extra info is not used here"""
+        batch_size, time_steps, _ = x.size()
+        demo_input = x[:, 0, : self.demo_dim]
+        lab_input = x[:, :, self.demo_dim :]
+        out = torch.zeros((batch_size, time_steps, self.hidden_dim))
+        for cur_time in range(time_steps):
+            # print(cur_time, end=" ")
+            cur_lab = lab_input[:, : cur_time + 1, :]
+            # print("cur_lab", cur_lab.shape)
+            if cur_time == 0:
+                out[:, cur_time, :] = self.demo_lab_proj(x[:, 0, :])
+            else:
+                out[:, cur_time, :] = self.concare_encoder(cur_lab, demo_input, device)
+        # print()
+        return out