Switch to unified view

a b/datasets/tjh/preprocess.py
1
# %%
2
# Import necessary packages
3
import numpy as np
4
import pandas as pd
5
import torch
6
7
# %%
8
# Read raw data
9
df_train: pd.DataFrame = pd.read_excel(
10
    "./datasets/tongji/raw_data/time_series_375_prerpocess_en.xlsx"
11
)
12
13
# %% [markdown]
14
# Steps:
15
#
16
# - fill `patient_id`
17
# - only reserve y-m-d for `RE_DATE` column
18
# - merge lab tests of the same (patient_id, date)
19
# - calculate and save features' statistics information (demographic and lab test data are calculated separately)
20
# - normalize data
21
# - feature selection
22
# - fill missing data (our filling strategy will be described below)
23
# - combine above data to time series data (one patient one record)
24
# - export to python pickle file
25
26
# %%
27
# fill `patient_id` rows
28
df_train["PATIENT_ID"].fillna(method="ffill", inplace=True)
29
30
# gender transformation: 1--male, 0--female
31
df_train["gender"].replace(2, 0, inplace=True)
32
33
# only reserve y-m-d for `RE_DATE` and `Discharge time` columns
34
df_train["RE_DATE"] = df_train["RE_DATE"].dt.strftime("%Y-%m-%d")
35
df_train["Discharge time"] = df_train["Discharge time"].dt.strftime("%Y-%m-%d")
36
37
38
# %%
39
df_train = df_train.dropna(
40
    subset=["PATIENT_ID", "RE_DATE", "Discharge time"], how="any"
41
)
42
43
# %%
44
# calculate raw data's los interval
45
df_grouped = df_train.groupby("PATIENT_ID")
46
47
los_interval_list = []
48
los_interval_alive_list = []
49
los_interval_dead_list = []
50
51
for name, group in df_grouped:
52
    sorted_group = group.sort_values(by=["RE_DATE"], ascending=True)
53
    # print(sorted_group['outcome'])
54
    # print('---')
55
    # print(type(sorted_group))
56
    intervals = sorted_group["RE_DATE"].tolist()
57
    outcome = sorted_group["outcome"].tolist()[0]
58
    cur_visits_len = len(intervals)
59
    # print(cur_visits_len)
60
    if cur_visits_len == 1:
61
        continue
62
    for i in range(1, len(intervals)):
63
        los_interval_list.append(
64
            (pd.to_datetime(intervals[i]) - pd.to_datetime(intervals[i - 1])).days
65
        )
66
        if outcome == 0:
67
            los_interval_alive_list.append(
68
                (pd.to_datetime(intervals[i]) - pd.to_datetime(intervals[i - 1])).days
69
            )
70
        else:
71
            los_interval_dead_list.append(
72
                (pd.to_datetime(intervals[i]) - pd.to_datetime(intervals[i - 1])).days
73
            )
