Diff of /preprocess/behrtFormat.py [000000] .. [bad60c]

Switch to unified view

a b/preprocess/behrtFormat.py
1
from common.spark import spark_init, read_parquet, read_txt
2
from CPRD.tabel import EHR
3
import pyspark.sql.functions as F
4
from pyspark.sql import Window
5
6
spark = spark_init()
7
8
config= {
9
    'diagnoses': '',  # data path for diagnoses/medication
10
    'demographic': '',  # data path for demographic information
11
    'output': '',  # path to save formated file
12
    'col_name': ''  # column name for ICD/Med code
13
}
14
15
diagnoses = read_parquet(spark.sqlContext, config['diagnoses']).select(['patid','eventdate',config['col_name']]).na.drop().select(['patid','eventdate', config['col_name']])
16
demographic = read_parquet(spark.sqlContext, config['demographic'])
17
18
diagnoses = diagnoses.na.drop()
19
diagnoses = diagnoses.dropDuplicates()
20
21
# demographic data
22
demographic = demographic.select(['patid', 'yob'])
23
diagnoses= diagnoses.join(demographic, diagnoses.patid == demographic.patid, 'inner').drop(demographic.patid)
24
diagnoses = EHR(diagnoses).cal_age('eventdate', 'yob', year=False).select(['patid', 'eventdate', 'age', config['col_name'], 'yob'])
25
diagnoses = diagnoses.dropDuplicates()
26
27
# set age and code to string
28
diagnoses = EHR(diagnoses).set_col_to_str('age').set_col_to_str(config['col_name'])
29
30
# group by date
31
diagnoses = diagnoses.groupby(['patid', 'eventdate']).agg(F.collect_list(config['col_name']).alias(config['col_name']), F.collect_list('age').alias('age'), F.first('yob').alias('yob'))
32
diagnoses = EHR(diagnoses).array_add_element(config['col_name'], 'SEP')
33
34
# add extra age to fill the gap of sep
35
extract_age = F.udf(lambda x: x[0])
36
diagnoses = diagnoses.withColumn('age_temp', extract_age('age')).withColumn('age', F.concat(F.col('age'),F.array(F.col('age_temp')))).drop('age_temp')
37
38
w = Window.partitionBy('patid').orderBy('eventdate')
39
# sort and merge ccs and age
40
diagnoses = diagnoses.withColumn(config['col_name'], F.collect_list(config['col_name']).over(w)).withColumn('age', F.collect_list('age').over(w)).groupBy('patid').agg(F.max(config['col_name']).alias(config['col_name']), F.max('age').alias('age'))
41
42
diagnoses = EHR(diagnoses).array_flatten(config['col_name']).array_flatten('age')
43
diagnoses.write.parquet(config['output'])