--- a
+++ b/utils.py
@@ -0,0 +1,68 @@
+import torch
+import numpy as np
+from sklearn.metrics import precision_recall_curve, auc
+from torch.utils.data import Dataset
+
+if torch.cuda.is_available():
+    device = 'cuda'
+else:
+    device = 'cpu'
+print(device)
+
+
+def train(data, model, optim, criterion, lbd, max_clip_norm=5):
+    model.train()
+    input = data[:, :-1].to(device)
+    label = data[:, -1].float().to(device)
+    model.train()
+    optim.zero_grad()
+    logits, kld = model(input)
+    logits = logits.squeeze(-1)
+    kld = kld.sum()
+    bce = criterion(logits, label)
+    loss = bce + lbd * kld
+    torch.nn.utils.clip_grad_norm_(model.parameters(), max_clip_norm)
+    loss.backward()
+    optim.step()
+    return loss.item(), kld.item(), bce.item()
+
+
+def evaluate(model, data_iter, length):
+    model.eval()
+    y_pred = np.zeros(length)
+    y_true = np.zeros(length)
+    y_prob = np.zeros(length)
+    pointer = 0
+    for data in data_iter:
+        input = data[:, :-1].to(device)
+        label = data[:, -1]
+        batch_size = len(label)
+        probability, _ = model(input)
+        probability = torch.sigmoid(probability.squeeze(-1).detach())
+        predicted = probability > 0.5
+        y_true[pointer: pointer + batch_size] = label.numpy()
+        y_pred[pointer: pointer + batch_size] = predicted.cpu().numpy()
+        y_prob[pointer: pointer + batch_size] = probability.cpu().numpy()
+        pointer += batch_size
+    precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
+    return auc(recall, precision), (y_pred, y_prob, y_true)
+
+
+class EHRData(Dataset):
+    def __init__(self, data, cla):
+        self.data = data
+        self.cla = cla
+
+    def __len__(self):
+        return len(self.cla)
+
+    def __getitem__(self, idx):
+        return self.data[idx], self.cla[idx]
+
+
+def collate_fn(data):
+    # padding
+    data_list = []
+    for datum in data:
+        data_list.append(np.hstack((datum[0].toarray().ravel(), datum[1])))
+    return torch.from_numpy(np.array(data_list)).long()