Switch to side-by-side view

--- a
+++ b/benchmark/data_split_ongoing.py
@@ -0,0 +1,379 @@
+# -*- coding: utf-8 -*- 
+
+'''
+
+input: 	9k data ?  
+	1. ctgov_data/raw_data.csv 
+
+nctid,status,why_stop,label,phase,diseases,icdcodes,drugs,smiless,criteria
+
+processing:  
+	1. phase I
+	2. phase II
+	3. phase III
+	4. indication 
+	5. train/test split 
+
+
+output:
+	1. ctgov_data/phase_I.csv 
+	2. ctgov_data/phase_II.csv 
+	3. ctgov_data/phase_III.csv 
+	4. ctgov_data/trial.csv 
+
+requires ~10 minutes. 
+
+'''
+
+import csv 
+from random import shuffle 
+## no shuffle 
+from functools import reduce
+
+
+from ccs_utils import file2_icd2ccs_and_ccs2description, file2_icd2ccsr
+# icd2ccs, ccscode2description = file2_icd2ccs_and_ccs2description() 
+icd2ccsr = file2_icd2ccsr()
+
+
+def csvfile2rows(input_file):
+	with open(input_file, 'r') as csvfile:
+		rows = list(csv.reader(csvfile, delimiter = ','))[1:]
+	return rows 
+
+def filter_phase_I(row):
+	if "phase 1" in row[4]:
+		return True	
+	## enhance 
+	# if int(row[3])==1 and row[4]=='phase 4':
+	# 	return True 
+
+	return False 
+
+def filter_phase_II(row):
+	phase = row[4] 
+	label = int(row[3])
+
+	if "phase 2" in row[4]:
+		return True
+	## enhance
+	# if int(row[3])==1 and 'phase 4' in row[4]:
+	# 	return True
+
+	return False
+
+def filter_phase_III(row):
+	if "phase 3" in row[4]:
+		return True
+	### enhance 
+	# if "phase 4" in row[4] and int(row[3])==1:
+	# 	return True
+	# if int(row[3])==0 and row[4] =='phase 2':
+	# 	return True 
+	return False 
+
+def filter_trial(row):
+	label = int(row[3])
+	if label == 0 and ('phase 1' in row[4] or 'phase 2' in row[4]):
+		return True 
+	if ('phase 3' in row[4] or 'phase 4' in row[4]) and label==1:  ### label == 1
+		return True 
+	return False 
+
+# def filter_chronic(row):
+# 	if 'chronic' in row[5]:
+# 		return True 
+# 	return False 
+
+# def filter_cardio(row):
+# 	if 'cardio' in row[5]:
+# 		return True 
+# 	return False 
+
+# def filter_cancer(row):
+# 	if 'cancer' in row[5] or 'neoplasm' in row[5] or 'tumor' in row[5]:
+# 		return True 
+# 	return False 
+
+# def filter_pain(row):
+# 	if 'pain' in row[5]:
+# 		return True 
+# 	return False 
+
+def icdcode_text_2_lst_of_lst(text):
+	text = text[2:-2]
+	lst_lst = []
+	for i in text.split('", "'):
+		i = i[1:-1]
+		lst_lst.append([j.strip()[1:-1] for j in i.split(',')])
+	return lst_lst 
+
+def row2icdcodelst(row):
+	icdcode_text = row[6]
+	icdcode_lst2 = icdcode_text_2_lst_of_lst(icdcode_text)
+	icdcode_lst = reduce(lambda x,y:x+y, icdcode_lst2)
+	icdcode_lst = [i.replace('.', '') for i in icdcode_lst]
+	return icdcode_lst 
+
+
+
+
+# def filter_heart(row):
+# 	icdcode_text = row[6]
+# 	icdcode_lst2 = icdcode_text_2_lst_of_lst(icdcode_text)
+# 	icdcode_lst = reduce(lambda x,y:x+y, icdcode_lst2)
+# 	icdcode_lst = [i.replace('.', '') for i in icdcode_lst]
+# 	for icdcode in icdcode_lst:
+# 		try:
+# 			ccs = icd2ccs[icdcode]
+# 			description = ccscode2description[ccs].lower() 	
+# 			if 'heart' in description:
+# 				return True 
+# 		except:
+# 			pass 
+# 	return False 
+
+# def filter_infection(row):
+# 	icdcode_text = row[6]
+# 	icdcode_lst2 = icdcode_text_2_lst_of_lst(icdcode_text)
+# 	icdcode_lst = reduce(lambda x,y:x+y, icdcode_lst2)
+# 	icdcode_lst = [i.replace('.', '') for i in icdcode_lst]
+# 	for icdcode in icdcode_lst:
+# 		try:
+# 			ccs = icd2ccs[icdcode]
+# 			description = ccscode2description[ccs].lower() 	
+# 			if 'infect' in description:
+# 				return True 
+# 		except:
+# 			pass
+# 	return False 
+
+
+
+# def filter_disorder(row):
+# 	icdcode_text = row[6]
+# 	icdcode_lst2 = icdcode_text_2_lst_of_lst(icdcode_text)
+# 	icdcode_lst = reduce(lambda x,y:x+y, icdcode_lst2)
+# 	icdcode_lst = [i.replace('.', '') for i in icdcode_lst]
+# 	for icdcode in icdcode_lst:
+# 		try:
+# 			ccs = icd2ccs[icdcode]
+# 			description = ccscode2description[ccs].lower() 	
+# 			if 'disorder' in description:
+# 				return True 
+# 		except:
+# 			pass 
+# 	return False 
+
+def filter_nervous(row):
+	icdcode_lst = row2icdcodelst(row)
+	for icdcode in icdcode_lst:
+		try:
+			ccsr = icd2ccsr[icdcode]
+			if ccsr == 'NVS':
+				return True 
+		except:
+			pass 
+	return False 
+
+def filter_cancer(row):
+	icdcode_lst = row2icdcodelst(row)
+	for icdcode in icdcode_lst:
+		try:
+			ccsr = icd2ccsr[icdcode]
+			if ccsr == 'NEO':
+				return True 
+		except:
+			pass 
+	return False 
+
+
+
+# def filter_cancer(row):
+# 	icdcode_text = row[6]
+# 	if 'cancer' in icdcode_text.lower() or 'neoplasm' in icdcode_text.lower() \
+# 		or 'oncology' in icdcode_text.lower() or 'tumor' in icdcode_text.lower():
+# 		return True 
+# 	icdcode_lst2 = icdcode_text_2_lst_of_lst(icdcode_text)
+# 	icdcode_lst = reduce(lambda x,y:x+y, icdcode_lst2)
+# 	icdcode_lst = [i.replace('.', '') for i in icdcode_lst]
+# 	for icdcode in icdcode_lst:
+# 		try:
+# 			ccs = icd2ccs[icdcode]
+# 			description = ccscode2description[ccs].lower() 	
+# 			if 'cancer' in description or 'neoplasm' in description or 'oncology' in description or 'tumor' in description:
+# 				return True 
+# 		except:
+# 			pass 
+# 	return False 
+
+def filter_infect(row):
+	icdcode_lst = row2icdcodelst(row)
+	for icdcode in icdcode_lst:
+		try:
+			ccsr = icd2ccsr[icdcode]
+			if ccsr == 'INF':
+				return True 
+		except:
+			pass 
+	return False 
+
+
+def filter_respiratory(row):
+	icdcode_lst = row2icdcodelst(row)
+	for icdcode in icdcode_lst:
+		try:
+			ccsr = icd2ccsr[icdcode]
+			if ccsr == 'RSP':
+				return True 
+		except:
+			pass 
+	return False 
+
+def filter_digest(row):
+	icdcode_lst = row2icdcodelst(row)
+	for icdcode in icdcode_lst:
+		try:
+			ccsr = icd2ccsr[icdcode]
+			if ccsr == 'DIG':
+				return True 
+		except:
+			pass 
+	return False 
+
+
+
+
+def write_row_to_csvfile(rows, fieldname, output_file):
+	with open(output_file, 'w') as csvfile:
+		writer = csv.DictWriter(csvfile, fieldnames=fieldname)
+		writer.writeheader()	
+		for row in rows:
+			dic = {k:row[i] for i,k in enumerate(fieldname)}
+			writer.writerow(dic)
+	return	
+
+
+nctid2year = dict() 
+with open('data/nctid_date.txt', 'r') as fin:
+	lines = fin.readlines()
+for line in lines:
+	nctid, start_year, completion_year = line.strip('\n').split('\t')
+	start_year = 0 if start_year=='' else int(start_year.split()[-1])
+	completion_year = 0 if completion_year == '' else int(completion_year.split()[-1])
+	nctid2year[nctid] = start_year, completion_year  #### 0, 2018
+
+def row2year(row):
+	nctid = row[0]
+	start_year, completion_year = nctid2year[nctid]
+	return start_year, completion_year 
+
+
+def select_and_split_data(input_file, filter_func, output_file_name):
+	rows = csvfile2rows(input_file)
+	rows = list(filter(filter_func, rows))
+	fieldname = ['nctid', 'status', 'why_stop', 'label', 'phase', 
+				 'diseases', 'icdcodes', 'drugs', 'smiless', 'criteria', 'lead_sponsor', 'collaborator']
+	write_row_to_csvfile(rows, fieldname, output_file_name)
+	return
+
+
+def smiles_txt_to_lst(text):
+	"""
+		"['CN[C@H]1CC[C@@H](C2=CC(Cl)=C(Cl)C=C2)C2=CC=CC=C12', 'CNCCC=C1C2=CC=CC=C2CCC2=CC=CC=C12']" 
+	"""
+	text = text[1:-1]
+	lst = [i.strip()[1:-1] for i in text.split(',')]
+	return lst 
+
+from copy import deepcopy
+
+# nctid,status,why_stop,label,phase,diseases,icdcodes,drugs,smiless,criteria,lead_sponsor,collaborator
+
+def clean_data(input_file, clean_file):
+	"""
+		remove placebo 
+	"""
+	rows = csvfile2rows(input_file)
+	newrows = []
+	fieldname = ['nctid','status','why_stop','label','phase','diseases','icdcodes','drugs','smiless','criteria','lead_sponsor','collaborator']
+	with open(clean_file, 'w') as csvfile:
+		writer = csv.DictWriter(csvfile, fieldnames=fieldname)
+		writer.writeheader()	
+		for row in rows:
+			# drugs = row[7]
+			# if 'placebo' not in drugs.lower():
+			# 	newrows.append(row)
+			# 	continue 
+			# smiless = row[8]
+			# newdrug, newsmiles = [], []
+			# # assert len(smiles_txt_to_lst(drugs)) == len(smiles_txt_to_lst(smiless))
+			# for drug, smiles in zip(smiles_txt_to_lst(drugs), smiles_txt_to_lst(smiless)):
+			# 	if 'placebo' not in drug.lower():
+			# 		newdrug.append(drug)
+			# 		newsmiles.append(smiles)
+			# 	else:
+			# 		print(smiles)
+			# newdrug = str(newdrug)
+			# newsmiles = str(smiles)
+			# assert len(newdrug) > 0
+
+			# smiless = row[8]
+			# if '[O--].[Mg++]' in smiless:
+			# 	smiles_lst = smiles_txt_to_lst(smiless)
+			# 	smiles_lst = set(smiles_lst)
+			# 	smiles_lst.remove('[O--].[Mg++]')
+			# 	if len(smiles_lst)==0:
+			# 		continue 
+			# 	smiles_lst = str(list(smiles_lst))
+			# 	newrow = row[:8] + [smiles_lst] + row[9:]
+			# else:
+			# 	newrow = row 
+			newrow = row 
+
+			dic = {k:newrow[i] for i,k in enumerate(fieldname)}
+			writer.writerow(dic)
+	return	
+
+
+
+if __name__ == "__main__":
+	input_file = 'data/ongoing_data.csv'
+	clean_file = "data/clean_ongoing_data.csv"
+
+	clean_data(input_file, clean_file)
+	#### remove placebo 
+
+	print("------------ phase I -------------")
+	select_and_split_data(clean_file, filter_phase_I, 'data/ongoing_phase_I.csv')
+	print("----------- phase II -------------")
+	select_and_split_data(clean_file, filter_phase_II, 'data/ongoing_phase_II.csv')
+	print("----------- phase III ----------")
+	select_and_split_data(clean_file, filter_phase_III, 'data/ongoing_phase_III.csv')
+
+
+
+'''
+origin 
+	phase I
+		pos = 997   neg = 558
+	phase II
+		pos = 1069   neg = 2377
+	phase III
+		pos = 2399   neg = 1458
+	indication
+		pos = 3146   neg = 2146
+
+
+
+
+
+'''
+
+
+
+
+
+
+
+