74
75
los_interval_list = np.array(los_interval_list)
76
los_interval_alive_list = np.array(los_interval_alive_list)
77
los_interval_dead_list = np.array(los_interval_dead_list)
78
79
output = {
80
    "overall": los_interval_list,
81
    "alive": los_interval_alive_list,
82
    "dead": los_interval_dead_list,
83
}
84
# pd.to_pickle(output, 'raw_tjh_los_interval_list.pkl')
85
86
87
# %%
88
# we have 2 types of prediction tasks: 1) predict mortality outcome, 2) length of stay
89
90
# below are all lab test features
91
labtest_features_str = """
92
Hypersensitive cardiac troponinI    hemoglobin  Serum chloride  Prothrombin time    procalcitonin   eosinophils(%)  Interleukin 2 receptor  Alkaline phosphatase    albumin basophil(%) Interleukin 10  Total bilirubin Platelet count  monocytes(%)    antithrombin    Interleukin 8   indirect bilirubin  Red blood cell distribution width   neutrophils(%)  total protein   Quantification of Treponema pallidum antibodies Prothrombin activity    HBsAg   mean corpuscular volume hematocrit  White blood cell count  Tumor necrosis factorα  mean corpuscular hemoglobin concentration   fibrinogen  Interleukin 1β  Urea    lymphocyte count    PH value    Red blood cell count    Eosinophil count    Corrected calcium   Serum potassium glucose neutrophils count   Direct bilirubin    Mean platelet volume    ferritin    RBC distribution width SD   Thrombin time   (%)lymphocyte   HCV antibody quantification D-D dimer   Total cholesterol   aspartate aminotransferase  Uric acid   HCO3-   calcium Amino-terminal brain natriuretic peptide precursor(NT-proBNP)   Lactate dehydrogenase   platelet large cell ratio   Interleukin 6   Fibrin degradation products monocytes count PLT distribution width  globulin    γ-glutamyl transpeptidase   International standard ratio    basophil count(#)   2019-nCoV nucleic acid detection    mean corpuscular hemoglobin     Activation of partial thromboplastin time   Hypersensitive c-reactive protein   HIV antibody quantification serum sodium    thrombocytocrit ESR glutamic-pyruvic transaminase   eGFR    creatinine
93
"""
94
95
# below are 2 demographic features
96
demographic_features_str = """
97
age gender
98
"""
99
100
labtest_features = [f for f in labtest_features_str.strip().split("\t")]
101
demographic_features = [f for f in demographic_features_str.strip().split("\t")]
102
target_features = ["outcome", "LOS"]
103
104
# from our observation, `2019-nCoV nucleic acid detection` feature (in lab test) are all -1 value
105
# so we remove this feature here
106
labtest_features.remove("2019-nCoV nucleic acid detection")
107
108
# %%
109
# if some values are negative, set it as Null
110
df_train[df_train[demographic_features + labtest_features] < 0] = np.nan
111
112
# %%
113
# merge lab tests of the same (patient_id, date)
114
df_train = df_train.groupby(
115
    ["PATIENT_ID", "RE_DATE", "Discharge time"], dropna=True, as_index=False
116
).mean()
117
118
# %%
119
# calculate length-of-stay lable
120
df_train["LOS"] = (
121
    pd.to_datetime(df_train["Discharge time"]) - pd.to_datetime(df_train["RE_DATE"])
122
).dt.days
123
124
# %%
125
# if los values are negative, set it as 0
126
df_train["LOS"] = df_train["LOS"].clip(lower=0)
127
128
# %%
129
# save features' statistics information
130
131
132
def calculate_statistic_info(df, features):
133
    """all values calculated"""
134
    statistic_info = {}
135
    len_df = len(df)
136
    for _, e in enumerate(features):
137
        h = {}
138
        h["count"] = int(df[e].count())
139
        h["missing"] = str(round(float((100 - df[e].count() * 100 / len_df)), 3)) + "%"
140
        h["mean"] = float(df[e].mean())
141
        h["max"] = float(df[e].max())
142
        h["min"] = float(df[e].min())
143
        h["median"] = float(df[e].median())
144
        h["std"] = float(df[e].std())
145
        statistic_info[e] = h
146
    return statistic_info
147
148
149
def calculate_middle_part_statistic_info(df, features):
150
    """calculate 5% ~ 95% percentile data"""
151
    statistic_info = {}
152
    len_df = len(df)
153
    # calculate 5% and 95% percentile of dataframe
154
    middle_part_df_info = df.quantile([0.05, 0.95])
155
156
    for _, e in enumerate(features):
157
        low_value = middle_part_df_info[e][0.05]
158
        high_value = middle_part_df_info[e][0.95]
159
        middle_part_df_element = df.loc[(df[e] >= low_value) & (df[e] <= high_value)][e]
160
        h = {}
161
        h["count"] = int(middle_part_df_element.count())
162
        h["missing"] = (
163
            str(round(float((100 - middle_part_df_element.count() * 100 / len_df)), 3))
164
            + "%"
165
        )
166
        h["mean"] = float(middle_part_df_element.mean())
167
        h["max"] = float(middle_part_df_element.max())
168
        h["min"] = float(middle_part_df_element.min())
169
        h["median"] = float(middle_part_df_element.median())
170
        h["std"] = float(middle_part_df_element.std())
171
        statistic_info[e] = h
172
    return statistic_info
