Switch to side-by-side view

--- a
+++ b/medacy/data/annotations.py
@@ -0,0 +1,264 @@
+import logging
+import os
+import re
+from collections import Counter, namedtuple
+from math import ceil
+
+
+EntTuple = namedtuple('EntTuple', ['tag', 'start', 'end', 'text'])
+
+
+class Annotations:
+    """
+    An Annotations object stores all relevant information needed to manage Annotations over a document.
+    The Annotation object is utilized by medaCy to structure input to models and output from models.
+    This object wraps a list of tuples representing the entities in a document.
+
+    :ivar ann_path: the path to the .ann file
+    :ivar source_text_path: path to the related .txt file
+    :ivar annotations: a list of annotation tuples
+    """
+
+    brat_pattern = re.compile(r'T(\d+)\t(\S+) ((\d+ \d+;)*\d+ \d+)\t(.+)')
+
+    def __init__(self, annotation_data, source_text_path=None):
+        """
+        :param annotation_data: a file path to an annotation file, or a list of annotation tuples.
+        Construction from a list of tuples is intended for internal use.
+        :param source_text_path: optional; path to the text file from which the annotations were derived.
+        """
+        if isinstance(annotation_data, list) and all(isinstance(e, tuple) for e in annotation_data):
+            self.annotations = annotation_data
+            self.source_text_path = source_text_path
+            return
+        elif not os.path.isfile(annotation_data):
+            raise FileNotFoundError("annotation_data must be a list of tuples or a valid file path, but is %s" % repr(annotation_data))
+
+        self.ann_path = annotation_data
+        self.source_text_path = source_text_path
+        self.annotations = self._init_from_file(annotation_data)
+
+    @staticmethod
+    def _init_from_file(file_path):
+        """
+        Creates a list of annotation tuples from a file path
+        :param file_path: the path to an ann file
+        :return: a list of annotation tuples
+        """
+        annotations = []
+        with open(file_path, 'r', encoding='utf-8') as f:
+            data = f.read()
+
+        for match in re.finditer(Annotations.brat_pattern, data):
+            tag = match.group(2)
+            spans = match.group(3)
+            mention = match.group(5)
+
+            spans = re.findall(r'\d+', spans)
+            start, end = int(spans[0]), int(spans[-1])
+
+            new_ent = EntTuple(tag, start, end, mention)
+            annotations.append(new_ent)
+
+        return annotations
+
+    @property
+    def annotations(self):
+        return self._annotations
+
+    @annotations.setter
+    def annotations(self, value):
+        """Ensures that annotations are always sorted"""
+        self._annotations = sorted([EntTuple(*e) for e in value], key=lambda x: (x.start, x.end))
+
+    def get_labels(self, as_list=False):
+        """
+        Get the set of labels from this collection of annotations.
+        :param as_list: bool for if to return the results as a list; defaults to False
+        :return: The set of labels.
+        """
+        labels = {e[0] for e in self.annotations}
+
+        if as_list:
+            return list(labels)
+        return labels
+
+    def add_entity(self, label, start, end, text=""):
+        """
+        Adds an entity to the Annotations
+        :param label: the label of the annotation you are appending
+        :param start: the start index in the document of the annotation you are appending.
+        :param end: the end index in the document of the annotation you are appending
+        :param text: the raw text of the annotation you are appending
+        """
+        self.annotations.append(EntTuple(label, start, end, text))
+
+    def to_ann(self, write_location=None):
+        """
+        Formats the Annotations object into a string representing a valid ANN file. Optionally writes the formatted
+        string to a destination.
+        :param write_location: path of location to write ann file to
+        :return: returns string formatted as an ann file, if write_location is valid path also writes to that path.
+        """
+        ann_string = ""
+
+        for num, tup in enumerate(self.annotations, 1):
+            mention = tup.text.replace('\n', ' ')
+            ann_string += f"T{num}\t{tup.tag} {tup.start} {tup.end}\t{mention}\n"
+
+        if write_location is not None:
+            if os.path.isfile(write_location):
+                logging.warning("Overwriting file at: %s", write_location)
+            with open(write_location, 'w') as f:
+                f.write(ann_string)
+
+        return ann_string
+
+    def difference(self, other, leniency=0):
+        """
+        Identifies the difference between two Annotations objects. Useful for checking if an unverified annotation
+        matches an annotation known to be accurate. This is done returning a list of all annotations in the operated on
+        Annotation object that do not exist in the passed in annotation object. This is a set difference.
+        :param other: Another Annotations object.
+        :param leniency: a floating point value between [0,1] defining the leniency of the character spans to count
+        as different. A value of zero considers only exact character matches while a positive value considers entities
+        that differ by up to :code:`ceil(leniency * len(span)/2)` on either side.
+        :return: A set of tuples of non-matching annotations.
+        """
+        if not isinstance(other, Annotations):
+            raise ValueError("Annotations.diff() can only accept another Annotations object as an argument.")
+        if leniency == 0:
+            return set(self.annotations) - set(other.annotations)
+        if not 0 <= leniency <= 1:
+            raise ValueError("Leniency must be a floating point between [0,1]")
+
+        matches = set()
+        for ann in self.annotations:
+            label, start, end, text = ann
+            window = ceil(leniency * (end - start))
+            for c_label, c_start, c_end, c_text in other.annotations:
+                if label == c_label:
+                    if start - window <= c_start and end + window >= c_end:
+                        matches.add(ann)
+                        break
+
+        return set(self.annotations) - matches
+
+    def intersection(self, other, leniency=0):
+        """
+        Computes the intersection of the operated annotation object with the operand annotation object.
+        :param other: Another Annotations object.
+        :param leniency: a floating point value between [0,1] defining the leniency of the character spans to count as
+        different. A value of zero considers only exact character matches while a positive value considers entities that
+         differ by up to :code:`ceil(leniency * len(span)/2)` on either side.
+        :return A set of annotations that appear in both Annotation objects
+        """
+        if not isinstance(other, Annotations):
+            raise ValueError("An Annotations object is requried as an argument.")
+        if leniency == 0:
+            return set(self.annotations) & set(other.annotations)
+        if not 0 <= leniency <= 1:
+            raise ValueError("Leniency must be a floating point between [0,1]")
+
+        matches = set()
+        for ann in self.annotations:
+            label, start, end, text = ann
+            window = ceil(leniency * (end - start))
+            for c_label, c_start, c_end, c_text in other:
+                if label == c_label and start - window <= c_start and end + window >= c_end:
+                    matches.add(ann)
+                    break
+
+        return matches
+
+    def compute_ambiguity(self, other):
+        """
+        Finds occurrences of spans from 'annotations' that intersect with a span from this annotation but do not have this spans label.
+        label. If 'annotation' comprises a models predictions, this method provides a strong indicators
+        of a model's in-ability to dis-ambiguate between entities. For a full analysis, compute a confusion matrix.
+        :param other: Another Annotations object.
+        :return: a dictionary containing incorrect label predictions for given spans
+        """
+        if not isinstance(other, Annotations):
+            raise ValueError("An Annotations object is required as an argument.")
+
+        ambiguity_dict = {}
+
+        for label, start, end, text in self.annotations:
+            for c_label, c_start, c_end, c_text in other.annotations:
+                if label == c_label:
+                    continue
+                overlap = max(0, min(end, c_end) - max(c_start, start))
+                if overlap != 0:
+                    ambiguity_dict[(label, start, end, text)] = [(c_label, c_start, c_end, c_text)]
+
+        return ambiguity_dict
+
+    def compute_confusion_matrix(self, other, entities, leniency=0):
+        """
+        Computes a confusion matrix representing span level ambiguity between this annotation and the argument annotation.
+        An annotation in 'annotations' is ambiguous is it overlaps with a span in this Annotation but does not have the
+        same entity label. The main diagonal of this matrix corresponds to entities in this Annotation that match spans
+        in 'annotations' and have equivalent class label.
+        :param other: Another Annotations object.
+        :param entities: a list of entities to use in computing matrix ambiguity.
+        :param leniency: leniency to utilize when computing overlapping entities. This is the same definition of leniency as in intersection.
+        :return: a square matrix with dimension len(entities) where matrix[i][j] indicates that entities[i] in this annotation was predicted as entities[j] in 'annotation' matrix[i][j] times.
+        """
+        if not isinstance(other, Annotations):
+            raise ValueError("An Annotations object is required as an argument.")
+        if not isinstance(entities, list):
+            raise ValueError("entities must be a list of entities, but is %s" % repr(entities))
+
+        entity_encoding = {entity: i for i, entity in enumerate(entities)}
+        # Create 2-d array of len(entities) ** 2
+        confusion_matrix = [[0 for _ in range(len(entities))] for _ in range(len(entities))]
+
+        ambiguity_dict = self.compute_ambiguity(other)
+        intersection = self.intersection(other, leniency=leniency)
+
+        # Compute all off diagonal scores
+        for gold_annotation in ambiguity_dict:
+            gold_label, start, end, text = gold_annotation
+            for ambiguous_annotation in ambiguity_dict[gold_annotation]:
+                ambiguous_label = ambiguous_annotation[0]
+                confusion_matrix[entity_encoding[gold_label]][entity_encoding[ambiguous_label]] += 1
+
+        # Compute diagonal scores (correctly predicted entities with correct spans)
+        for matching_annotation in intersection:
+            matching_label, start, end, text = matching_annotation
+            confusion_matrix[entity_encoding[matching_label]][entity_encoding[matching_label]] += 1
+
+        return confusion_matrix
+
+    def compute_counts(self):
+        """
+        Computes counts of each entity type in this annotation.
+        :return: a Counter of the entity counts
+        """
+        return Counter(e[0] for e in self.annotations)
+
+    def __str__(self):
+        return str(self.annotations)
+
+    def __len__(self):
+        return len(self.annotations)
+
+    def __iter__(self):
+        return iter(self.annotations)
+
+    def __or__(self, other):
+        """
+        Creates an Annotations object containing annotations from two instances
+        :param other: the other Annotations instance
+        :return: a new Annotations object containing entities from both
+        """
+        new_entities = list(set(self.annotations) | set(other.annotations))
+        new_annotations = Annotations(new_entities, source_text_path=self.source_text_path or other.source_text_path)
+        new_annotations.ann_path = 'None'
+        return new_annotations
+
+    def __ior__(self, other):
+        self.annotations = list(set(self.annotations) | set(other.annotations))
+        self.ann_path = 'None'
+        return self