|
a |
|
b/preprocess/behrtFormat.py |
|
|
1 |
from common.spark import spark_init, read_parquet, read_txt |
|
|
2 |
from CPRD.tabel import EHR |
|
|
3 |
import pyspark.sql.functions as F |
|
|
4 |
from pyspark.sql import Window |
|
|
5 |
|
|
|
6 |
spark = spark_init() |
|
|
7 |
|
|
|
8 |
config= { |
|
|
9 |
'diagnoses': '', # data path for diagnoses/medication |
|
|
10 |
'demographic': '', # data path for demographic information |
|
|
11 |
'output': '', # path to save formated file |
|
|
12 |
'col_name': '' # column name for ICD/Med code |
|
|
13 |
} |
|
|
14 |
|
|
|
15 |
diagnoses = read_parquet(spark.sqlContext, config['diagnoses']).select(['patid','eventdate',config['col_name']]).na.drop().select(['patid','eventdate', config['col_name']]) |
|
|
16 |
demographic = read_parquet(spark.sqlContext, config['demographic']) |
|
|
17 |
|
|
|
18 |
diagnoses = diagnoses.na.drop() |
|
|
19 |
diagnoses = diagnoses.dropDuplicates() |
|
|
20 |
|
|
|
21 |
# demographic data |
|
|
22 |
demographic = demographic.select(['patid', 'yob']) |
|
|
23 |
diagnoses= diagnoses.join(demographic, diagnoses.patid == demographic.patid, 'inner').drop(demographic.patid) |
|
|
24 |
diagnoses = EHR(diagnoses).cal_age('eventdate', 'yob', year=False).select(['patid', 'eventdate', 'age', config['col_name'], 'yob']) |
|
|
25 |
diagnoses = diagnoses.dropDuplicates() |
|
|
26 |
|
|
|
27 |
# set age and code to string |
|
|
28 |
diagnoses = EHR(diagnoses).set_col_to_str('age').set_col_to_str(config['col_name']) |
|
|
29 |
|
|
|
30 |
# group by date |
|
|
31 |
diagnoses = diagnoses.groupby(['patid', 'eventdate']).agg(F.collect_list(config['col_name']).alias(config['col_name']), F.collect_list('age').alias('age'), F.first('yob').alias('yob')) |
|
|
32 |
diagnoses = EHR(diagnoses).array_add_element(config['col_name'], 'SEP') |
|
|
33 |
|
|
|
34 |
# add extra age to fill the gap of sep |
|
|
35 |
extract_age = F.udf(lambda x: x[0]) |
|
|
36 |
diagnoses = diagnoses.withColumn('age_temp', extract_age('age')).withColumn('age', F.concat(F.col('age'),F.array(F.col('age_temp')))).drop('age_temp') |
|
|
37 |
|
|
|
38 |
w = Window.partitionBy('patid').orderBy('eventdate') |
|
|
39 |
# sort and merge ccs and age |
|
|
40 |
diagnoses = diagnoses.withColumn(config['col_name'], F.collect_list(config['col_name']).over(w)).withColumn('age', F.collect_list('age').over(w)).groupBy('patid').agg(F.max(config['col_name']).alias(config['col_name']), F.max('age').alias('age')) |
|
|
41 |
|
|
|
42 |
diagnoses = EHR(diagnoses).array_flatten(config['col_name']).array_flatten('age') |
|
|
43 |
diagnoses.write.parquet(config['output']) |