--- a +++ b/benchmark/data_split_ongoing.py @@ -0,0 +1,379 @@ +# -*- coding: utf-8 -*- + +''' + +input: 9k data ? + 1. ctgov_data/raw_data.csv + +nctid,status,why_stop,label,phase,diseases,icdcodes,drugs,smiless,criteria + +processing: + 1. phase I + 2. phase II + 3. phase III + 4. indication + 5. train/test split + + +output: + 1. ctgov_data/phase_I.csv + 2. ctgov_data/phase_II.csv + 3. ctgov_data/phase_III.csv + 4. ctgov_data/trial.csv + +requires ~10 minutes. + +''' + +import csv +from random import shuffle +## no shuffle +from functools import reduce + + +from ccs_utils import file2_icd2ccs_and_ccs2description, file2_icd2ccsr +# icd2ccs, ccscode2description = file2_icd2ccs_and_ccs2description() +icd2ccsr = file2_icd2ccsr() + + +def csvfile2rows(input_file): + with open(input_file, 'r') as csvfile: + rows = list(csv.reader(csvfile, delimiter = ','))[1:] + return rows + +def filter_phase_I(row): + if "phase 1" in row[4]: + return True + ## enhance + # if int(row[3])==1 and row[4]=='phase 4': + # return True + + return False + +def filter_phase_II(row): + phase = row[4] + label = int(row[3]) + + if "phase 2" in row[4]: + return True + ## enhance + # if int(row[3])==1 and 'phase 4' in row[4]: + # return True + + return False + +def filter_phase_III(row): + if "phase 3" in row[4]: + return True + ### enhance + # if "phase 4" in row[4] and int(row[3])==1: + # return True + # if int(row[3])==0 and row[4] =='phase 2': + # return True + return False + +def filter_trial(row): + label = int(row[3]) + if label == 0 and ('phase 1' in row[4] or 'phase 2' in row[4]): + return True + if ('phase 3' in row[4] or 'phase 4' in row[4]) and label==1: ### label == 1 + return True + return False + +# def filter_chronic(row): +# if 'chronic' in row[5]: +# return True +# return False + +# def filter_cardio(row): +# if 'cardio' in row[5]: +# return True +# return False + +# def filter_cancer(row): +# if 'cancer' in row[5] or 'neoplasm' in row[5] or 'tumor' in row[5]: +# return True +# return False + +# def filter_pain(row): +# if 'pain' in row[5]: +# return True +# return False + +def icdcode_text_2_lst_of_lst(text): + text = text[2:-2] + lst_lst = [] + for i in text.split('", "'): + i = i[1:-1] + lst_lst.append([j.strip()[1:-1] for j in i.split(',')]) + return lst_lst + +def row2icdcodelst(row): + icdcode_text = row[6] + icdcode_lst2 = icdcode_text_2_lst_of_lst(icdcode_text) + icdcode_lst = reduce(lambda x,y:x+y, icdcode_lst2) + icdcode_lst = [i.replace('.', '') for i in icdcode_lst] + return icdcode_lst + + + + +# def filter_heart(row): +# icdcode_text = row[6] +# icdcode_lst2 = icdcode_text_2_lst_of_lst(icdcode_text) +# icdcode_lst = reduce(lambda x,y:x+y, icdcode_lst2) +# icdcode_lst = [i.replace('.', '') for i in icdcode_lst] +# for icdcode in icdcode_lst: +# try: +# ccs = icd2ccs[icdcode] +# description = ccscode2description[ccs].lower() +# if 'heart' in description: +# return True +# except: +# pass +# return False + +# def filter_infection(row): +# icdcode_text = row[6] +# icdcode_lst2 = icdcode_text_2_lst_of_lst(icdcode_text) +# icdcode_lst = reduce(lambda x,y:x+y, icdcode_lst2) +# icdcode_lst = [i.replace('.', '') for i in icdcode_lst] +# for icdcode in icdcode_lst: +# try: +# ccs = icd2ccs[icdcode] +# description = ccscode2description[ccs].lower() +# if 'infect' in description: +# return True +# except: +# pass +# return False + + + +# def filter_disorder(row): +# icdcode_text = row[6] +# icdcode_lst2 = icdcode_text_2_lst_of_lst(icdcode_text) +# icdcode_lst = reduce(lambda x,y:x+y, icdcode_lst2) +# icdcode_lst = [i.replace('.', '') for i in icdcode_lst] +# for icdcode in icdcode_lst: +# try: +# ccs = icd2ccs[icdcode] +# description = ccscode2description[ccs].lower() +# if 'disorder' in description: +# return True +# except: +# pass +# return False + +def filter_nervous(row): + icdcode_lst = row2icdcodelst(row) + for icdcode in icdcode_lst: + try: + ccsr = icd2ccsr[icdcode] + if ccsr == 'NVS': + return True + except: + pass + return False + +def filter_cancer(row): + icdcode_lst = row2icdcodelst(row) + for icdcode in icdcode_lst: + try: + ccsr = icd2ccsr[icdcode] + if ccsr == 'NEO': + return True + except: + pass + return False + + + +# def filter_cancer(row): +# icdcode_text = row[6] +# if 'cancer' in icdcode_text.lower() or 'neoplasm' in icdcode_text.lower() \ +# or 'oncology' in icdcode_text.lower() or 'tumor' in icdcode_text.lower(): +# return True +# icdcode_lst2 = icdcode_text_2_lst_of_lst(icdcode_text) +# icdcode_lst = reduce(lambda x,y:x+y, icdcode_lst2) +# icdcode_lst = [i.replace('.', '') for i in icdcode_lst] +# for icdcode in icdcode_lst: +# try: +# ccs = icd2ccs[icdcode] +# description = ccscode2description[ccs].lower() +# if 'cancer' in description or 'neoplasm' in description or 'oncology' in description or 'tumor' in description: +# return True +# except: +# pass +# return False + +def filter_infect(row): + icdcode_lst = row2icdcodelst(row) + for icdcode in icdcode_lst: + try: + ccsr = icd2ccsr[icdcode] + if ccsr == 'INF': + return True + except: + pass + return False + + +def filter_respiratory(row): + icdcode_lst = row2icdcodelst(row) + for icdcode in icdcode_lst: + try: + ccsr = icd2ccsr[icdcode] + if ccsr == 'RSP': + return True + except: + pass + return False + +def filter_digest(row): + icdcode_lst = row2icdcodelst(row) + for icdcode in icdcode_lst: + try: + ccsr = icd2ccsr[icdcode] + if ccsr == 'DIG': + return True + except: + pass + return False + + + + +def write_row_to_csvfile(rows, fieldname, output_file): + with open(output_file, 'w') as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldname) + writer.writeheader() + for row in rows: + dic = {k:row[i] for i,k in enumerate(fieldname)} + writer.writerow(dic) + return + + +nctid2year = dict() +with open('data/nctid_date.txt', 'r') as fin: + lines = fin.readlines() +for line in lines: + nctid, start_year, completion_year = line.strip('\n').split('\t') + start_year = 0 if start_year=='' else int(start_year.split()[-1]) + completion_year = 0 if completion_year == '' else int(completion_year.split()[-1]) + nctid2year[nctid] = start_year, completion_year #### 0, 2018 + +def row2year(row): + nctid = row[0] + start_year, completion_year = nctid2year[nctid] + return start_year, completion_year + + +def select_and_split_data(input_file, filter_func, output_file_name): + rows = csvfile2rows(input_file) + rows = list(filter(filter_func, rows)) + fieldname = ['nctid', 'status', 'why_stop', 'label', 'phase', + 'diseases', 'icdcodes', 'drugs', 'smiless', 'criteria', 'lead_sponsor', 'collaborator'] + write_row_to_csvfile(rows, fieldname, output_file_name) + return + + +def smiles_txt_to_lst(text): + """ + "['CN[C@H]1CC[C@@H](C2=CC(Cl)=C(Cl)C=C2)C2=CC=CC=C12', 'CNCCC=C1C2=CC=CC=C2CCC2=CC=CC=C12']" + """ + text = text[1:-1] + lst = [i.strip()[1:-1] for i in text.split(',')] + return lst + +from copy import deepcopy + +# nctid,status,why_stop,label,phase,diseases,icdcodes,drugs,smiless,criteria,lead_sponsor,collaborator + +def clean_data(input_file, clean_file): + """ + remove placebo + """ + rows = csvfile2rows(input_file) + newrows = [] + fieldname = ['nctid','status','why_stop','label','phase','diseases','icdcodes','drugs','smiless','criteria','lead_sponsor','collaborator'] + with open(clean_file, 'w') as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldname) + writer.writeheader() + for row in rows: + # drugs = row[7] + # if 'placebo' not in drugs.lower(): + # newrows.append(row) + # continue + # smiless = row[8] + # newdrug, newsmiles = [], [] + # # assert len(smiles_txt_to_lst(drugs)) == len(smiles_txt_to_lst(smiless)) + # for drug, smiles in zip(smiles_txt_to_lst(drugs), smiles_txt_to_lst(smiless)): + # if 'placebo' not in drug.lower(): + # newdrug.append(drug) + # newsmiles.append(smiles) + # else: + # print(smiles) + # newdrug = str(newdrug) + # newsmiles = str(smiles) + # assert len(newdrug) > 0 + + # smiless = row[8] + # if '[O--].[Mg++]' in smiless: + # smiles_lst = smiles_txt_to_lst(smiless) + # smiles_lst = set(smiles_lst) + # smiles_lst.remove('[O--].[Mg++]') + # if len(smiles_lst)==0: + # continue + # smiles_lst = str(list(smiles_lst)) + # newrow = row[:8] + [smiles_lst] + row[9:] + # else: + # newrow = row + newrow = row + + dic = {k:newrow[i] for i,k in enumerate(fieldname)} + writer.writerow(dic) + return + + + +if __name__ == "__main__": + input_file = 'data/ongoing_data.csv' + clean_file = "data/clean_ongoing_data.csv" + + clean_data(input_file, clean_file) + #### remove placebo + + print("------------ phase I -------------") + select_and_split_data(clean_file, filter_phase_I, 'data/ongoing_phase_I.csv') + print("----------- phase II -------------") + select_and_split_data(clean_file, filter_phase_II, 'data/ongoing_phase_II.csv') + print("----------- phase III ----------") + select_and_split_data(clean_file, filter_phase_III, 'data/ongoing_phase_III.csv') + + + +''' +origin + phase I + pos = 997 neg = 558 + phase II + pos = 1069 neg = 2377 + phase III + pos = 2399 neg = 1458 + indication + pos = 3146 neg = 2146 + + + + + +''' + + + + + + + +