Diff of /app/datasets/ml.py [000000] .. [d6904d]

Switch to unified view

a b/app/datasets/ml.py
1
import numpy as np
2
import torch
3
4
5
def flatten_dataset(x, y, indices, visits_length, case="los"):
6
    x_flat = []
7
    y_flat = []
8
    len_list = []
9
    for i in indices:
10
        len_list.append(visits_length[i])
11
        for v in range(visits_length[i]):
12
            x_flat.append(x[i][v])
13
            if case == "los":
14
                y_flat.append(y[i][v][1])
15
            elif case == "outcome":
16
                y_flat.append(y[i][v].tolist())
17
    return np.array(x_flat), np.array(y_flat), np.array(len_list)
18
19
20
def numpy_dataset(x, y, x_lab_length):
21
    x = x.numpy()
22
    y = y.numpy()
23
    x_lab_length = x_lab_length.numpy()
24
    y_los = y[:, :, 1]
25
    y_outcome = y[:, 0, 0]
26
    return x, y_outcome, y_los, x_lab_length