Diff of /Track2-evaluate-ver4.py [000000] .. [1de6ed]

Switch to unified view

a b/Track2-evaluate-ver4.py
1
#!/usr/local/bin/python
2
3
"""Inter-annotator agreement calculator."""
4
5
"""
6
To run this file, please use:
7
8
python <gold standard folder> <system output folder>
9
10
e.g.: python gold_annotations system_annotations
11
12
Please note that you must use Python 3 to get the correct results with this script
13
14
15
"""
16
17
18
import argparse
19
import glob
20
import os
21
from collections import defaultdict
22
from xml.etree import cElementTree
23
24
25
class ClinicalCriteria(object):
26
    """Criteria in the Track 1 documents."""
27
28
    def __init__(self, tid, value):
29
        """Init."""
30
        self.tid = tid.strip().upper()
31
        self.ttype = self.tid
32
        self.value = value.lower().strip()
33
34
    def equals(self, other, mode='strict'):
35
        """Return whether the current criteria is equal to the one provided."""
36
        if other.tid == self.tid and other.value == self.value:
37
            return True
38
        return False
39
40
41
class ClinicalConcept(object):
42
    """Named Entity Tag class."""
43
44
    def __init__(self, tid, start, end, ttype, text=''):
45
        """Init."""
46
        self.tid = str(tid).strip()
47
        self.start = int(start)
48
        self.end = int(end)
49
        self.text = str(text).strip()
50
        self.ttype = str(ttype).strip()
51
52
    def span_matches(self, other, mode='strict'):
53
        """Return whether the current tag overlaps with the one provided."""
54
        assert mode in ('strict', 'lenient')
55
        if mode == 'strict':
56
            if self.start == other.start and self.end == other.end:
57
                return True
58
        else:   # lenient
59
            if (self.end > other.start and self.start < other.end) or \
60
               (self.start < other.end and other.start < self.end):
61
                return True
62
        return False
63
64
    def equals(self, other, mode='strict'):
65
        """Return whether the current tag is equal to the one provided."""
66
        assert mode in ('strict', 'lenient')
67
        return other.ttype == self.ttype and self.span_matches(other, mode)
68
69
    def __str__(self):
70
        """String representation."""
71
        return '{}\t{}\t({}:{})'.format(self.ttype, self.text, self.start, self.end)
72
73
74
class Relation(object):
75
    """Relation class."""
76
77
    def __init__(self, rid, arg1, arg2, rtype):
78
        """Init."""
79
        assert isinstance(arg1, ClinicalConcept)
80
        assert isinstance(arg2, ClinicalConcept)
81
        self.rid = str(rid).strip()
82
        self.arg1 = arg1
83
        self.arg2 = arg2
84
        self.rtype = str(rtype).strip()
85
86
    def equals(self, other, mode='strict'):
87
        """Return whether the current tag is equal to the one provided."""
88
        assert mode in ('strict', 'lenient')
89
        if self.arg1.equals(other.arg1, mode) and \
90
                self.arg2.equals(other.arg2, mode) and \
91
                self.rtype == other.rtype:
92
            return True
93
        return False
94
95
    def __str__(self):
96
        """String representation."""
97
        return '{} ({}->{})'.format(self.rtype, self.arg1.ttype,
98
                                    self.arg2.ttype)
99
100
101
class RecordTrack1(object):
102
    """Record for Track 2 class."""
103
104
    def __init__(self, file_path):
105
        self.path = os.path.abspath(file_path)
106
        self.basename = os.path.basename(self.path)
107
        self.annotations = self._get_annotations()
108
        self.text = None
109
110
    @property
111
    def tags(self):
112
        return self.annotations['tags']
113
114
    def _get_annotations(self):
115
        """Return a dictionary with all the annotations in the .ann file."""
116
        annotations = defaultdict(dict)
117
        annotation_file = cElementTree.parse(self.path)
118
        for tag in annotation_file.findall('.//TAGS/*'):
119
            criterion = ClinicalCriteria(tag.tag.upper(), tag.attrib['met'])
120
            annotations['tags'][tag.tag.upper()] = criterion
121
            if tag.attrib['met'] not in ('met', 'not met'):
122
                assert '{}: Unexpected value ("{}") for the {} tag!'.format(
123
                    self.path, criterion.value, criterion.ttype)
124
        return annotations
