|
a |
|
b/datasets/tjh/preprocess.py |
|
|
1 |
# %% |
|
|
2 |
# Import necessary packages |
|
|
3 |
import numpy as np |
|
|
4 |
import pandas as pd |
|
|
5 |
import torch |
|
|
6 |
|
|
|
7 |
# %% |
|
|
8 |
# Read raw data |
|
|
9 |
df_train: pd.DataFrame = pd.read_excel( |
|
|
10 |
"./datasets/tongji/raw_data/time_series_375_prerpocess_en.xlsx" |
|
|
11 |
) |
|
|
12 |
|
|
|
13 |
# %% [markdown] |
|
|
14 |
# Steps: |
|
|
15 |
# |
|
|
16 |
# - fill `patient_id` |
|
|
17 |
# - only reserve y-m-d for `RE_DATE` column |
|
|
18 |
# - merge lab tests of the same (patient_id, date) |
|
|
19 |
# - calculate and save features' statistics information (demographic and lab test data are calculated separately) |
|
|
20 |
# - normalize data |
|
|
21 |
# - feature selection |
|
|
22 |
# - fill missing data (our filling strategy will be described below) |
|
|
23 |
# - combine above data to time series data (one patient one record) |
|
|
24 |
# - export to python pickle file |
|
|
25 |
|
|
|
26 |
# %% |
|
|
27 |
# fill `patient_id` rows |
|
|
28 |
df_train["PATIENT_ID"].fillna(method="ffill", inplace=True) |
|
|
29 |
|
|
|
30 |
# gender transformation: 1--male, 0--female |
|
|
31 |
df_train["gender"].replace(2, 0, inplace=True) |
|
|
32 |
|
|
|
33 |
# only reserve y-m-d for `RE_DATE` and `Discharge time` columns |
|
|
34 |
df_train["RE_DATE"] = df_train["RE_DATE"].dt.strftime("%Y-%m-%d") |
|
|
35 |
df_train["Discharge time"] = df_train["Discharge time"].dt.strftime("%Y-%m-%d") |
|
|
36 |
|
|
|
37 |
|
|
|
38 |
# %% |
|
|
39 |
df_train = df_train.dropna( |
|
|
40 |
subset=["PATIENT_ID", "RE_DATE", "Discharge time"], how="any" |
|
|
41 |
) |
|
|
42 |
|
|
|
43 |
# %% |
|
|
44 |
# calculate raw data's los interval |
|
|
45 |
df_grouped = df_train.groupby("PATIENT_ID") |
|
|
46 |
|
|
|
47 |
los_interval_list = [] |
|
|
48 |
los_interval_alive_list = [] |
|
|
49 |
los_interval_dead_list = [] |
|
|
50 |
|
|
|
51 |
for name, group in df_grouped: |
|
|
52 |
sorted_group = group.sort_values(by=["RE_DATE"], ascending=True) |
|
|
53 |
# print(sorted_group['outcome']) |
|
|
54 |
# print('---') |
|
|
55 |
# print(type(sorted_group)) |
|
|
56 |
intervals = sorted_group["RE_DATE"].tolist() |
|
|
57 |
outcome = sorted_group["outcome"].tolist()[0] |
|
|
58 |
cur_visits_len = len(intervals) |
|
|
59 |
# print(cur_visits_len) |
|
|
60 |
if cur_visits_len == 1: |
|
|
61 |
continue |
|
|
62 |
for i in range(1, len(intervals)): |
|
|
63 |
los_interval_list.append( |
|
|
64 |
(pd.to_datetime(intervals[i]) - pd.to_datetime(intervals[i - 1])).days |
|
|
65 |
) |
|
|
66 |
if outcome == 0: |
|
|
67 |
los_interval_alive_list.append( |
|
|
68 |
(pd.to_datetime(intervals[i]) - pd.to_datetime(intervals[i - 1])).days |
|
|
69 |
) |
|
|
70 |
else: |
|
|
71 |
los_interval_dead_list.append( |
|
|
72 |
(pd.to_datetime(intervals[i]) - pd.to_datetime(intervals[i - 1])).days |
|
|
73 |
) |
|
|
74 |
|
|
|
75 |
los_interval_list = np.array(los_interval_list) |
|
|
76 |
los_interval_alive_list = np.array(los_interval_alive_list) |
|
|
77 |
los_interval_dead_list = np.array(los_interval_dead_list) |
|
|
78 |
|
|
|
79 |
output = { |
|
|
80 |
"overall": los_interval_list, |
|
|
81 |
"alive": los_interval_alive_list, |
|
|
82 |
"dead": los_interval_dead_list, |
|
|
83 |
} |
|
|
84 |
# pd.to_pickle(output, 'raw_tjh_los_interval_list.pkl') |
|
|
85 |
|
|
|
86 |
|
|
|
87 |
# %% |
|
|
88 |
# we have 2 types of prediction tasks: 1) predict mortality outcome, 2) length of stay |
|
|
89 |
|
|
|
90 |
# below are all lab test features |
|
|
91 |
labtest_features_str = """ |
|
|
92 |
Hypersensitive cardiac troponinI hemoglobin Serum chloride Prothrombin time procalcitonin eosinophils(%) Interleukin 2 receptor Alkaline phosphatase albumin basophil(%) Interleukin 10 Total bilirubin Platelet count monocytes(%) antithrombin Interleukin 8 indirect bilirubin Red blood cell distribution width neutrophils(%) total protein Quantification of Treponema pallidum antibodies Prothrombin activity HBsAg mean corpuscular volume hematocrit White blood cell count Tumor necrosis factorα mean corpuscular hemoglobin concentration fibrinogen Interleukin 1β Urea lymphocyte count PH value Red blood cell count Eosinophil count Corrected calcium Serum potassium glucose neutrophils count Direct bilirubin Mean platelet volume ferritin RBC distribution width SD Thrombin time (%)lymphocyte HCV antibody quantification D-D dimer Total cholesterol aspartate aminotransferase Uric acid HCO3- calcium Amino-terminal brain natriuretic peptide precursor(NT-proBNP) Lactate dehydrogenase platelet large cell ratio Interleukin 6 Fibrin degradation products monocytes count PLT distribution width globulin γ-glutamyl transpeptidase International standard ratio basophil count(#) 2019-nCoV nucleic acid detection mean corpuscular hemoglobin Activation of partial thromboplastin time Hypersensitive c-reactive protein HIV antibody quantification serum sodium thrombocytocrit ESR glutamic-pyruvic transaminase eGFR creatinine |
|
|
93 |
""" |
|
|
94 |
|
|
|
95 |
# below are 2 demographic features |
|
|
96 |
demographic_features_str = """ |
|
|
97 |
age gender |
|
|
98 |
""" |
|
|
99 |
|
|
|
100 |
labtest_features = [f for f in labtest_features_str.strip().split("\t")] |
|
|
101 |
demographic_features = [f for f in demographic_features_str.strip().split("\t")] |
|
|
102 |
target_features = ["outcome", "LOS"] |
|
|
103 |
|
|
|
104 |
# from our observation, `2019-nCoV nucleic acid detection` feature (in lab test) are all -1 value |
|
|
105 |
# so we remove this feature here |
|
|
106 |
labtest_features.remove("2019-nCoV nucleic acid detection") |
|
|
107 |
|
|
|
108 |
# %% |
|
|
109 |
# if some values are negative, set it as Null |
|
|
110 |
df_train[df_train[demographic_features + labtest_features] < 0] = np.nan |
|
|
111 |
|
|
|
112 |
# %% |
|
|
113 |
# merge lab tests of the same (patient_id, date) |
|
|
114 |
df_train = df_train.groupby( |
|
|
115 |
["PATIENT_ID", "RE_DATE", "Discharge time"], dropna=True, as_index=False |
|
|
116 |
).mean() |
|
|
117 |
|
|
|
118 |
# %% |
|
|
119 |
# calculate length-of-stay lable |
|
|
120 |
df_train["LOS"] = ( |
|
|
121 |
pd.to_datetime(df_train["Discharge time"]) - pd.to_datetime(df_train["RE_DATE"]) |
|
|
122 |
).dt.days |
|
|
123 |
|
|
|
124 |
# %% |
|
|
125 |
# if los values are negative, set it as 0 |
|
|
126 |
df_train["LOS"] = df_train["LOS"].clip(lower=0) |
|
|
127 |
|
|
|
128 |
# %% |
|
|
129 |
# save features' statistics information |
|
|
130 |
|
|
|
131 |
|
|
|
132 |
def calculate_statistic_info(df, features): |
|
|
133 |
"""all values calculated""" |
|
|
134 |
statistic_info = {} |
|
|
135 |
len_df = len(df) |
|
|
136 |
for _, e in enumerate(features): |
|
|
137 |
h = {} |
|
|
138 |
h["count"] = int(df[e].count()) |
|
|
139 |
h["missing"] = str(round(float((100 - df[e].count() * 100 / len_df)), 3)) + "%" |
|
|
140 |
h["mean"] = float(df[e].mean()) |
|
|
141 |
h["max"] = float(df[e].max()) |
|
|
142 |
h["min"] = float(df[e].min()) |
|
|
143 |
h["median"] = float(df[e].median()) |
|
|
144 |
h["std"] = float(df[e].std()) |
|
|
145 |
statistic_info[e] = h |
|
|
146 |
return statistic_info |
|
|
147 |
|
|
|
148 |
|
|
|
149 |
def calculate_middle_part_statistic_info(df, features): |
|
|
150 |
"""calculate 5% ~ 95% percentile data""" |
|
|
151 |
statistic_info = {} |
|
|
152 |
len_df = len(df) |
|
|
153 |
# calculate 5% and 95% percentile of dataframe |
|
|
154 |
middle_part_df_info = df.quantile([0.05, 0.95]) |
|
|
155 |
|
|
|
156 |
for _, e in enumerate(features): |
|
|
157 |
low_value = middle_part_df_info[e][0.05] |
|
|
158 |
high_value = middle_part_df_info[e][0.95] |
|
|
159 |
middle_part_df_element = df.loc[(df[e] >= low_value) & (df[e] <= high_value)][e] |
|
|
160 |
h = {} |
|
|
161 |
h["count"] = int(middle_part_df_element.count()) |
|
|
162 |
h["missing"] = ( |
|
|
163 |
str(round(float((100 - middle_part_df_element.count() * 100 / len_df)), 3)) |
|
|
164 |
+ "%" |
|
|
165 |
) |
|
|
166 |
h["mean"] = float(middle_part_df_element.mean()) |
|
|
167 |
h["max"] = float(middle_part_df_element.max()) |
|
|
168 |
h["min"] = float(middle_part_df_element.min()) |
|
|
169 |
h["median"] = float(middle_part_df_element.median()) |
|
|
170 |
h["std"] = float(middle_part_df_element.std()) |
|
|
171 |
statistic_info[e] = h |
|
|
172 |
return statistic_info |
|
|
173 |
|
|
|
174 |
|
|
|
175 |
# labtest_statistic_info = calculate_statistic_info(df_train, labtest_features) |
|
|
176 |
|
|
|
177 |
|
|
|
178 |
# group by patient_id, then calculate lab test/demographic features' statistics information |
|
|
179 |
groupby_patientid_df = df_train.groupby( |
|
|
180 |
["PATIENT_ID"], dropna=True, as_index=False |
|
|
181 |
).mean() |
|
|
182 |
|
|
|
183 |
|
|
|
184 |
# calculate statistic info (all values calculated) |
|
|
185 |
labtest_patientwise_statistic_info = calculate_statistic_info( |
|
|
186 |
groupby_patientid_df, labtest_features |
|
|
187 |
) |
|
|
188 |
demographic_statistic_info = calculate_statistic_info( |
|
|
189 |
groupby_patientid_df, demographic_features |
|
|
190 |
) # it's also patient-wise |
|
|
191 |
|
|
|
192 |
# calculate statistic info (5% ~ 95% only) |
|
|
193 |
demographic_statistic_info_2 = calculate_middle_part_statistic_info( |
|
|
194 |
groupby_patientid_df, demographic_features |
|
|
195 |
) |
|
|
196 |
labtest_patientwise_statistic_info_2 = calculate_middle_part_statistic_info( |
|
|
197 |
groupby_patientid_df, labtest_features |
|
|
198 |
) |
|
|
199 |
|
|
|
200 |
# take 2 statistics information's union |
|
|
201 |
statistic_info = labtest_patientwise_statistic_info_2 | demographic_statistic_info_2 |
|
|
202 |
|
|
|
203 |
|
|
|
204 |
# %% |
|
|
205 |
# observe features, export to csv file [optional] |
|
|
206 |
to_export_dict = { |
|
|
207 |
"name": [], |
|
|
208 |
"missing_rate": [], |
|
|
209 |
"count": [], |
|
|
210 |
"mean": [], |
|
|
211 |
"max": [], |
|
|
212 |
"min": [], |
|
|
213 |
"median": [], |
|
|
214 |
"std": [], |
|
|
215 |
} |
|
|
216 |
for key in statistic_info: |
|
|
217 |
detail = statistic_info[key] |
|
|
218 |
to_export_dict["name"].append(key) |
|
|
219 |
to_export_dict["count"].append(detail["count"]) |
|
|
220 |
to_export_dict["missing_rate"].append(detail["missing"]) |
|
|
221 |
to_export_dict["mean"].append(detail["mean"]) |
|
|
222 |
to_export_dict["max"].append(detail["max"]) |
|
|
223 |
to_export_dict["min"].append(detail["min"]) |
|
|
224 |
to_export_dict["median"].append(detail["median"]) |
|
|
225 |
to_export_dict["std"].append(detail["std"]) |
|
|
226 |
to_export_df = pd.DataFrame.from_dict(to_export_dict) |
|
|
227 |
# to_export_df.to_csv('statistic_info.csv') |
|
|
228 |
|
|
|
229 |
# %% |
|
|
230 |
# normalize data |
|
|
231 |
def normalize_data(df, features, statistic_info): |
|
|
232 |
|
|
|
233 |
df_features = df[features] |
|
|
234 |
df_features = df_features.apply( |
|
|
235 |
lambda x: (x - statistic_info[x.name]["mean"]) |
|
|
236 |
/ (statistic_info[x.name]["std"] + 1e-12) |
|
|
237 |
) |
|
|
238 |
df = pd.concat( |
|
|
239 |
[df[["PATIENT_ID", "gender", "RE_DATE", "outcome", "LOS"]], df_features], axis=1 |
|
|
240 |
) |
|
|
241 |
return df |
|
|
242 |
|
|
|
243 |
|
|
|
244 |
df_train = normalize_data( |
|
|
245 |
df_train, ["age"] + labtest_features, statistic_info |
|
|
246 |
) # gender don't need to be normalized |
|
|
247 |
|
|
|
248 |
# %% |
|
|
249 |
# filter outliers |
|
|
250 |
def filter_data(df, features, bar=3): |
|
|
251 |
for f in features: |
|
|
252 |
df[f] = df[f].mask(df[f].abs().gt(bar)) |
|
|
253 |
return df |
|
|
254 |
|
|
|
255 |
|
|
|
256 |
df_train = filter_data(df_train, demographic_features + labtest_features, bar=3) |
|
|
257 |
|
|
|
258 |
# %% |
|
|
259 |
# drop rows if all labtest_features are recorded nan |
|
|
260 |
df_train = df_train.dropna(subset=labtest_features, how="all") |
|
|
261 |
|
|
|
262 |
# %% |
|
|
263 |
# Calculate data statistics after preprocessing steps (before imputation) |
|
|
264 |
|
|
|
265 |
# Step 1: reverse z-score normalization operation |
|
|
266 |
df_reverse = df_train |
|
|
267 |
# reverse normalize data |
|
|
268 |
def reverse_normalize_data(df, features, statistic_info): |
|
|
269 |
df_features = df[features] |
|
|
270 |
df_features = df_features.apply( |
|
|
271 |
lambda x: x * (statistic_info[x.name]["std"] + 1e-12) |
|
|
272 |
+ statistic_info[x.name]["mean"] |
|
|
273 |
) |
|
|
274 |
df = pd.concat( |
|
|
275 |
[df[["PATIENT_ID", "gender", "RE_DATE", "outcome", "LOS"]], df_features], axis=1 |
|
|
276 |
) |
|
|
277 |
return df |
|
|
278 |
|
|
|
279 |
|
|
|
280 |
df_reverse = reverse_normalize_data( |
|
|
281 |
df_reverse, ["age"] + labtest_features, statistic_info |
|
|
282 |
) # gender don't need to be normalized |
|
|
283 |
|
|
|
284 |
statistics = {} |
|
|
285 |
|
|
|
286 |
for f in demographic_features + labtest_features: |
|
|
287 |
statistics[f] = {} |
|
|
288 |
|
|
|
289 |
|
|
|
290 |
def calculate_quantile_statistic_info(df, features, case): |
|
|
291 |
"""all values calculated""" |
|
|
292 |
for _, e in enumerate(features): |
|
|
293 |
# print(e, lo, mi, hi) |
|
|
294 |
if e == "gender": |
|
|
295 |
unique, count = np.unique(df[e], return_counts=True) |
|
|
296 |
data_count = dict(zip(unique, count)) # key = 1 male, 0 female |
|
|
297 |
print(data_count) |
|
|
298 |
male_percentage = ( |
|
|
299 |
data_count[1.0] * 100 / (data_count[1.0] + data_count[0.0]) |
|
|
300 |
) |
|
|
301 |
statistics[e][case] = f"{male_percentage:.2f}% Male" |
|
|
302 |
print(statistics[e][case]) |
|
|
303 |
else: |
|
|
304 |
lo = round(np.nanpercentile(df[e], 25), 2) |
|
|
305 |
mi = round(np.nanpercentile(df[e], 50), 2) |
|
|
306 |
hi = round(np.nanpercentile(df[e], 75), 2) |
|
|
307 |
statistics[e][case] = f"{mi:.2f} [{lo:.2f}, {hi:.2f}]" |
|
|
308 |
|
|
|
309 |
|
|
|
310 |
def calculate_missing_rate(df, features, case="missing_rate"): |
|
|
311 |
for _, e in enumerate(features): |
|
|
312 |
missing_rate = round(float(df[e].isnull().sum() * 100 / df[e].shape[0]), 2) |
|
|
313 |
statistics[e][case] = f"{missing_rate:.2f}%" |
|
|
314 |
|
|
|
315 |
|
|
|
316 |
tmp_groupby_pid = df_reverse.groupby(["PATIENT_ID"], dropna=True, as_index=False).mean() |
|
|
317 |
|
|
|
318 |
calculate_quantile_statistic_info(tmp_groupby_pid, demographic_features, "overall") |
|
|
319 |
calculate_quantile_statistic_info( |
|
|
320 |
tmp_groupby_pid[tmp_groupby_pid["outcome"] == 0], demographic_features, "alive" |
|
|
321 |
) |
|
|
322 |
calculate_quantile_statistic_info( |
|
|
323 |
tmp_groupby_pid[tmp_groupby_pid["outcome"] == 1], demographic_features, "dead" |
|
|
324 |
) |
|
|
325 |
|
|
|
326 |
calculate_quantile_statistic_info(df_reverse, labtest_features, "overall") |
|
|
327 |
calculate_quantile_statistic_info( |
|
|
328 |
df_reverse[df_reverse["outcome"] == 0], labtest_features, "alive" |
|
|
329 |
) |
|
|
330 |
calculate_quantile_statistic_info( |
|
|
331 |
df_reverse[df_reverse["outcome"] == 1], labtest_features, "dead" |
|
|
332 |
) |
|
|
333 |
|
|
|
334 |
calculate_missing_rate( |
|
|
335 |
df_reverse, demographic_features + labtest_features, "missing_rate" |
|
|
336 |
) |
|
|
337 |
|
|
|
338 |
export_quantile_statistics = { |
|
|
339 |
"Characteristics": [], |
|
|
340 |
"Overall": [], |
|
|
341 |
"Alive": [], |
|
|
342 |
"Dead": [], |
|
|
343 |
"Missing Rate": [], |
|
|
344 |
} |
|
|
345 |
for f in demographic_features + labtest_features: |
|
|
346 |
export_quantile_statistics["Characteristics"].append(f) |
|
|
347 |
export_quantile_statistics["Overall"].append(statistics[f]["overall"]) |
|
|
348 |
export_quantile_statistics["Alive"].append(statistics[f]["alive"]) |
|
|
349 |
export_quantile_statistics["Dead"].append(statistics[f]["dead"]) |
|
|
350 |
export_quantile_statistics["Missing Rate"].append(statistics[f]["missing_rate"]) |
|
|
351 |
|
|
|
352 |
# pd.DataFrame.from_dict(export_quantile_statistics).to_csv('statistics.csv') |
|
|
353 |
|
|
|
354 |
# %% |
|
|
355 |
def calculate_data_existing_length(data): |
|
|
356 |
res = 0 |
|
|
357 |
for i in data: |
|
|
358 |
if not pd.isna(i): |
|
|
359 |
res += 1 |
|
|
360 |
return res |
|
|
361 |
|
|
|
362 |
|
|
|
363 |
# elements in data are sorted in time ascending order |
|
|
364 |
def fill_missing_value(data, to_fill_value=0): |
|
|
365 |
data_len = len(data) |
|
|
366 |
data_exist_len = calculate_data_existing_length(data) |
|
|
367 |
if data_len == data_exist_len: |
|
|
368 |
return data |
|
|
369 |
elif data_exist_len == 0: |
|
|
370 |
# data = [to_fill_value for _ in range(data_len)] |
|
|
371 |
for i in range(data_len): |
|
|
372 |
data[i] = to_fill_value |
|
|
373 |
return data |
|
|
374 |
if pd.isna(data[0]): |
|
|
375 |
# find the first non-nan value's position |
|
|
376 |
not_na_pos = 0 |
|
|
377 |
for i in range(data_len): |
|
|
378 |
if not pd.isna(data[i]): |
|
|
379 |
not_na_pos = i |
|
|
380 |
break |
|
|
381 |
# fill element before the first non-nan value with median |
|
|
382 |
for i in range(not_na_pos): |
|
|
383 |
data[i] = to_fill_value |
|
|
384 |
# fill element after the first non-nan value |
|
|
385 |
for i in range(1, data_len): |
|
|
386 |
if pd.isna(data[i]): |
|
|
387 |
data[i] = data[i - 1] |
|
|
388 |
return data |
|
|
389 |
|
|
|
390 |
|
|
|
391 |
# %% |
|
|
392 |
# fill missing data using our strategy and convert to time series records |
|
|
393 |
grouped = df_train.groupby("PATIENT_ID") |
|
|
394 |
|
|
|
395 |
all_x_demographic = [] |
|
|
396 |
all_x_labtest = [] |
|
|
397 |
all_y = [] |
|
|
398 |
all_missing_mask = [] |
|
|
399 |
|
|
|
400 |
for name, group in grouped: |
|
|
401 |
sorted_group = group.sort_values(by=["RE_DATE"], ascending=True) |
|
|
402 |
patient_demographic = [] |
|
|
403 |
patient_labtest = [] |
|
|
404 |
patient_y = [] |
|
|
405 |
|
|
|
406 |
for f in demographic_features + labtest_features: |
|
|
407 |
to_fill_value = (statistic_info[f]["median"] - statistic_info[f]["mean"]) / ( |
|
|
408 |
statistic_info[f]["std"] + 1e-12 |
|
|
409 |
) |
|
|
410 |
# take median patient as the default to-fill missing value |
|
|
411 |
# print(sorted_group[f].values) |
|
|
412 |
fill_missing_value(sorted_group[f].values, to_fill_value) |
|
|
413 |
# print(sorted_group[f].values) |
|
|
414 |
# print('-----------') |
|
|
415 |
all_missing_mask.append( |
|
|
416 |
( |
|
|
417 |
np.isfinite( |
|
|
418 |
sorted_group[demographic_features + labtest_features].to_numpy() |
|
|
419 |
) |
|
|
420 |
).astype(int) |
|
|
421 |
) |
|
|
422 |
|
|
|
423 |
for _, v in sorted_group.iterrows(): |
|
|
424 |
patient_y.append([v["outcome"], v["LOS"]]) |
|
|
425 |
demo = [] |
|
|
426 |
lab = [] |
|
|
427 |
for f in demographic_features: |
|
|
428 |
demo.append(v[f]) |
|
|
429 |
for f in labtest_features: |
|
|
430 |
lab.append(v[f]) |
|
|
431 |
patient_labtest.append(lab) |
|
|
432 |
patient_demographic.append(demo) |
|
|
433 |
all_y.append(patient_y) |
|
|
434 |
all_x_demographic.append(patient_demographic[-1]) |
|
|
435 |
all_x_labtest.append(patient_labtest) |
|
|
436 |
|
|
|
437 |
# all_x_demographic (2 dim, record each patients' demographic features) |
|
|
438 |
# all_x_labtest (3 dim, record each patients' lab test features) |
|
|
439 |
# all_y (3 dim, patients' outcome/los of all visits) |
|
|
440 |
|
|
|
441 |
# %% |
|
|
442 |
all_x_labtest = np.array(all_x_labtest, dtype=object) |
|
|
443 |
x_lab_length = [len(_) for _ in all_x_labtest] |
|
|
444 |
x_lab_length = torch.tensor(x_lab_length, dtype=torch.int) |
|
|
445 |
max_length = int(x_lab_length.max()) |
|
|
446 |
all_x_labtest = [torch.tensor(_) for _ in all_x_labtest] |
|
|
447 |
# pad lab test sequence to the same shape |
|
|
448 |
all_x_labtest = torch.nn.utils.rnn.pad_sequence((all_x_labtest), batch_first=True) |
|
|
449 |
|
|
|
450 |
all_x_demographic = torch.tensor(all_x_demographic) |
|
|
451 |
batch_size, demo_dim = all_x_demographic.shape |
|
|
452 |
# repeat demographic tensor |
|
|
453 |
all_x_demographic = torch.reshape( |
|
|
454 |
all_x_demographic.repeat(1, max_length), (batch_size, max_length, demo_dim) |
|
|
455 |
) |
|
|
456 |
# demographic tensor concat with lab test tensor |
|
|
457 |
all_x = torch.cat((all_x_demographic, all_x_labtest), 2) |
|
|
458 |
|
|
|
459 |
all_y = np.array(all_y, dtype=object) |
|
|
460 |
all_y = [torch.Tensor(_) for _ in all_y] |
|
|
461 |
# pad [outcome/los] sequence as well |
|
|
462 |
all_y = torch.nn.utils.rnn.pad_sequence((all_y), batch_first=True) |
|
|
463 |
|
|
|
464 |
all_missing_mask = np.array(all_missing_mask, dtype=object) |
|
|
465 |
all_missing_mask = [torch.tensor(_) for _ in all_missing_mask] |
|
|
466 |
all_missing_mask = torch.nn.utils.rnn.pad_sequence((all_missing_mask), batch_first=True) |
|
|
467 |
|
|
|
468 |
# %% |
|
|
469 |
# save pickle format dataset (export torch tensor) |
|
|
470 |
pd.to_pickle(all_x, f"./datasets/tongji/processed_data/x.pkl") |
|
|
471 |
pd.to_pickle(all_y, f"./datasets/tongji/processed_data/y.pkl") |
|
|
472 |
pd.to_pickle(x_lab_length, f"./datasets/tongji/processed_data/visits_length.pkl") |
|
|
473 |
pd.to_pickle(all_missing_mask, f"./datasets/tongji/processed_data/missing_mask.pkl") |
|
|
474 |
|
|
|
475 |
# %% |
|
|
476 |
# Calculate patients' outcome statistics (patients-wise) |
|
|
477 |
outcome_list = [] |
|
|
478 |
y_outcome = all_y[:, :, 0] |
|
|
479 |
indices = torch.arange(len(x_lab_length), dtype=torch.int64) |
|
|
480 |
for i in indices: |
|
|
481 |
outcome_list.append(y_outcome[i][0].item()) |
|
|
482 |
outcome_list = np.array(outcome_list) |
|
|
483 |
print(len(outcome_list)) |
|
|
484 |
unique, count = np.unique(outcome_list, return_counts=True) |
|
|
485 |
data_count = dict(zip(unique, count)) |
|
|
486 |
print(data_count) |
|
|
487 |
|
|
|
488 |
# %% |
|
|
489 |
# Calculate patients' outcome statistics (records-wise) |
|
|
490 |
outcome_records_list = [] |
|
|
491 |
y_outcome = all_y[:, :, 0] |
|
|
492 |
indices = torch.arange(len(x_lab_length), dtype=torch.int64) |
|
|
493 |
for i in indices: |
|
|
494 |
outcome_records_list.extend(y_outcome[i][0 : x_lab_length[i]].tolist()) |
|
|
495 |
outcome_records_list = np.array(outcome_records_list) |
|
|
496 |
print(len(outcome_records_list)) |
|
|
497 |
unique, count = np.unique(outcome_records_list, return_counts=True) |
|
|
498 |
data_count = dict(zip(unique, count)) |
|
|
499 |
print(data_count) |
|
|
500 |
|
|
|
501 |
# %% |
|
|
502 |
# Calculate patients' mean los and 95% percentile los |
|
|
503 |
los_list = [] |
|
|
504 |
y_los = all_y[:, :, 1] |
|
|
505 |
indices = torch.arange(len(x_lab_length), dtype=torch.int64) |
|
|
506 |
for i in indices: |
|
|
507 |
# los_list.extend(y_los[i][: x_lab_length[i].long()].tolist()) |
|
|
508 |
los_list.append(y_los[i][0].item()) |
|
|
509 |
los_list = np.array(los_list) |
|
|
510 |
print(los_list.mean() * 0.5) |
|
|
511 |
print(np.median(los_list) * 0.5) |
|
|
512 |
print(np.percentile(los_list, 95)) |
|
|
513 |
|
|
|
514 |
print("median:", np.median(los_list)) |
|
|
515 |
print("Q1:", np.percentile(los_list, 25)) |
|
|
516 |
print("Q3:", np.percentile(los_list, 75)) |
|
|
517 |
|
|
|
518 |
# %% |
|
|
519 |
los_alive_list = np.array( |
|
|
520 |
[los_list[i] for i in range(len(los_list)) if outcome_list[i] == 0] |
|
|
521 |
) |
|
|
522 |
los_dead_list = np.array( |
|
|
523 |
[los_list[i] for i in range(len(los_list)) if outcome_list[i] == 1] |
|
|
524 |
) |
|
|
525 |
print(len(los_alive_list)) |
|
|
526 |
print(len(los_dead_list)) |
|
|
527 |
|
|
|
528 |
print("[Alive]") |
|
|
529 |
print("median:", np.median(los_alive_list)) |
|
|
530 |
print("Q1:", np.percentile(los_alive_list, 25)) |
|
|
531 |
print("Q3:", np.percentile(los_alive_list, 75)) |
|
|
532 |
|
|
|
533 |
print("[Dead]") |
|
|
534 |
print("median:", np.median(los_dead_list)) |
|
|
535 |
print("Q1:", np.percentile(los_dead_list, 25)) |
|
|
536 |
print("Q3:", np.percentile(los_dead_list, 75)) |
|
|
537 |
|
|
|
538 |
# %% |
|
|
539 |
tjh_los_statistics = { |
|
|
540 |
"overall": los_list, |
|
|
541 |
"alive": los_alive_list, |
|
|
542 |
"dead": los_dead_list, |
|
|
543 |
} |
|
|
544 |
# pd.to_pickle(tjh_los_statistics, 'tjh_los_statistics.pkl') |
|
|
545 |
|
|
|
546 |
# %% |
|
|
547 |
# calculate visits length Median [Q1, Q3] |
|
|
548 |
visits_list = np.array(x_lab_length) |
|
|
549 |
visits_alive_list = np.array( |
|
|
550 |
[x_lab_length[i] for i in range(len(x_lab_length)) if outcome_list[i] == 0] |
|
|
551 |
) |
|
|
552 |
visits_dead_list = np.array( |
|
|
553 |
[x_lab_length[i] for i in range(len(x_lab_length)) if outcome_list[i] == 1] |
|
|
554 |
) |
|
|
555 |
print(len(visits_alive_list)) |
|
|
556 |
print(len(visits_dead_list)) |
|
|
557 |
|
|
|
558 |
print("[Total]") |
|
|
559 |
print("median:", np.median(visits_list)) |
|
|
560 |
print("Q1:", np.percentile(visits_list, 25)) |
|
|
561 |
print("Q3:", np.percentile(visits_list, 75)) |
|
|
562 |
|
|
|
563 |
print("[Alive]") |
|
|
564 |
print("median:", np.median(visits_alive_list)) |
|
|
565 |
print("Q1:", np.percentile(visits_alive_list, 25)) |
|
|
566 |
print("Q3:", np.percentile(visits_alive_list, 75)) |
|
|
567 |
|
|
|
568 |
print("[Dead]") |
|
|
569 |
print("median:", np.median(visits_dead_list)) |
|
|
570 |
print("Q1:", np.percentile(visits_dead_list, 25)) |
|
|
571 |
print("Q3:", np.percentile(visits_dead_list, 75)) |
|
|
572 |
|
|
|
573 |
# %% |
|
|
574 |
# Length-of-stay interval (overall/alive/dead) |
|
|
575 |
los_interval_list = [] |
|
|
576 |
los_interval_alive_list = [] |
|
|
577 |
los_interval_dead_list = [] |
|
|
578 |
|
|
|
579 |
y_los = all_y[:, :, 1] |
|
|
580 |
indices = torch.arange(len(x_lab_length), dtype=torch.int64) |
|
|
581 |
for i in indices: |
|
|
582 |
cur_visits_len = x_lab_length[i] |
|
|
583 |
if cur_visits_len == 1: |
|
|
584 |
continue |
|
|
585 |
for j in range(1, cur_visits_len): |
|
|
586 |
los_interval_list.append(y_los[i][j - 1] - y_los[i][j]) |
|
|
587 |
if outcome_list[i] == 0: |
|
|
588 |
los_interval_alive_list.append(y_los[i][j - 1] - y_los[i][j]) |
|
|
589 |
else: |
|
|
590 |
los_interval_dead_list.append(y_los[i][j - 1] - y_los[i][j]) |
|
|
591 |
|
|
|
592 |
los_interval_list = np.array(los_interval_list) |
|
|
593 |
los_interval_alive_list = np.array(los_interval_alive_list) |
|
|
594 |
los_interval_dead_list = np.array(los_interval_dead_list) |
|
|
595 |
|
|
|
596 |
output = { |
|
|
597 |
"overall": los_interval_list, |
|
|
598 |
"alive": los_interval_alive_list, |
|
|
599 |
"dead": los_interval_dead_list, |
|
|
600 |
} |
|
|
601 |
# pd.to_pickle(output, 'tjh_los_interval_list.pkl') |
|
|
602 |
|
|
|
603 |
# %% |
|
|
604 |
len(los_interval_list), len(los_interval_alive_list), len(los_interval_dead_list) |
|
|
605 |
|
|
|
606 |
# %% |
|
|
607 |
def check_nan(x): |
|
|
608 |
if np.isnan(np.sum(x.cpu().numpy())): |
|
|
609 |
print("some values from input are nan") |
|
|
610 |
else: |
|
|
611 |
print("no nan") |