173
174
175
# labtest_statistic_info = calculate_statistic_info(df_train, labtest_features)
176
177
178
# group by patient_id, then calculate lab test/demographic features' statistics information
179
groupby_patientid_df = df_train.groupby(
180
    ["PATIENT_ID"], dropna=True, as_index=False
181
).mean()
182
183
184
# calculate statistic info (all values calculated)
185
labtest_patientwise_statistic_info = calculate_statistic_info(
186
    groupby_patientid_df, labtest_features
187
)
188
demographic_statistic_info = calculate_statistic_info(
189
    groupby_patientid_df, demographic_features
190
)  # it's also patient-wise
191
192
# calculate statistic info (5% ~ 95% only)
193
demographic_statistic_info_2 = calculate_middle_part_statistic_info(
194
    groupby_patientid_df, demographic_features
195
)
196
labtest_patientwise_statistic_info_2 = calculate_middle_part_statistic_info(
197
    groupby_patientid_df, labtest_features
198
)
199
200
# take 2 statistics information's union
201
statistic_info = labtest_patientwise_statistic_info_2 | demographic_statistic_info_2
202
203
204
# %%
205
# observe features, export to csv file [optional]
206
to_export_dict = {
207
    "name": [],
208
    "missing_rate": [],
209
    "count": [],
210
    "mean": [],
211
    "max": [],
212
    "min": [],
213
    "median": [],
214
    "std": [],
215
}
216
for key in statistic_info:
217
    detail = statistic_info[key]
218
    to_export_dict["name"].append(key)
219
    to_export_dict["count"].append(detail["count"])
220
    to_export_dict["missing_rate"].append(detail["missing"])
221
    to_export_dict["mean"].append(detail["mean"])
222
    to_export_dict["max"].append(detail["max"])
223
    to_export_dict["min"].append(detail["min"])
224
    to_export_dict["median"].append(detail["median"])
225
    to_export_dict["std"].append(detail["std"])
226
to_export_df = pd.DataFrame.from_dict(to_export_dict)
227
# to_export_df.to_csv('statistic_info.csv')
228
229
# %%
230
# normalize data
231
def normalize_data(df, features, statistic_info):
232
233
    df_features = df[features]
234
    df_features = df_features.apply(
235
        lambda x: (x - statistic_info[x.name]["mean"])
236
        / (statistic_info[x.name]["std"] + 1e-12)
237
    )
238
    df = pd.concat(
239
        [df[["PATIENT_ID", "gender", "RE_DATE", "outcome", "LOS"]], df_features], axis=1
240
    )
241
    return df
242
243
244
df_train = normalize_data(
245
    df_train, ["age"] + labtest_features, statistic_info
246
)  # gender don't need to be normalized
247
248
# %%
249
# filter outliers
250
def filter_data(df, features, bar=3):
251
    for f in features:
252
        df[f] = df[f].mask(df[f].abs().gt(bar))
253
    return df
254
255
256
df_train = filter_data(df_train, demographic_features + labtest_features, bar=3)
257
258
# %%
259
# drop rows if all labtest_features are recorded nan
260
df_train = df_train.dropna(subset=labtest_features, how="all")
261
262
# %%
263
# Calculate data statistics after preprocessing steps (before imputation)
264
265
# Step 1: reverse z-score normalization operation
266
df_reverse = df_train
267
# reverse normalize data
268
def reverse_normalize_data(df, features, statistic_info):
269
    df_features = df[features]
270
    df_features = df_features.apply(
271
        lambda x: x * (statistic_info[x.name]["std"] + 1e-12)
272
        + statistic_info[x.name]["mean"]
273
    )
274
    df = pd.concat(
275
        [df[["PATIENT_ID", "gender", "RE_DATE", "outcome", "LOS"]], df_features], axis=1
276
    )
277
    return df
278
279
280
df_reverse = reverse_normalize_data(
281
    df_reverse, ["age"] + labtest_features, statistic_info
282
)  # gender don't need to be normalized
283
284
statistics = {}
285
286
for f in demographic_features + labtest_features:
287
    statistics[f] = {}
288
289
290
def calculate_quantile_statistic_info(df, features, case):
291
    """all values calculated"""
292
    for _, e in enumerate(features):
293
        # print(e, lo, mi, hi)
294
        if e == "gender":
295
            unique, count = np.unique(df[e], return_counts=True)
296
            data_count = dict(zip(unique, count))  # key = 1 male, 0 female
297
            print(data_count)
298
            male_percentage = (
299
                data_count[1.0] * 100 / (data_count[1.0] + data_count[0.0])
300
            )
301
            statistics[e][case] = f"{male_percentage:.2f}% Male"
302
            print(statistics[e][case])
303
        else:
304
            lo = round(np.nanpercentile(df[e], 25), 2)
305
            mi = round(np.nanpercentile(df[e], 50), 2)
306
            hi = round(np.nanpercentile(df[e], 75), 2)
307
            statistics[e][case] = f"{mi:.2f} [{lo:.2f}, {hi:.2f}]"
308
309
310
def calculate_missing_rate(df, features, case="missing_rate"):
311
    for _, e in enumerate(features):
312
        missing_rate = round(float(df[e].isnull().sum() * 100 / df[e].shape[0]), 2)
313
        statistics[e][case] = f"{missing_rate:.2f}%"
314
315
316
tmp_groupby_pid = df_reverse.groupby(["PATIENT_ID"], dropna=True, as_index=False).mean()
317
318
calculate_quantile_statistic_info(tmp_groupby_pid, demographic_features, "overall")
319
calculate_quantile_statistic_info(
320
    tmp_groupby_pid[tmp_groupby_pid["outcome"] == 0], demographic_features, "alive"
321
)
322
calculate_quantile_statistic_info(
323
    tmp_groupby_pid[tmp_groupby_pid["outcome"] == 1], demographic_features, "dead"
324
)
325
326
calculate_quantile_statistic_info(df_reverse, labtest_features, "overall")
327
calculate_quantile_statistic_info(
328
    df_reverse[df_reverse["outcome"] == 0], labtest_features, "alive"
329
)
330
calculate_quantile_statistic_info(
331
    df_reverse[df_reverse["outcome"] == 1], labtest_features, "dead"
332
)
333
334
calculate_missing_rate(
335
    df_reverse, demographic_features + labtest_features, "missing_rate"
336
)
337
338
export_quantile_statistics = {
339
    "Characteristics": [],
340
    "Overall": [],
341
    "Alive": [],
342
    "Dead": [],
343
    "Missing Rate": [],
344
}
345
for f in demographic_features + labtest_features:
346
    export_quantile_statistics["Characteristics"].append(f)
347
    export_quantile_statistics["Overall"].append(statistics[f]["overall"])
348
    export_quantile_statistics["Alive"].append(statistics[f]["alive"])
349
    export_quantile_statistics["Dead"].append(statistics[f]["dead"])
350
    export_quantile_statistics["Missing Rate"].append(statistics[f]["missing_rate"])
351
352
# pd.DataFrame.from_dict(export_quantile_statistics).to_csv('statistics.csv')
353
354
# %%
355
def calculate_data_existing_length(data):
356
    res = 0
357
    for i in data:
358
        if not pd.isna(i):
359
            res += 1
360
    return res
361
362
363
# elements in data are sorted in time ascending order
364
def fill_missing_value(data, to_fill_value=0):
365
    data_len = len(data)
366
    data_exist_len = calculate_data_existing_length(data)
367
    if data_len == data_exist_len:
368
        return data
369
    elif data_exist_len == 0:
370
        # data = [to_fill_value for _ in range(data_len)]
371
        for i in range(data_len):
372
            data[i] = to_fill_value
373
        return data
374
    if pd.isna(data[0]):
375
        # find the first non-nan value's position
376
        not_na_pos = 0
377
        for i in range(data_len):
378
            if not pd.isna(data[i]):
379
                not_na_pos = i
380
                break
381
        # fill element before the first non-nan value with median
382
        for i in range(not_na_pos):
383
            data[i] = to_fill_value
384
    # fill element after the first non-nan value
385
    for i in range(1, data_len):
386
        if pd.isna(data[i]):
387
            data[i] = data[i - 1]
388
    return data
389
390
391
# %%
392
# fill missing data using our strategy and convert to time series records
393
grouped = df_train.groupby("PATIENT_ID")
394
395
all_x_demographic = []
396
all_x_labtest = []
397
all_y = []
398
all_missing_mask = []
399
400
for name, group in grouped:
401
    sorted_group = group.sort_values(by=["RE_DATE"], ascending=True)
402
    patient_demographic = []
403
    patient_labtest = []
404
    patient_y = []
405
406
    for f in demographic_features + labtest_features:
407
        to_fill_value = (statistic_info[f]["median"] - statistic_info[f]["mean"]) / (
408
            statistic_info[f]["std"] + 1e-12
409
        )
410
        # take median patient as the default to-fill missing value
411
        # print(sorted_group[f].values)