125
126
127
class RecordTrack2(object):
128
    """Record for Track 2 class."""
129
130
    def __init__(self, file_path):
131
        """Initialize."""
132
        self.path = os.path.abspath(file_path)
133
        self.basename = os.path.basename(self.path)
134
        self.annotations = self._get_annotations()
135
        # self.text = self._get_text()
136
137
    @property
138
    def tags(self):
139
        return self.annotations['tags']
140
141
    @property
142
    def relations(self):
143
        return self.annotations['relations']
144
145
    def _get_annotations(self):
146
        """Return a dictionary with all the annotations in the .ann file."""
147
        annotations = defaultdict(dict)
148
        with open(self.path) as annotation_file:
149
            lines = annotation_file.readlines()
150
            for line_num, line in enumerate(lines):
151
                if line.strip().startswith('T'):
152
                    try:
153
                        tag_id, tag_m, tag_text = line.strip().split('\t')
154
                    except ValueError:
155
                        print(self.path, line)
156
                    if len(tag_m.split(' ')) == 3:
157
                        tag_type, tag_start, tag_end = tag_m.split(' ')
158
                    elif len(tag_m.split(' ')) == 4:
159
                        tag_type, tag_start, _, tag_end = tag_m.split(' ')
160
                    elif len(tag_m.split(' ')) == 5:
161
                        tag_type, tag_start, _, _, tag_end = tag_m.split(' ')
162
                    else:
163
                        print(self.path)
164
                        print(line)
165
                    tag_start, tag_end = int(tag_start), int(tag_end)
166
                    annotations['tags'][tag_id] = ClinicalConcept(tag_id,
167
                                                                  tag_start,
168
                                                                  tag_end,
169
                                                                  tag_type,
170
                                                                  tag_text)
171
            for line_num, line in enumerate(lines):
172
                if line.strip().startswith('R'):
173
                    rel_id, rel_m = line.strip().split('\t')
174
                    rel_type, rel_arg1, rel_arg2 = rel_m.split(' ')
175
                    rel_arg1 = rel_arg1.split(':')[1]
176
                    rel_arg2 = rel_arg2.split(':')[1]
177
                    arg1 = annotations['tags'][rel_arg1]
178
                    arg2 = annotations['tags'][rel_arg2]
179
                    annotations['relations'][rel_id] = Relation(rel_id, arg1,
180
                                                                arg2, rel_type)
181
        return annotations
182
183
    def _get_text(self):
184
        """Return the text in the corresponding txt file."""
185
        path = self.path.replace('.ann', '.txt')
186
        with open(path) as text_file:
187
            text = text_file.read()
188
        return text
189
190
    def search_by_id(self, key):
191
        """Search by id among both tags and relations."""
192
        try:
193
            return self.annotations['tags'][key]
194
        except KeyError():
195
            try:
196
                return self.annotations['relations'][key]
197
            except KeyError():
198
                return None
199
200
201
class Measures(object):
202
    """Abstract methods and var to evaluate."""
203
204
    def __init__(self, tp=0, tn=0, fp=0, fn=0):
205
        """Initizialize."""
206
        assert type(tp) == int
207
        assert type(tn) == int
208
        assert type(fp) == int
209
        assert type(fn) == int
210
        self.tp = tp
211
        self.tn = tn
212
        self.fp = fp
213
        self.fn = fn
214
215
    def precision(self):
216
        """Compute Precision score."""
217
        try:
218
            return self.tp / (self.tp + self.fp)
219
        except ZeroDivisionError:
220
            return 0.0
221
222
    def recall(self):
223
        """Compute Recall score."""
224
        try:
225
            return self.tp / (self.tp + self.fn)
226
        except ZeroDivisionError:
227
            return 0.0
228
229
    def f_score(self, beta=1):
230
        """Compute F1-measure score."""
231
        assert beta > 0.
232
        try:
233
            num = (1 + beta**2) * (self.precision() * self.recall())
234
            den = beta**2 * (self.precision() + self.recall())
235
            return num / den
236
        except ZeroDivisionError:
237
            return 0.0
238
239
    def f1(self):
240
        """Compute the F1-score (beta=1)."""
241
        return self.f_score(beta=1)
242
243
    def specificity(self):
244
        """Compute Specificity score."""
245
        try:
246
            return self.tn / (self.fp + self.tn)
247
        except ZeroDivisionError:
248
            return 0.0
