|
a |
|
b/datasets/utils/tools.py |
|
|
1 |
import pandas as pd |
|
|
2 |
|
|
|
3 |
|
|
|
4 |
def df_column_switch(df: pd.DataFrame, column1, column2): |
|
|
5 |
i = list(df.columns) |
|
|
6 |
a, b = i.index(column1), i.index(column2) |
|
|
7 |
i[b], i[a] = i[a], i[b] |
|
|
8 |
df = df[i] |
|
|
9 |
return df |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
def calculate_data_existing_length(data): |
|
|
13 |
res = 0 |
|
|
14 |
for i in data: |
|
|
15 |
if not pd.isna(i): |
|
|
16 |
res += 1 |
|
|
17 |
return res |
|
|
18 |
|
|
|
19 |
|
|
|
20 |
# elements in data are sorted in time ascending order |
|
|
21 |
def fill_missing_value(data, to_fill_value=0): |
|
|
22 |
data_len = len(data) |
|
|
23 |
data_exist_len = calculate_data_existing_length(data) |
|
|
24 |
if data_len == data_exist_len: |
|
|
25 |
return data |
|
|
26 |
elif data_exist_len == 0: |
|
|
27 |
# data = [to_fill_value for _ in range(data_len)] |
|
|
28 |
for i in range(data_len): |
|
|
29 |
data[i] = to_fill_value |
|
|
30 |
return data |
|
|
31 |
if pd.isna(data[0]): |
|
|
32 |
# find the first non-nan value's position |
|
|
33 |
not_na_pos = 0 |
|
|
34 |
for i in range(data_len): |
|
|
35 |
if not pd.isna(data[i]): |
|
|
36 |
not_na_pos = i |
|
|
37 |
break |
|
|
38 |
# fill element before the first non-nan value with median |
|
|
39 |
for i in range(not_na_pos): |
|
|
40 |
data[i] = to_fill_value |
|
|
41 |
# fill element after the first non-nan value |
|
|
42 |
for i in range(1, data_len): |
|
|
43 |
if pd.isna(data[i]): |
|
|
44 |
data[i] = data[i - 1] |
|
|
45 |
return data |
|
|
46 |
|
|
|
47 |
|
|
|
48 |
def forward_fill_pipeline( |
|
|
49 |
df: pd.DataFrame, |
|
|
50 |
default_fill: pd.DataFrame, |
|
|
51 |
demographic_features: list[str], |
|
|
52 |
labtest_features: list[str], |
|
|
53 |
): |
|
|
54 |
grouped = df.groupby("PatientID") |
|
|
55 |
|
|
|
56 |
all_x = [] |
|
|
57 |
all_y = [] |
|
|
58 |
all_pid = [] |
|
|
59 |
|
|
|
60 |
for name, group in grouped: |
|
|
61 |
sorted_group = group.sort_values(by=["RecordTime"], ascending=True) |
|
|
62 |
patient_x = [] |
|
|
63 |
patient_y = [] |
|
|
64 |
|
|
|
65 |
for f in ["Age"] + labtest_features: |
|
|
66 |
to_fill_value = default_fill[f] |
|
|
67 |
# take median patient as the default to-fill missing value |
|
|
68 |
fill_missing_value(sorted_group[f].values, to_fill_value) |
|
|
69 |
|
|
|
70 |
for _, v in sorted_group.iterrows(): |
|
|
71 |
patient_y.append([v["Outcome"], v["LOS"]]) |
|
|
72 |
x = [] |
|
|
73 |
for f in demographic_features + labtest_features: |
|
|
74 |
x.append(v[f]) |
|
|
75 |
patient_x.append(x) |
|
|
76 |
all_x.append(patient_x) |
|
|
77 |
all_y.append(patient_y) |
|
|
78 |
all_pid.append(name) |
|
|
79 |
return all_x, all_y, all_pid |
|
|
80 |
|
|
|
81 |
def normalize_dataframe(train_df, val_df, test_df, normalize_features): |
|
|
82 |
# Calculate the quantiles |
|
|
83 |
q_low = train_df[normalize_features].quantile(0.05) |
|
|
84 |
q_high = train_df[normalize_features].quantile(0.95) |
|
|
85 |
|
|
|
86 |
# Filter the DataFrame based on the quantiles |
|
|
87 |
filtered_df = train_df[(train_df[normalize_features] > q_low) & (train_df[normalize_features] < q_high)] |
|
|
88 |
|
|
|
89 |
# Calculate the mean and standard deviation and median of the filtered data, also the default fill value |
|
|
90 |
train_mean = filtered_df[normalize_features].mean() |
|
|
91 |
train_std = filtered_df[normalize_features].std() |
|
|
92 |
train_median = filtered_df[normalize_features].median() |
|
|
93 |
default_fill: pd.DataFrame = (train_median-train_mean)/(train_std+1e-12) |
|
|
94 |
|
|
|
95 |
# LOS info |
|
|
96 |
LOS_info = {"mean": train_mean["LOS"], "std": train_std["LOS"], "median": train_median["LOS"]} |
|
|
97 |
|
|
|
98 |
# Z-score normalize the train, val, and test sets with train_mean and train_std |
|
|
99 |
train_df[normalize_features] = (train_df[normalize_features] - train_mean) / (train_std+1e-12) |
|
|
100 |
val_df[normalize_features] = (val_df[normalize_features] - train_mean) / (train_std+1e-12) |
|
|
101 |
test_df[normalize_features] = (test_df[normalize_features] - train_mean) / (train_std+1e-12) |
|
|
102 |
|
|
|
103 |
return train_df, val_df, test_df, default_fill, LOS_info |