412
        fill_missing_value(sorted_group[f].values, to_fill_value)
413
        # print(sorted_group[f].values)
414
        # print('-----------')
415
    all_missing_mask.append(
416
        (
417
            np.isfinite(
418
                sorted_group[demographic_features + labtest_features].to_numpy()
419
            )
420
        ).astype(int)
421
    )
422
423
    for _, v in sorted_group.iterrows():
424
        patient_y.append([v["outcome"], v["LOS"]])
425
        demo = []
426
        lab = []
427
        for f in demographic_features:
428
            demo.append(v[f])
429
        for f in labtest_features:
430
            lab.append(v[f])
431
        patient_labtest.append(lab)
432
        patient_demographic.append(demo)
433
    all_y.append(patient_y)
434
    all_x_demographic.append(patient_demographic[-1])
435
    all_x_labtest.append(patient_labtest)
436
437
# all_x_demographic (2 dim, record each patients' demographic features)
438
# all_x_labtest (3 dim, record each patients' lab test features)
439
# all_y (3 dim, patients' outcome/los of all visits)
440
441
# %%
442
all_x_labtest = np.array(all_x_labtest, dtype=object)
443
x_lab_length = [len(_) for _ in all_x_labtest]
444
x_lab_length = torch.tensor(x_lab_length, dtype=torch.int)
445
max_length = int(x_lab_length.max())
446
all_x_labtest = [torch.tensor(_) for _ in all_x_labtest]
447
# pad lab test sequence to the same shape
448
all_x_labtest = torch.nn.utils.rnn.pad_sequence((all_x_labtest), batch_first=True)
449
450
all_x_demographic = torch.tensor(all_x_demographic)
451
batch_size, demo_dim = all_x_demographic.shape
452
# repeat demographic tensor
453
all_x_demographic = torch.reshape(
454
    all_x_demographic.repeat(1, max_length), (batch_size, max_length, demo_dim)
455
)
456
# demographic tensor concat with lab test tensor
457
all_x = torch.cat((all_x_demographic, all_x_labtest), 2)
458
459
all_y = np.array(all_y, dtype=object)
460
all_y = [torch.Tensor(_) for _ in all_y]
461
# pad [outcome/los] sequence as well
462
all_y = torch.nn.utils.rnn.pad_sequence((all_y), batch_first=True)
463
464
all_missing_mask = np.array(all_missing_mask, dtype=object)
465
all_missing_mask = [torch.tensor(_) for _ in all_missing_mask]
466
all_missing_mask = torch.nn.utils.rnn.pad_sequence((all_missing_mask), batch_first=True)
467
468
# %%
469
# save pickle format dataset (export torch tensor)
470
pd.to_pickle(all_x, f"./datasets/tongji/processed_data/x.pkl")
471
pd.to_pickle(all_y, f"./datasets/tongji/processed_data/y.pkl")
472
pd.to_pickle(x_lab_length, f"./datasets/tongji/processed_data/visits_length.pkl")
473
pd.to_pickle(all_missing_mask, f"./datasets/tongji/processed_data/missing_mask.pkl")
474
475
# %%
476
# Calculate patients' outcome statistics (patients-wise)
477
outcome_list = []
478
y_outcome = all_y[:, :, 0]
479
indices = torch.arange(len(x_lab_length), dtype=torch.int64)
480
for i in indices:
481
    outcome_list.append(y_outcome[i][0].item())
482
outcome_list = np.array(outcome_list)
483
print(len(outcome_list))
484
unique, count = np.unique(outcome_list, return_counts=True)
485
data_count = dict(zip(unique, count))
486
print(data_count)
487
488
# %%
489
# Calculate patients' outcome statistics (records-wise)
490
outcome_records_list = []
491
y_outcome = all_y[:, :, 0]
492
indices = torch.arange(len(x_lab_length), dtype=torch.int64)
493
for i in indices:
494
    outcome_records_list.extend(y_outcome[i][0 : x_lab_length[i]].tolist())
495
outcome_records_list = np.array(outcome_records_list)
496
print(len(outcome_records_list))
497
unique, count = np.unique(outcome_records_list, return_counts=True)
498
data_count = dict(zip(unique, count))
499
print(data_count)
500
501
# %%
502
# Calculate patients' mean los and 95% percentile los
503
los_list = []
504
y_los = all_y[:, :, 1]
505
indices = torch.arange(len(x_lab_length), dtype=torch.int64)
506
for i in indices:
507
    # los_list.extend(y_los[i][: x_lab_length[i].long()].tolist())