249
250
    def sensitivity(self):
251
        """Compute Sensitivity score."""
252
        return self.recall()
253
254
    def auc(self):
255
        """Compute AUC score."""
256
        return (self.sensitivity() + self.specificity()) / 2
257
258
259
class SingleEvaluator(object):
260
    """Evaluate two single files."""
261
262
    def __init__(self, doc1, doc2, track, mode='strict', key=None, verbose=False):
263
        """Initialize."""
264
        assert isinstance(doc1, RecordTrack2) or isinstance(doc1, RecordTrack1)
265
        assert isinstance(doc2, RecordTrack2) or isinstance(doc2, RecordTrack1)
266
        assert mode in ('strict', 'lenient')
267
        assert doc1.basename == doc2.basename
268
        self.scores = {'tags': {'tp': 0, 'fp': 0, 'fn': 0, 'tn': 0},
269
                       'relations': {'tp': 0, 'fp': 0, 'fn': 0, 'tn': 0}}
270
        self.doc1 = doc1
271
        self.doc2 = doc2
272
        if key:
273
            gol = [t for t in doc1.tags.values() if t.ttype == key]
274
            sys = [t for t in doc2.tags.values() if t.ttype == key]
275
            sys_check = [t for t in doc2.tags.values() if t.ttype == key]
276
        else:
277
            gol = [t for t in doc1.tags.values()]
278
            sys = [t for t in doc2.tags.values()]
279
            sys_check = [t for t in doc2.tags.values()]
280
281
        #pare down matches -- if multiple system tags overlap with only one
282
        #gold standard tag, only keep one sys tag
283
        gol_matched = []
284
        for s in sys:
285
            for g in gol:
286
                if (g.equals(s,mode)):
287
                    if g not in gol_matched:
288
                        gol_matched.append(g)
289
                    else:
290
                        if s in sys_check:
291
                            sys_check.remove(s)
292
293
294
        sys = sys_check
295
        #now evaluate
296
        self.scores['tags']['tp'] = len({s.tid for s in sys for g in gol if g.equals(s, mode)})
297
        self.scores['tags']['fp'] = len({s.tid for s in sys}) - self.scores['tags']['tp']
298
        self.scores['tags']['fn'] = len({g.tid for g in gol}) - self.scores['tags']['tp']
299
        self.scores['tags']['tn'] = 0
300
301
        if verbose and track == 2:
302
            tps = {s for s in sys for g in gol if g.equals(s, mode)}
303
            fps = set(sys) - tps
304
            fns = set()
305
            for g in gol:
306
                if not len([s for s in sys if s.equals(g, mode)]):
307
                    fns.add(g)
308
            for e in fps:
309
                print('FP: ' + str(e))
310
            for e in fns:
311
                print('FN:' + str(e))
312
        if track == 2:
313
            if key:
314
                gol = [r for r in doc1.relations.values() if r.rtype == key]
315
                sys = [r for r in doc2.relations.values() if r.rtype == key]
316
                sys_check = [r for r in doc2.relations.values() if r.rtype == key]
317
            else:
318
                gol = [r for r in doc1.relations.values()]
319
                sys = [r for r in doc2.relations.values()]
320
                sys_check = [r for r in doc2.relations.values()]
321
322
            #pare down matches -- if multiple system tags overlap with only one
323
            #gold standard tag, only keep one sys tag
324
            gol_matched = []
325
            for s in sys:
326
                for g in gol:
327
                    if (g.equals(s,mode)):
328
                        if g not in gol_matched:
329
                            gol_matched.append(g)
330
                        else:
331
                            if s in sys_check:
332
                                sys_check.remove(s)
333
            sys = sys_check
334
            #now evaluate
335
            self.scores['relations']['tp'] = len({s.rid for s in sys for g in gol if g.equals(s, mode)})
336
            self.scores['relations']['fp'] = len({s.rid for s in sys}) - self.scores['relations']['tp']
337
            self.scores['relations']['fn'] = len({g.rid for g in gol}) - self.scores['relations']['tp']
338
            self.scores['relations']['tn'] = 0
339
            if verbose:
340
                tps = {s for s in sys for g in gol if g.equals(s, mode)}
341
                fps = set(sys) - tps
342
                fns = set()
343
                for g in gol:
344
                    if not len([s for s in sys if s.equals(g, mode)]):
345
                        fns.add(g)
