--- a
+++ b/HINT/icdcode_encode.py
@@ -0,0 +1,241 @@
+'''
+input:
+	data/raw_data.csv
+
+output: 
+	data/icdcode2ancestor_dict.pkl (icdcode to its ancestors)
+	icdcode_embedding 
+
+'''
+
+import csv, re, pickle, os 
+from functools import reduce 
+import icd10
+from collections import defaultdict
+
+
+import torch 
+torch.manual_seed(0)
+from torch import nn 
+from torch.autograd import Variable
+import torch.nn.functional as F
+from torch.utils import data  #### data.Dataset 
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+def text_2_lst_of_lst(text):
+	"""
+		"[""['F53.0', 'P91.4', 'Z13.31', 'Z13.32']""]"
+	"""
+	text = text[2:-2]
+	code_sublst = []
+	for i in text.split('", "'):
+		i = i[1:-1]
+		code_sublst.append([j.strip()[1:-1] for j in i.split(',')])
+	# print(code_sublst)	
+	return code_sublst 
+
+def get_icdcode_lst():
+	input_file = 'data/raw_data.csv'
+	with open(input_file, 'r') as csvfile:
+		rows = list(csv.reader(csvfile, delimiter = ','))[1:]
+	code_lst = []
+	for row in rows:
+		code_sublst = text_2_lst_of_lst(row[6])
+		code_lst.append(code_sublst)
+	return code_lst 
+
+def combine_lst_of_lst(lst_of_lst):
+	lst = list(reduce(lambda x,y:x+y, lst_of_lst))
+	lst = list(set(lst))
+	return lst
+
+def collect_all_icdcodes():
+	code_lst = get_icdcode_lst()
+	code_lst = list(map(combine_lst_of_lst, code_lst))
+	code_lst = list(reduce(lambda x,y:x+y, code_lst))
+	code_lst = list(set(code_lst))
+	return code_lst
+
+
+def find_ancestor_for_icdcode(icdcode, icdcode2ancestor):
+	if icdcode in icdcode2ancestor:
+		return 
+	icdcode2ancestor[icdcode] = []
+	ancestor = icdcode[:]
+	while len(ancestor) > 2:
+		ancestor = ancestor[:-1]
+		if ancestor[-1]=='.':
+			ancestor = ancestor[:-1]
+		if icd10.find(ancestor) is not None:
+			icdcode2ancestor[icdcode].append(ancestor)
+	return
+
+
+def build_icdcode2ancestor_dict():
+	pkl_file = "data/icdcode2ancestor_dict.pkl"
+	if os.path.exists(pkl_file):
+		icdcode2ancestor = pickle.load(open(pkl_file, 'rb'))
+		return icdcode2ancestor 
+	all_code = collect_all_icdcodes() 
+	icdcode2ancestor = defaultdict(list)
+	for code in all_code:
+		find_ancestor_for_icdcode(code, icdcode2ancestor)
+	pickle.dump(icdcode2ancestor, open(pkl_file,'wb'))
+	return icdcode2ancestor 
+
+
+def collect_all_code_and_ancestor():
+	icdcode2ancestor = build_icdcode2ancestor_dict()
+	all_code = set(icdcode2ancestor.keys())
+	ancestor_lst = list(icdcode2ancestor.values())
+	ancestor_set = set(reduce(lambda x,y:x+y, ancestor_lst))
+	all_code_lst = all_code.union(ancestor_set)
+	return all_code_lst	
+
+
+'''
+
+assign each code an index. 
+
+embedding lookup 
+
+
+'''
+
+
+class GRAM(nn.Sequential):
+	"""	
+		return a weighted embedding 
+	"""
+
+	def __init__(self, embedding_dim, icdcode2ancestor, device):
+		super(GRAM, self).__init__()		
+		self.icdcode2ancestor = icdcode2ancestor 
+		self.all_code_lst = GRAM.codedict_2_allcode(self.icdcode2ancestor)
+		self.code_num = len(self.all_code_lst)
+		self.maxlength = 5
+		self.code2index = {code:idx for idx,code in enumerate(self.all_code_lst)}
+		self.index2code = {idx:code for idx,code in enumerate(self.all_code_lst)}
+		self.padding_matrix = torch.zeros(self.code_num, self.maxlength).long() 
+		self.mask_matrix = torch.zeros(self.code_num, self.maxlength)
+		for idx in range(self.code_num):
+			code = self.index2code[idx]
+			ancestor_code_lst = self.icdcode2ancestor[code]
+			ancestor_idx_lst = [idx] + [self.code2index[code] for code in ancestor_code_lst]
+			ancestor_mask_lst = [1 for i in ancestor_idx_lst] + [0] * (self.maxlength - len(ancestor_idx_lst))
+			ancestor_idx_lst = ancestor_idx_lst + [0]*(self.maxlength-len(ancestor_idx_lst))
+			self.padding_matrix[idx,:] = torch.Tensor(ancestor_idx_lst)
+			self.mask_matrix[idx,:] = torch.Tensor(ancestor_mask_lst)
+
+		self.embedding_dim = embedding_dim 
+		self.embedding = nn.Embedding(self.code_num, self.embedding_dim)
+		self.attention_model = nn.Linear(2*embedding_dim, 1)
+
+		self.device = device
+		self = self.to(device)
+		self.padding_matrix = self.padding_matrix.to('cpu')
+		self.mask_matrix = self.mask_matrix.to('cpu')
+
+	@property
+	def embedding_size(self):
+		return self.embedding_dim
+
+
+	@staticmethod
+	def codedict_2_allcode(icdcode2ancestor):
+		all_code = set(icdcode2ancestor.keys())
+		ancestor_lst = list(icdcode2ancestor.values())
+		ancestor_set = set(reduce(lambda x,y:x+y, ancestor_lst))
+		all_code_lst = all_code.union(ancestor_set)
+		return all_code_lst		
+
+
+	def forward_single_code(self, single_code):
+		idx = self.code2index[single_code].to(self.device)
+		ancestor_vec = self.padding_matrix[idx,:]  #### (5,)
+		mask_vec = self.mask_matrix[idx,:] 
+
+		embeded = self.embedding(ancestor_vec)  ### 5, 50
+		current_vec = torch.cat([self.embedding(torch.Tensor([idx]).long()).view(1,-1) for i in range(self.maxlength)], 0) ### 1,50 -> 5,50
+		attention_input = torch.cat([embeded, current_vec], 1)  ### 5, 100
+		attention_weight = self.attention_model(attention_input)  ##### 5, 1
+		attention_weight = torch.exp(attention_weight)  #### 5, 1
+		attention_output = attention_weight * mask_vec.view(-1,1)  #### 5, 1
+		attention_output = attention_output / torch.sum(attention_output)  #### 5, 1
+		output = embeded * attention_output ### 5, 50 
+		output = torch.sum(output, 0) ### 50
+		return output 
+
+
+	def forward_code_lst(self, code_lst):
+		"""
+			
+			['C05.2', 'C10.0', 'C16.0', 'C16.4', 'C17.0', 'C17.1', 'C17.2'], length is 32 
+			32 is length of code_lst; 5 is maxlength; 50 is embedding_dim; 
+		"""
+		idx_lst = [self.code2index[code] for code in code_lst if code in self.code2index] ### 32 
+		if idx_lst == []:
+			idx_lst = [0]
+		ancestor_mat = self.padding_matrix[idx_lst,:].to(self.device)  ##### 32,5
+		mask_mat = self.mask_matrix[idx_lst,:].to(self.device)  #### 32,5
+		embeded = self.embedding(ancestor_mat)  #### 32,5,50
+		current_vec = self.embedding(torch.Tensor(idx_lst).long().to(self.device)) #### 32,50
+		current_vec = current_vec.unsqueeze(1) ### 32,1,50
+		current_vec = current_vec.repeat(1, self.maxlength, 1) #### 32,5,50
+		attention_input = torch.cat([embeded, current_vec], 2)  #### 32,5,100
+		attention_weight = self.attention_model(attention_input)  #### 32,5,1 
+		attention_weight = torch.exp(attention_weight).squeeze(-1)  #### 32,5 
+		attention_output = attention_weight * mask_mat  #### 32,5 
+		attention_output = attention_output / torch.sum(attention_output, 1).view(-1,1)  #### 32,5 
+		attention_output = attention_output.unsqueeze(-1)  #### 32,5,1 
+		output = embeded * attention_output ##### 32,5,50 
+		output = torch.sum(output,1) ##### 32,50
+		return output 
+
+	def forward_code_lst2(self, code_lst_lst):
+		### in one sample 
+		code_lst = reduce(lambda x,y:x+y, code_lst_lst)
+		code_embed = self.forward_code_lst(code_lst)
+		### to do 
+		code_embed = torch.mean(code_embed, 0).view(1,-1)  #### dim, 
+		return code_embed 
+		
+	def forward_code_lst3(self, code_lst_lst_lst):
+		code_embed_lst = [self.forward_code_lst2(code_lst_lst) for code_lst_lst in code_lst_lst_lst]
+		code_embed = torch.cat(code_embed_lst, 0)
+		return code_embed 
+
+
+
+
+
+if __name__ == '__main__':
+	dic = build_icdcode2ancestor_dict() 
+
+
+
+
+# if __name__ == "__main__":
+# 	# code_lst = collect_all_icdcodes()  ### 5k code
+# 	# all_code = collect_all_code_and_ancestor() ### 10k 
+# 	# icdcode2ancestor = build_icdcode2ancestor_dict()
+# 	# maxlength = 0
+# 	# for icdcode, ancestor in icdcode2ancestor.items():
+# 	# 	if len(ancestor) > maxlength:
+# 	# 		maxlength = len(ancestor)
+# 	# print(maxlength) 
+# 	# assert maxlength == 4
+
+# 	icdcode2ancestor = build_icdcode2ancestor_dict()
+# 	gram_model = GRAM(embedding_dim = 50, icdcode2ancestor = icdcode2ancestor)
+# 	# output = gram_model.single_forward('S33.121S')
+# 	code_lst = ['C05.2', 'C10.0', 'C16.0', 'C16.4', 'C17.0', 'C17.1', 'C17.2']
+# 	output = gram_model(code_lst)
+
+
+
+
+
+
+