508
    los_list.append(y_los[i][0].item())
509
los_list = np.array(los_list)
510
print(los_list.mean() * 0.5)
511
print(np.median(los_list) * 0.5)
512
print(np.percentile(los_list, 95))
513
514
print("median:", np.median(los_list))
515
print("Q1:", np.percentile(los_list, 25))
516
print("Q3:", np.percentile(los_list, 75))
517
518
# %%
519
los_alive_list = np.array(
520
    [los_list[i] for i in range(len(los_list)) if outcome_list[i] == 0]
521
)
522
los_dead_list = np.array(
523
    [los_list[i] for i in range(len(los_list)) if outcome_list[i] == 1]
524
)
525
print(len(los_alive_list))
526
print(len(los_dead_list))
527
528
print("[Alive]")
529
print("median:", np.median(los_alive_list))
530
print("Q1:", np.percentile(los_alive_list, 25))
531
print("Q3:", np.percentile(los_alive_list, 75))
532
533
print("[Dead]")
534
print("median:", np.median(los_dead_list))
535
print("Q1:", np.percentile(los_dead_list, 25))
536
print("Q3:", np.percentile(los_dead_list, 75))
537
538
# %%
539
tjh_los_statistics = {
540
    "overall": los_list,
541
    "alive": los_alive_list,
542
    "dead": los_dead_list,
543
}
544
# pd.to_pickle(tjh_los_statistics, 'tjh_los_statistics.pkl')
545
546
# %%
547
# calculate visits length Median [Q1, Q3]
548
visits_list = np.array(x_lab_length)
549
visits_alive_list = np.array(
550
    [x_lab_length[i] for i in range(len(x_lab_length)) if outcome_list[i] == 0]
551
)
552
visits_dead_list = np.array(
553
    [x_lab_length[i] for i in range(len(x_lab_length)) if outcome_list[i] == 1]
554
)
555
print(len(visits_alive_list))
556
print(len(visits_dead_list))
557
558
print("[Total]")
559
print("median:", np.median(visits_list))
560
print("Q1:", np.percentile(visits_list, 25))
561
print("Q3:", np.percentile(visits_list, 75))
562
563
print("[Alive]")
564
print("median:", np.median(visits_alive_list))
565
print("Q1:", np.percentile(visits_alive_list, 25))
566
print("Q3:", np.percentile(visits_alive_list, 75))
567
568
print("[Dead]")
569
print("median:", np.median(visits_dead_list))
570
print("Q1:", np.percentile(visits_dead_list, 25))
571
print("Q3:", np.percentile(visits_dead_list, 75))
572
573
# %%
574
# Length-of-stay interval (overall/alive/dead)
575
los_interval_list = []
576
los_interval_alive_list = []
577
los_interval_dead_list = []
578
579
y_los = all_y[:, :, 1]
580
indices = torch.arange(len(x_lab_length), dtype=torch.int64)
581
for i in indices:
582
    cur_visits_len = x_lab_length[i]
583
    if cur_visits_len == 1:
584
        continue
585
    for j in range(1, cur_visits_len):
586
        los_interval_list.append(y_los[i][j - 1] - y_los[i][j])
587
        if outcome_list[i] == 0:
588
            los_interval_alive_list.append(y_los[i][j - 1] - y_los[i][j])
589
        else:
590
            los_interval_dead_list.append(y_los[i][j - 1] - y_los[i][j])
591
592
los_interval_list = np.array(los_interval_list)
593
los_interval_alive_list = np.array(los_interval_alive_list)
594
los_interval_dead_list = np.array(los_interval_dead_list)
595
596
output = {
597
    "overall": los_interval_list,
598
    "alive": los_interval_alive_list,
599
    "dead": los_interval_dead_list,
600
}
601
# pd.to_pickle(output, 'tjh_los_interval_list.pkl')
602
603
# %%
604
len(los_interval_list), len(los_interval_alive_list), len(los_interval_dead_list)
605
606
# %%
607
def check_nan(x):
608
    if np.isnan(np.sum(x.cpu().numpy())):
609
        print("some values from input are nan")
610
    else:
611
        print("no nan")