346
                for e in fps:
347
                    print('FP: ' + str(e))
348
                for e in fns:
349
                    print('FN:' + str(e))
350
351
352
class MultipleEvaluator(object):
353
    """Evaluate two sets of files."""
354
355
    def __init__(self, corpora, tag_type=None, mode='strict',
356
                 verbose=False):
357
        """Initialize."""
358
        assert isinstance(corpora, Corpora)
359
        assert mode in ('strict', 'lenient')
360
        self.scores = None
361
        if corpora.track == 1:
362
            self.track1(corpora)
363
        else:
364
            self.track2(corpora, tag_type, mode, verbose)
365
366
    def track1(self, corpora):
367
        """Compute measures for Track 1."""
368
        self.tags = ('ABDOMINAL', 'ADVANCED-CAD', 'ALCOHOL-ABUSE',
369
                     'ASP-FOR-MI', 'CREATININE', 'DIETSUPP-2MOS',
370
                     'DRUG-ABUSE', 'ENGLISH', 'HBA1C', 'KETO-1YR',
371
                     'MAJOR-DIABETES', 'MAKES-DECISIONS', 'MI-6MOS')
372
        self.scores = defaultdict(dict)
373
        metrics = ('p', 'r', 'f1', 'specificity', 'auc')
374
        values = ('met', 'not met')
375
        self.values = {'met': {'tp': 0, 'fp': 0, 'tn': 0, 'fn': 0},
376
                       'not met': {'tp': 0, 'fp': 0, 'tn': 0, 'fn': 0}}
377
378
        def evaluation(corpora, value, scores):
379
            predictions = defaultdict(list)
380
            for g, s in corpora.docs:
381
                for tag in self.tags:
382
                    predictions[tag].append(
383
                        (g.tags[tag].value == value, s.tags[tag].value == value))
384
            for tag in self.tags:
385
                # accumulate for micro overall measure
386
                self.values[value]['tp'] += predictions[tag].count((True, True))
387
                self.values[value]['fp'] += predictions[tag].count((False, True))
388
                self.values[value]['tn'] += predictions[tag].count((False, False))
389
                self.values[value]['fn'] += predictions[tag].count((True, False))
390
391
                # compute per-tag measures
392
                measures = Measures(tp=predictions[tag].count((True, True)),
393
                                    fp=predictions[tag].count((False, True)),
394
                                    tn=predictions[tag].count((False, False)),
395
                                    fn=predictions[tag].count((True, False)))
396
                scores[(tag, value, 'p')] = measures.precision()
397
                scores[(tag, value, 'r')] = measures.recall()
398
                scores[(tag, value, 'f1')] = measures.f1()
399
                scores[(tag, value, 'specificity')] = measures.specificity()
400
                scores[(tag, value, 'auc')] = measures.auc()
401
            return scores
402
403
        self.scores = evaluation(corpora, 'met', self.scores)
404
        self.scores = evaluation(corpora, 'not met', self.scores)
405
406
        for measure in metrics:
407
            for value in values:
408
                self.scores[('macro', value, measure)] = sum(
409
                    [self.scores[(t, value, measure)] for t in self.tags]) / len(self.tags)
410
411
    def track2(self, corpora, tag_type=None, mode='strict', verbose=False):
412
        """Compute measures for Track 2."""
413
        self.scores = {'tags': {'tp': 0,
414
                                'fp': 0,
415
                                'fn': 0,
416
                                'tn': 0,
417
                                'micro': {'precision': 0,
418
                                          'recall': 0,
419
                                          'f1': 0},
420
                                'macro': {'precision': 0,
421
                                          'recall': 0,
422
                                          'f1': 0}},
423
                       'relations': {'tp': 0,
424
                                     'fp': 0,
425
                                     'fn': 0,
426
                                     'tn': 0,
427
                                     'micro': {'precision': 0,
428
                                               'recall': 0,
429
                                               'f1': 0},
430
                                     'macro': {'precision': 0,
431
                                               'recall': 0,
432
                                               'f1': 0}}}
433
        self.tags = ('Drug', 'Strength', 'Duration', 'Route', 'Form',
434
                     'ADE', 'Dosage', 'Reason', 'Frequency')
435
        self.relations = ('Strength-Drug', 'Dosage-Drug', 'Duration-Drug',
436
                          'Frequency-Drug', 'Form-Drug', 'Route-Drug',
437
                          'Reason-Drug', 'ADE-Drug')
