Switch to side-by-side view

--- a
+++ b/test/embeddings_reimplement/icd9.py
@@ -0,0 +1,117 @@
+import csv
+import json
+from collections import *
+
+class Node(object):
+  def __init__(self, depth, code, descr=None):
+    self.depth = depth
+    self.descr = descr or code
+    self.code = code
+    self.parent = None
+    self.children = []
+
+  def add_child(self, child):
+    if child not in self.children:
+      self.children.append(child)
+
+  def search(self, code):
+    if code == self.code: return [self]
+    ret = []
+    for child in self.children:
+      ret.extend(child.search(code))
+    return ret
+
+  def find(self, code):
+    nodes = self.search(code)
+    if nodes:
+      return nodes[0]
+    return None
+
+  @property
+  def root(self):
+    return self.parents[0]
+
+  @property
+  def description(self):
+    return self.descr
+
+  @property
+  def codes(self):
+    return [n.code for n in self.leaves]
+
+  @property
+  def parents(self):
+    n = self
+    ret = []
+    while n:
+      ret.append(n)
+      n = n.parent
+    ret.reverse()
+    return ret
+
+
+  @property
+  def leaves(self):
+    leaves = set()
+    if not self.children:
+      return [self]
+    for child in self.children:
+      leaves.update(child.leaves)
+    return list(leaves)
+
+  # return all leaf notes with a depth of @depth
+  def leaves_at_depth(self, depth):
+    return [n for n in self.leaves if n.depth == depth]
+
+  @property
+  def siblings(self):
+    parent = self.parent
+    if not parent:
+      return []
+    return list(parent.children)
+
+  def __str__(self):
+    return '%s\t%s' % (self.depth, self.code)
+
+  def __hash__(self):
+    return hash(str(self))
+
+
+class ICD9(Node):
+  def __init__(self, codesfname):
+    # dictionary of depth -> dictionary of code->node
+    self.depth2nodes = defaultdict(dict)
+    super(ICD9, self).__init__(-1, 'ROOT')
+
+    with open(codesfname, 'r') as f:
+      allcodes = json.loads(f.read())
+      self.process(allcodes)
+
+  def process(self, allcodes):
+    for hierarchy in allcodes:
+      self.add(hierarchy)
+
+  def get_node(self, depth, code, descr):
+    d = self.depth2nodes[depth]
+    if code not in d:
+      d[code] = Node(depth, code, descr)
+    return d[code]
+
+  def add(self, hierarchy):
+    prev_node = self
+    for depth, link in enumerate(hierarchy):
+      if not link['code']: continue
+
+      code = link['code']
+      descr = 'descr' in link and link['descr'] or code
+      node = self.get_node(depth, code, descr)
+      node.parent = prev_node
+      prev_node.add_child(node)
+      prev_node = node
+
+
+if __name__ == '__main__':
+  tree = ICD9('codes.json')
+  counter = Counter(list(map(str, tree.leaves)))
+  import pdb
+  pdb.set_trace()