a b/helpers/traintestsplit.py
1
import os
2
import pandas as pd
3
from imblearn.under_sampling import RandomUnderSampler
4
from optparse import OptionParser
5
from sklearn.model_selection import train_test_split
6
7
#This will make a train/validation/test split 80/20/20
8
def resample_data(t):
9
    t = t[['HADM_ID', 'text', 'readm_30d']]
10
    label = t.pop('readm_30d')
11
12
    rus = RandomUnderSampler(random_state=42)
13
    X, y = rus.fit_resample(t, label.astype('category'))
14
15
    ids = pd.Series(X[:, 0])
16
    texts = pd.Series(X[:, 1])
17
18
    df = pd.DataFrame()
19
    df['readm_30d'] = pd.Series(y)
20
    df['HADM_ID'] = ids
21
    df['text'] = texts
22
    return df
23
24
def split_data(admissions, ratio):
25
26
    # Do some limited preprocessing
27
    X = admissions[['HADM_ID', 'text']]
28
    y = admissions['readm_30d']
29
30
    # Create a stratified train test split to preserver distribution.
31
    X_train, X_test, y_train, y_test = train_test_split(X, y,stratify=y, test_size=ratio, random_state=42)
32
33
    train = pd.merge(X_train, y_train, left_index=True, right_index=True)
34
    test = pd.merge(X_test, y_test, left_index=True, right_index=True)
35
36
    return train, test
37
38
39
def main(input_data, output_dir, ratio):
40
41
    # read the dataset from file.
42
    print("Reading raw data")
43
    data = pd.read_csv(input_data)
44
45
    # split into training and testing
46
    print("Splitting into training and testing")
47
    train, test = split_data(data, ratio)
48
49
    # split into train and validation
50
    print("spliting train into train and validation")
51
    train, validation = split_data(train, ratio)
52
53
    # undersample the train
54
    print("Undersampling the train")
55
    train = resample_data(train)
56
57
    # now save the files
58
    if not os.path.exists(output_dir):
59
        os.makedirs(output_dir)
60
    train.to_csv(os.path.join(output_dir, "train.csv"), index=None)
61
    test.to_csv(os.path.join(output_dir, "test.csv"), index=None)
62
    validation.to_csv(os.path.join(output_dir, "validation.csv"), index=None)
63
64
if __name__ == "__main__":
65
66
    parser = OptionParser()
67
68
    parser.add_option("--input", help="specify the input data")
69
70
    parser.add_option("--output_dir", help="specify the output location")
71
72
    parser.add_option("--ratio", help="specify the proportion to keep for testing", type="float")
73
74
    (options, args) = parser.parse_args()
75
76
    # load the data
77
    main(options.input, options.output_dir, options.ratio)