438
        for g, s in corpora.docs:
439
            evaluator = SingleEvaluator(g, s, 2, mode, tag_type, verbose=verbose)
440
            for target in ('tags', 'relations'):
441
                for score in ('tp', 'fp', 'fn'):
442
                    self.scores[target][score] += evaluator.scores[target][score]
443
                measures = Measures(tp=evaluator.scores[target]['tp'],
444
                                    fp=evaluator.scores[target]['fp'],
445
                                    fn=evaluator.scores[target]['fn'],
446
                                    tn=evaluator.scores[target]['tn'])
447
                for score in ('precision', 'recall', 'f1'):
448
                    fn = getattr(measures, score)
449
                    self.scores[target]['macro'][score] += fn()
450
451
        for target in ('tags', 'relations'):
452
            # Normalization
453
            for key in self.scores[target]['macro'].keys():
454
                self.scores[target]['macro'][key] = \
455
                    self.scores[target]['macro'][key] / len(corpora.docs)
456
457
            measures = Measures(tp=self.scores[target]['tp'],
458
                                fp=self.scores[target]['fp'],
459
                                fn=self.scores[target]['fn'],
460
                                tn=self.scores[target]['tn'])
461
            for key in self.scores[target]['micro'].keys():
462
                fn = getattr(measures, key)
463
                self.scores[target]['micro'][key] = fn()
464
465
466
def evaluate(corpora, mode='strict', verbose=False):
467
    """Run the evaluation by considering only files in the two folders."""
468
    assert mode in ('strict', 'lenient')
469
    evaluator_s = MultipleEvaluator(corpora, verbose)
470
    if corpora.track == 1:
471
        macro_f1, macro_auc = 0, 0
472
        print('{:*^96}'.format(' TRACK 1 '))
473
        print('{:20}  {:-^30}    {:-^22}    {:-^14}'.format('', ' met ',
474
                                                            ' not met ',
475
                                                            ' overall '))
476
        print('{:20}  {:6}  {:6}  {:6}  {:6}    {:6}  {:6}  {:6}    {:6}  {:6}'.format(
477
            '', 'Prec.', 'Rec.', 'Speci.', 'F(b=1)', 'Prec.', 'Rec.', 'F(b=1)', 'F(b=1)', 'AUC'))
478
        for tag in evaluator_s.tags:
479
            print('{:>20}  {:<5.4f}  {:<5.4f}  {:<5.4f}  {:<5.4f}    {:<5.4f}  {:<5.4f}  {:<5.4f}    {:<5.4f}  {:<5.4f}'.format(
480
                tag.capitalize(),
481
                evaluator_s.scores[(tag, 'met', 'p')],
482
                evaluator_s.scores[(tag, 'met', 'r')],
483
                evaluator_s.scores[(tag, 'met', 'specificity')],
484
                evaluator_s.scores[(tag, 'met', 'f1')],
485
                evaluator_s.scores[(tag, 'not met', 'p')],
486
                evaluator_s.scores[(tag, 'not met', 'r')],
487
                evaluator_s.scores[(tag, 'not met', 'f1')],
488
                (evaluator_s.scores[(tag, 'met', 'f1')] + evaluator_s.scores[(tag, 'not met', 'f1')])/2,
489
                evaluator_s.scores[(tag, 'met', 'auc')]))
490
            macro_f1 += (evaluator_s.scores[(tag, 'met', 'f1')] + evaluator_s.scores[(tag, 'not met', 'f1')])/2
491
            macro_auc += evaluator_s.scores[(tag, 'met', 'auc')]
492
        print('{:20}  {:-^30}    {:-^22}    {:-^14}'.format('', '', '', ''))
493
        m = Measures(tp=evaluator_s.values['met']['tp'],
494
                     fp=evaluator_s.values['met']['fp'],
495
                     fn=evaluator_s.values['met']['fn'],
496
                     tn=evaluator_s.values['met']['tn'])
497
        nm = Measures(tp=evaluator_s.values['not met']['tp'],
498
                      fp=evaluator_s.values['not met']['fp'],
499
                      fn=evaluator_s.values['not met']['fn'],
500
                      tn=evaluator_s.values['not met']['tn'])
501
        print('{:>20}  {:<5.4f}  {:<5.4f}  {:<5.4f}  {:<5.4f}    {:<5.4f}  {:<5.4f}  {:<5.4f}    {:<5.4f}  {:<5.4f}'.format(
502
            'Overall (micro)', m.precision(), m.recall(), m.specificity(),
503
            m.f1(), nm.precision(), nm.recall(), nm.f1(),
504
            (m.f1() + nm.f1()) / 2, m.auc()))
505
        print('{:>20}  {:<5.4f}  {:<5.4f}  {:<5.4f}  {:<5.4f}    {:<5.4f}  {:<5.4f}  {:<5.4f}    {:<5.4f}  {:<5.4f}'.format(
506
            'Overall (macro)',
507
            evaluator_s.scores[('macro', 'met', 'p')],
508
            evaluator_s.scores[('macro', 'met', 'r')],
509
            evaluator_s.scores[('macro', 'met', 'specificity')],
510
            evaluator_s.scores[('macro', 'met', 'f1')],
511
            evaluator_s.scores[('macro', 'not met', 'p')],
512
            evaluator_s.scores[('macro', 'not met', 'r')],
513
            evaluator_s.scores[('macro', 'not met', 'f1')],
514
            macro_f1 / len(evaluator_s.tags),
515
            evaluator_s.scores[('macro', 'met', 'auc')]))
516
        print()
517
        print('{:>20}  {:^74}'.format('', '  {} files found  '.format(len(corpora.docs))))
518
    else:
519
        evaluator_l = MultipleEvaluator(corpora, mode='lenient', verbose=verbose)
520
        print('{:*^70}'.format(' TRACK 2 '))
521
        print('{:20}  {:-^22}    {:-^22}'.format('', ' strict ', ' lenient '))
522
        print('{:20}  {:6}  {:6}  {:6}    {:6}  {:6}  {:6}'.format('', 'Prec.',
523
                                                                   'Rec.',
524
                                                                   'F(b=1)',
525
                                                                   'Prec.',
526
                                                                   'Rec.',
527
                                                                   'F(b=1)'))
528
        for tag in evaluator_s.tags:
529
            evaluator_tag_s = MultipleEvaluator(corpora, tag, verbose=verbose)
530
            evaluator_tag_l = MultipleEvaluator(corpora, tag, mode='lenient', verbose=verbose)
531
            print('{:>20}  {:<5.4f}  {:<5.4f}  {:<5.4f}    {:<5.4f}  {:<5.4f}  {:<5.4f}'.format(
532
                tag.capitalize(),
533
                evaluator_tag_s.scores['tags']['micro']['precision'],
534
                evaluator_tag_s.scores['tags']['micro']['recall'],
535
                evaluator_tag_s.scores['tags']['micro']['f1'],
536
                evaluator_tag_l.scores['tags']['micro']['precision'],
537
                evaluator_tag_l.scores['tags']['micro']['recall'],
538
                evaluator_tag_l.scores['tags']['micro']['f1']))
539
        print('{:>20}  {:-^48}'.format('', ''))
540
        print('{:>20}  {:<5.4f}  {:<5.4f}  {:<5.4f}    {:<5.4f}  {:<5.4f}  {:<5.4f}'.format(
541
            'Overall (micro)',
542
            evaluator_s.scores['tags']['micro']['precision'],
543
            evaluator_s.scores['tags']['micro']['recall'],
544
            evaluator_s.scores['tags']['micro']['f1'],
545
            evaluator_l.scores['tags']['micro']['precision'],
546
            evaluator_l.scores['tags']['micro']['recall'],
547
            evaluator_l.scores['tags']['micro']['f1']))
548
        print('{:>20}  {:<5.4f}  {:<5.4f}  {:<5.4f}    {:<5.4f}  {:<5.4f}  {:<5.4f}'.format(
549
            'Overall (macro)',
550
            evaluator_s.scores['tags']['macro']['precision'],
551
            evaluator_s.scores['tags']['macro']['recall'],
552
            evaluator_s.scores['tags']['macro']['f1'],
553
            evaluator_l.scores['tags']['macro']['precision'],
554
            evaluator_l.scores['tags']['macro']['recall'],
555
            evaluator_l.scores['tags']['macro']['f1']))
556
        print()
557
558
        print('{:*^70}'.format(' RELATIONS '))
559
        for rel in evaluator_s.relations:
560
            evaluator_tag_s = MultipleEvaluator(corpora, rel, mode='strict', verbose=verbose)
561
            evaluator_tag_l = MultipleEvaluator(corpora, rel, mode='lenient', verbose=verbose)
562
            print('{:>20}  {:<5.4f}  {:<5.4f}  {:<5.4f}    {:<5.4f}  {:<5.4f}  {:<5.4f}'.format(
563
                '{} -> {}'.format(rel.split('-')[0], rel.split('-')[1].capitalize()),
564
                evaluator_tag_s.scores['relations']['micro']['precision'],
565
                evaluator_tag_s.scores['relations']['micro']['recall'],
566
                evaluator_tag_s.scores['relations']['micro']['f1'],
567
                evaluator_tag_l.scores['relations']['micro']['precision'],
568
                evaluator_tag_l.scores['relations']['micro']['recall'],
569
                evaluator_tag_l.scores['relations']['micro']['f1']))
570
        print('{:>20}  {:-^48}'.format('', ''))
571
        print('{:>20}  {:<5.4f}  {:<5.4f}  {:<5.4f}    {:<5.4f}  {:<5.4f}  {:<5.4f}'.format(
572
            'Overall (micro)',
573
            evaluator_s.scores['relations']['micro']['precision'],
574
            evaluator_s.scores['relations']['micro']['recall'],
575
            evaluator_s.scores['relations']['micro']['f1'],
576
            evaluator_l.scores['relations']['micro']['precision'],
577
            evaluator_l.scores['relations']['micro']['recall'],
578
            evaluator_l.scores['relations']['micro']['f1']))
579
        print('{:>20}  {:<5.4f}  {:<5.4f}  {:<5.4f}    {:<5.4f}  {:<5.4f}  {:<5.4f}'.format(
580
            'Overall (macro)',
581
            evaluator_s.scores['relations']['macro']['precision'],
582
            evaluator_s.scores['relations']['macro']['recall'],
583
            evaluator_s.scores['relations']['macro']['f1'],
584
            evaluator_l.scores['relations']['macro']['precision'],
585
            evaluator_l.scores['relations']['macro']['recall'],
586
            evaluator_l.scores['relations']['macro']['f1']))
587
        print()
588
        print('{:20}{:^48}'.format('', '  {} files found  '.format(len(corpora.docs))))
589
590
591
class Corpora(object):
592
593
    def __init__(self, folder1, folder2, track_num):
594
        extensions = {1: '*.xml', 2: '*.ann'}
595
        file_ext = extensions[track_num]
596
        self.track = track_num
597
        self.folder1 = folder1
598
        self.folder2 = folder2
599
        files1 = set([os.path.basename(f) for f in glob.glob(
600
            os.path.join(folder1, file_ext))])
601
        files2 = set([os.path.basename(f) for f in glob.glob(
602
            os.path.join(folder2, file_ext))])
603
        common_files = files1 & files2     # intersection
604
        if not common_files:
605
            print('ERROR: None of the files match.')
606
        else:
607
            if files1 - common_files:
608
                print('Files skipped in {}:'.format(self.folder1))
609
                print(', '.join(sorted(list(files1 - common_files))))
610
            if files2 - common_files:
611
                print('Files skipped in {}:'.format(self.folder2))
612
                print(', '.join(sorted(list(files2 - common_files))))
613
        self.docs = []
614
        for file in common_files:
615
            if track_num == 1:
616
                g = RecordTrack1(os.path.join(self.folder1, file))
617
                s = RecordTrack1(os.path.join(self.folder2, file))
618
            else:
619
                g = RecordTrack2(os.path.join(self.folder1, file))
620
                s = RecordTrack2(os.path.join(self.folder2, file))
621
            self.docs.append((g, s))
622
623
624
def main(f1, f2, track, verbose):
625
    """Where the magic begins."""
626
    corpora = Corpora(f1, f2, track)
627
    if corpora.docs:
628
        evaluate(corpora, verbose=verbose)
629
630
631
if __name__ == '__main__':
632
    parser = argparse.ArgumentParser(description='n2c2: Evaluation script for Track 2')
633
    parser.add_argument('folder1', help='First data folder path (gold)')
634
    parser.add_argument('folder2', help='Second data folder path (system)')
635
    args = parser.parse_args()
636
    main(os.path.abspath(args.folder1), os.path.abspath(args.folder2), 2, False)