Switch to side-by-side view

--- a
+++ b/NER Preprocessing and Performance Analysis.ipynb
@@ -0,0 +1,1117 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Parsing Clinical Trial Eligibility Criteria Using Transformers"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "import json\n",
+    "import pandas as pd\n",
+    "from matplotlib import pyplot as plt\n",
+    "from sklearn.model_selection import train_test_split\n",
+    "from shutil import copyfile\n",
+    "import csv"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from spacy.lang.en import English\n",
+    "nlp = English()\n",
+    "sentencizer = nlp.create_pipe(\"sentencizer\")\n",
+    "nlp.add_pipe(sentencizer)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Chia Preprocessing"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "inputpath = f\"chia/chia_with_scope\"\n",
+    "outputpath = f\"chia/chia_bio\"\n",
+    "trainpath = f\"chia/trains\"\n",
+    "testpath = f\"chia/tests\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "1000"
+      ]
+     },
+     "execution_count": 4,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "# Load input files names\n",
+    "inputfiles = set()\n",
+    "for f in os.listdir(inputpath):\n",
+    "    if f.endswith('.ann'):\n",
+    "        inputfiles.add(f.split('.')[0].split('_')[0])\n",
+    "len(inputfiles)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# list of entity types to retain\n",
+    "select_types = ['Condition', 'Value', 'Drug', 'Procedure', 'Measurement', 'Temporal', \\\n",
+    "    'Observation', 'Person', 'Mood', 'Device', 'Pregnancy_considerations']"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# convert Brat format into BIO format\n",
+    "# function for getting entity annotations from the annotation file\n",
+    "def get_annotation_entities(ann_file, select_types=None):\n",
+    "    entities = []\n",
+    "    with open(ann_file, \"r\", encoding=\"utf-8\") as f:\n",
+    "        for line in f:\n",
+    "            if line.startswith('T'):\n",
+    "                term = line.strip().split('\\t')[1].split()\n",
+    "                if (select_types != None) and (term[0] not in select_types): continue\n",
+    "                if int(term[-1]) <= int(term[1]): continue\n",
+    "                entities.append((int(term[1]), int(term[-1]), term[0]))\n",
+    "    return sorted(entities, key=lambda x: (x[0], x[1]))\n",
+    "\n",
+    "# function for handling overlap by keeping the entity with largest text span\n",
+    "def remove_overlap_entities(sorted_entities):\n",
+    "    keep_entities = []\n",
+    "    for idx, entity in enumerate(sorted_entities):\n",
+    "        if idx == 0:\n",
+    "            keep_entities.append(entity)\n",
+    "            last_keep = entity\n",
+    "            continue\n",
+    "        if entity[0] < last_keep[1]:\n",
+    "            if entity[1]-entity[0] > last_keep[1]-last_keep[0]:\n",
+    "                last_keep = entity\n",
+    "                keep_entities[-1] = last_keep\n",
+    "        elif entity[0] == last_keep[1]:\n",
+    "            last_keep = (last_keep[0], entity[1], last_keep[-1])\n",
+    "            keep_entities[-1] = last_keep\n",
+    "        else:\n",
+    "            last_keep = entity\n",
+    "            keep_entities.append(entity)\n",
+    "    return keep_entities\n",
+    "\n",
+    "# inverse index of entity annotations\n",
+    "def entity_dictionary(keep_entities, txt_file):\n",
+    "    f_ann = {}\n",
+    "    with open(txt_file, \"r\", encoding=\"utf-8\") as f:\n",
+    "        text = f.readlines()\n",
+    "        if file in ['NCT02348918_exc', 'NCT02348918_inc', 'NCT01735955_exc']:\n",
+    "            text = ' '.join([i.strip() for i in text])\n",
+    "        else:\n",
+    "            text = '  '.join([i.strip() for i in text])\n",
+    "    for entity in keep_entities:\n",
+    "        entity_text = text[entity[0]:entity[1]]\n",
+    "        doc = nlp(entity_text)\n",
+    "        token_starts = [(i, doc[i:].start_char) for i in range(len(doc))]\n",
+    "        term_type = entity[-1]\n",
+    "        term_offset = entity[0]\n",
+    "        for i, token in enumerate(doc):\n",
+    "            ann_offset = token_starts[i][1]+term_offset\n",
+    "            if ann_offset not in f_ann:\n",
+    "                f_ann[ann_offset] = [i, token.text, term_type]\n",
+    "    return f_ann\n",
+    "\n",
+    "# Brat -> BIO format conversion\n",
+    "for infile in inputfiles:\n",
+    "    for t in [\"exc\", \"inc\"]:\n",
+    "        file = f\"{infile}_{t}\"\n",
+    "        ann_file = f\"{inputpath}/{file}.ann\"\n",
+    "        txt_file = f\"{inputpath}/{file}.txt\"\n",
+    "        out_file = f\"{outputpath}/{file}.bio.txt\"\n",
+    "        sorted_entities = get_annotation_entities(ann_file, select_types)\n",
+    "        keep_entities = remove_overlap_entities(sorted_entities)\n",
+    "        f_ann = entity_dictionary(keep_entities, txt_file)\n",
+    "        with open(out_file, \"w\", encoding=\"utf-8\") as f_out:\n",
+    "            with open(txt_file, \"r\", encoding=\"utf-8\") as f:\n",
+    "                sent_offset = 0\n",
+    "                for line in f:\n",
+    "                    # print(line.strip())\n",
+    "                    if '⁄' in line:\n",
+    "                        # print(txt_file)\n",
+    "                        line = line.replace('⁄', '/') # replace non unicode characters\n",
+    "                    doc = nlp(line.strip())\n",
+    "                    token_starts = [(i, doc[i:].start_char) for i in range(len(doc))]\n",
+    "                    for token in doc:\n",
+    "                        token_sent_offset = token_starts[token.i][1]\n",
+    "                        token_doc_offset = token_starts[token.i][1]+sent_offset\n",
+    "                        if token_doc_offset in f_ann:\n",
+    "                            if f_ann[token_doc_offset][0] == 0:\n",
+    "                                label = f\"B-{f_ann[token_doc_offset][2]}\"\n",
+    "                            else:\n",
+    "                                label = f\"I-{f_ann[token_doc_offset][2]}\"\n",
+    "                        else:\n",
+    "                            label = f\"O\"\n",
+    "                        # print(token.text, token_sent_offset, token_sent_offset+len(token.text), token_doc_offset, token_doc_offset+len(token.text), label)\n",
+    "                        f_out.write(f\"{token.text} {token_sent_offset} {token_sent_offset+len(token.text)} {token_doc_offset} {token_doc_offset+len(token.text)} {label}\\n\")\n",
+    "                    # print('\\n')\n",
+    "                    f_out.write('\\n')\n",
+    "                    if file in ['NCT02348918_exc', 'NCT02348918_inc', 'NCT01735955_exc']: # 3 trials with inconsistent offsets\n",
+    "                        sent_offset += (len(line.strip())+1)\n",
+    "                    else:\n",
+    "                        sent_offset += (len(line.strip())+2)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "800 100 100\n"
+     ]
+    }
+   ],
+   "source": [
+    "# dataset separation: 800 trials (80%) for training, 100 trials (10%) for validation and 100 trials (10%) for testing\n",
+    "train_ids, dev_ids = train_test_split(list(inputfiles), train_size=0.8, random_state=13, shuffle=True)\n",
+    "dev_ids, test_ids = train_test_split(dev_ids, train_size=0.5, random_state=13, shuffle=True)\n",
+    "print(len(train_ids), len(dev_ids), len(test_ids))\n",
+    "chia_datasets = {\"train\":train_ids, \"dev\":dev_ids, \"test\":test_ids}\n",
+    "json.dump(chia_datasets, open(\"chia/chia_datasets.json\", \"w\", encoding=\"utf-8\"))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Merge BIO format train, validation and test datasets\n",
+    "# chia_datasets = json.load(open(\"chia/chia_datasets.json\", \"r\", encoding=\"utf-8\"))\n",
+    "# merge the train dataset\n",
+    "with open(\"chia/train.txt\", \"w\", encoding=\"utf-8\") as f:\n",
+    "    for fid in chia_datasets[\"train\"]:\n",
+    "        copyfile(f\"{outputpath}/{fid}_exc.bio.txt\", f\"{trainpath}/{fid}_exc.bio.txt\")\n",
+    "        copyfile(f\"{outputpath}/{fid}_inc.bio.txt\", f\"{trainpath}/{fid}_inc.bio.txt\")\n",
+    "        with open(f\"{outputpath}/{fid}_exc.bio.txt\", \"r\", encoding=\"utf-8\") as fr:\n",
+    "            txt = fr.read().strip()\n",
+    "            if txt != '':\n",
+    "                f.write(txt)\n",
+    "                f.write(\"\\n\\n\")\n",
+    "        with open(f\"{outputpath}/{fid}_inc.bio.txt\", \"r\", encoding=\"utf-8\") as fr:\n",
+    "            txt = fr.read().strip()\n",
+    "            if txt != '':\n",
+    "                f.write(txt)\n",
+    "                f.write(\"\\n\\n\")\n",
+    "\n",
+    "# merge the validation dataset\n",
+    "with open(\"chia/dev.txt\", \"w\", encoding=\"utf-8\") as f:\n",
+    "    for fid in chia_datasets[\"dev\"]:\n",
+    "        copyfile(f\"{outputpath}/{fid}_exc.bio.txt\", f\"{trainpath}/{fid}_exc.bio.txt\")\n",
+    "        copyfile(f\"{outputpath}/{fid}_inc.bio.txt\", f\"{trainpath}/{fid}_inc.bio.txt\")\n",
+    "        with open(f\"{outputpath}/{fid}_exc.bio.txt\", \"r\", encoding=\"utf-8\") as fr:\n",
+    "            txt = fr.read().strip()\n",
+    "            if txt != '':\n",
+    "                f.write(txt)\n",
+    "                f.write(\"\\n\\n\")\n",
+    "        with open(f\"{outputpath}/{fid}_inc.bio.txt\", \"r\", encoding=\"utf-8\") as fr:\n",
+    "            txt = fr.read().strip()\n",
+    "            if txt != '':\n",
+    "                f.write(txt)\n",
+    "                f.write(\"\\n\\n\")\n",
+    "\n",
+    "# merge the test dataset\n",
+    "with open(\"chia/test.txt\", \"w\", encoding=\"utf-8\") as f:\n",
+    "    for fid in chia_datasets[\"test\"]:\n",
+    "        copyfile(f\"{outputpath}/{fid}_exc.bio.txt\", f\"{testpath}/{fid}_exc.bio.txt\")\n",
+    "        copyfile(f\"{outputpath}/{fid}_inc.bio.txt\", f\"{testpath}/{fid}_inc.bio.txt\")\n",
+    "        with open(f\"{outputpath}/{fid}_exc.bio.txt\", \"r\", encoding=\"utf-8\") as fr:\n",
+    "            txt = fr.read().strip()\n",
+    "            if txt != '':\n",
+    "                f.write(txt)\n",
+    "                f.write(\"\\n\\n\")\n",
+    "        with open(f\"{outputpath}/{fid}_inc.bio.txt\", \"r\", encoding=\"utf-8\") as fr:\n",
+    "            txt = fr.read().strip()\n",
+    "            if txt != '':\n",
+    "                f.write(txt)\n",
+    "                f.write(\"\\n\\n\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# convert Chia in Brat into format for Att-BiLSTM-CRF model\n",
+    "out_file = f\"chia/chia_ner.tsv\"\n",
+    "with open(out_file, \"w\", encoding=\"utf-8\") as f_out:\n",
+    "    for infile in inputfiles:\n",
+    "        for t in [\"exc\", \"inc\"]:\n",
+    "            file = f\"{infile}_{t}\"\n",
+    "            ann_file = f\"{inputpath}/{file}.ann\"\n",
+    "            txt_file = f\"{inputpath}/{file}.txt\"\n",
+    "            sorted_entities = get_annotation_entities(ann_file, select_types)\n",
+    "            keep_entities = remove_overlap_entities(sorted_entities)\n",
+    "            with open(txt_file, \"r\", encoding=\"utf-8\") as f:\n",
+    "                sent_offset = 0\n",
+    "                for line in f:\n",
+    "                    # print(line.strip())\n",
+    "                    if '⁄' in line: line = line.replace('⁄', '/')\n",
+    "                    sent_end = sent_offset + len(line)\n",
+    "                    sent_ents = []\n",
+    "                    for ent in keep_entities:\n",
+    "                        if ent[0] < sent_offset or ent[1] < sent_offset: continue\n",
+    "                        if ent[0] >= sent_end or ent[1] > sent_offset+len(line.strip()): break\n",
+    "                        ent_start = ent[0]-sent_offset+1\n",
+    "                        ent_end = ent[1]-sent_offset+1\n",
+    "                        sent_ents.append(f\"{ent_start}:{ent_end}:{ent[2].lower()}\")\n",
+    "                    if sent_ents == []:\n",
+    "                        if file in ['NCT02348918_exc', 'NCT02348918_inc', 'NCT01735955_exc']:\n",
+    "                            sent_offset += (len(line.strip())+1)\n",
+    "                        else:\n",
+    "                            sent_offset += (len(line.strip())+2)\n",
+    "                        continue\n",
+    "                    # print(f\"{file}\\t{','.join(sent_ents)}\\t{line.strip()}\")\n",
+    "                    f_out.write(f\"{file}\\t{','.join(sent_ents)}\\t{line.strip()}\")\n",
+    "                    # print('\\n')\n",
+    "                    f_out.write('\\n')\n",
+    "                    if file in ['NCT02348918_exc', 'NCT02348918_inc', 'NCT01735955_exc']:\n",
+    "                        sent_offset += (len(line.strip())+1)\n",
+    "                    else:\n",
+    "                        sent_offset += (len(line.strip())+2)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# split Chia in format for Att-BiLSTM-CRF model into train, validation and test datasets\n",
+    "# chia_datasets = json.load(open(\"chia/chia_datasets.json\", \"r\", encoding=\"utf-8\"))\n",
+    "with open(\"chia/chia_ner_train.tsv\", \"w\", encoding=\"utf-8\") as ftrain, open(\"chia/chia_ner_dev.tsv\", \"w\", encoding=\"utf-8\") as fdev, open(\"chia/chia_ner_test.tsv\", \"w\", encoding=\"utf-8\") as ftest:\n",
+    "    with open(\"chia/chia_ner.tsv\", \"r\", encoding=\"utf-8\") as fread:\n",
+    "        for line in fread:\n",
+    "            if line.split('\\t', 1)[0].split(\"_\")[0] in chia_datasets[\"train\"]:\n",
+    "                ftrain.write(line)\n",
+    "            elif line.split('\\t', 1)[0].split(\"_\")[0] in chia_datasets[\"dev\"]:\n",
+    "                fdev.write(line)\n",
+    "            else:\n",
+    "                ftest.write(line)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Facebook Research Data (FRD) Preprocessing"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "fbnerfile = f\"fbner/medical_ner.tsv\"\n",
+    "outputpath = f\"fbner/fb_bio\"\n",
+    "trainpath = f\"fbner/trains\"\n",
+    "testpath = f\"fbner/tests\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# loading FRD data\n",
+    "fbner = {}\n",
+    "with open(fbnerfile, \"r\", encoding='utf-8') as f:\n",
+    "    for line in f:\n",
+    "        line = line.strip().split('\\t')\n",
+    "        ents = line[1].split(',')\n",
+    "        ents = [ent.split(':') for ent in ents]\n",
+    "        fbner[line[0]] = fbner.get(line[0], [])\n",
+    "        fbner[line[0]].append({'text':line[-1], 'entities':ents})\n",
+    "json.dump(fbner, open(f\"./fbner/medical_ner.json\", \"w\", encoding=\"utf-8\"))\n",
+    "inputfiles = list(fbner.keys())"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# convert FRD into BIO format\n",
+    "for k, v in fbner.items():\n",
+    "    out_file = f\"{outputpath}/{k}.bio.txt\"\n",
+    "    with open(out_file, \"w\", encoding=\"utf-8\") as f_out:\n",
+    "        sent_offset = 0\n",
+    "        for sent in v:\n",
+    "            sent_text = sent['text']\n",
+    "            f_ann = {}\n",
+    "            # the entity location dictionary\n",
+    "            for ent in sent['entities']:\n",
+    "                entity_text = sent_text[int(ent[0])-1:int(ent[1])-1]\n",
+    "                doc = nlp(entity_text)\n",
+    "                token_starts = [(i, doc[i:].start_char) for i in range(len(doc))]\n",
+    "                term_type = ent[-1]\n",
+    "                term_offset = int(ent[0])-1+sent_offset\n",
+    "                for i, token in enumerate(doc):\n",
+    "                    ann_offset = token_starts[i][1]+term_offset\n",
+    "                    if ann_offset not in f_ann:\n",
+    "                        f_ann[ann_offset] = [i, token.text, term_type]\n",
+    "            # convert to bio format\n",
+    "            doc = nlp(sent_text)\n",
+    "            token_starts = [(i, doc[i:].start_char) for i in range(len(doc))]\n",
+    "            for token in doc:\n",
+    "                token_sent_offset = token_starts[token.i][1]\n",
+    "                token_doc_offset = token_starts[token.i][1]+sent_offset\n",
+    "                if token_doc_offset in f_ann:\n",
+    "                    if f_ann[token_doc_offset][0] == 0:\n",
+    "                        label = f\"B-{f_ann[token_doc_offset][2]}\"\n",
+    "                    else:\n",
+    "                        label = f\"I-{f_ann[token_doc_offset][2]}\"\n",
+    "                else:\n",
+    "                    label = f\"O\"\n",
+    "                # print(token.text, token_sent_offset, token_sent_offset+len(token.text), token_doc_offset, token_doc_offset+len(token.text), label)\n",
+    "                f_out.write(f\"{token.text} {token_sent_offset} {token_sent_offset+len(token.text)} {token_doc_offset} {token_doc_offset+len(token.text)} {label}\\n\")\n",
+    "            # print('\\n')\n",
+    "            f_out.write('\\n')\n",
+    "            sent_offset += (len(sent_text)+1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "2651 331 332\n"
+     ]
+    }
+   ],
+   "source": [
+    "# split FRD into train, validation and test datasets\n",
+    "train_ids, dev_ids = train_test_split(inputfiles, train_size=0.8, random_state=13, shuffle=True)\n",
+    "dev_ids, test_ids = train_test_split(dev_ids, train_size=0.5, random_state=13, shuffle=True)\n",
+    "print(len(train_ids), len(dev_ids), len(test_ids))\n",
+    "fbner_datasets = {\"train\":train_ids, \"dev\":dev_ids, \"test\":test_ids}\n",
+    "json.dump(fbner_datasets, open(\"fbner/fbner_datasets.json\", \"w\", encoding=\"utf-8\"))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Merge BIO format train, validation and test datasets\n",
+    "# fbner_datasets = json.dump(open(\"fbner/fbner_datasets.json\", \"r\", encoding=\"utf-8\"))\n",
+    "# merge the train dataset\n",
+    "with open(\"fbner/train.txt\", \"w\", encoding=\"utf-8\") as f:\n",
+    "    for fid in fbner_datasets[\"train\"]:\n",
+    "        copyfile(f\"{outputpath}/{fid}.bio.txt\", f\"{trainpath}/{fid}.bio.txt\")\n",
+    "        with open(f\"{outputpath}/{fid}.bio.txt\", \"r\", encoding=\"utf-8\") as fr:\n",
+    "            txt = fr.read().strip()\n",
+    "            if txt != '':\n",
+    "                f.write(txt)\n",
+    "                f.write(\"\\n\\n\")\n",
+    "\n",
+    "# merge the validation dataset\n",
+    "with open(\"fbner/dev.txt\", \"w\", encoding=\"utf-8\") as f:\n",
+    "    for fid in fbner_datasets[\"dev\"]:\n",
+    "        copyfile(f\"{outputpath}/{fid}.bio.txt\", f\"{trainpath}/{fid}.bio.txt\")\n",
+    "        with open(f\"{outputpath}/{fid}.bio.txt\", \"r\", encoding=\"utf-8\") as fr:\n",
+    "            txt = fr.read().strip()\n",
+    "            if txt != '':\n",
+    "                f.write(txt)\n",
+    "                f.write(\"\\n\\n\")\n",
+    "\n",
+    "# merge the test dataset\n",
+    "with open(\"fbner/test.txt\", \"w\", encoding=\"utf-8\") as f:\n",
+    "    for fid in fbner_datasets[\"test\"]:\n",
+    "        copyfile(f\"{outputpath}/{fid}.bio.txt\", f\"{testpath}/{fid}.bio.txt\")\n",
+    "        with open(f\"{outputpath}/{fid}.bio.txt\", \"r\", encoding=\"utf-8\") as fr:\n",
+    "            txt = fr.read().strip()\n",
+    "            if txt != '':\n",
+    "                f.write(txt)\n",
+    "                f.write(\"\\n\\n\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# resplit the processed data\n",
+    "# fbner_datasets = json.load(open(\"fbner/fbner_datasets.json\", \"r\", encoding=\"utf-8\"))\n",
+    "\n",
+    "with open(\"fbner/fbner_ner_train.tsv\", \"w\", encoding=\"utf-8\") as ftrain, open(\"fbner/fbner_ner_dev.tsv\", \"w\", encoding=\"utf-8\") as fdev, open(\"fbner/fbner_ner_test.tsv\", \"w\", encoding=\"utf-8\") as ftest:\n",
+    "    with open(\"fbner/medical_ner.tsv\", \"r\", encoding=\"utf-8\") as fread:\n",
+    "        for line in fread:\n",
+    "            if line.split('\\t', 1)[0] in fbner_datasets[\"train\"]:\n",
+    "                ftrain.write(line)\n",
+    "            elif line.split('\\t', 1)[0] in fbner_datasets[\"dev\"]:\n",
+    "                fdev.write(line)\n",
+    "            else:\n",
+    "                ftest.write(line)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Att-BiLSTM-CRF Model Performance Analysis"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def abc_strict_match(gs, pred, s_idx, e_idx, ent_type):\n",
+    "    if s_idx == 0:\n",
+    "        for idx in range(s_idx, e_idx):\n",
+    "            if gs[idx] != pred[idx]:\n",
+    "                return False\n",
+    "        if e_idx < len(gs):\n",
+    "            if gs[e_idx] == ent_type or pred[e_idx] == ent_type:\n",
+    "                return False\n",
+    "    else:\n",
+    "        if gs[s_idx-1] == ent_type or pred[s_idx-1] == ent_type:\n",
+    "            return False\n",
+    "        for idx in range(s_idx, e_idx):\n",
+    "            if gs[idx] != pred[idx]:\n",
+    "                return False\n",
+    "        if e_idx < len(gs):\n",
+    "            if gs[e_idx] == ent_type or pred[e_idx] == ent_type:\n",
+    "                return False\n",
+    "    return True\n",
+    "\n",
+    "def abc_relax_match(gs, pred, s_idx, e_idx, ent_type):\n",
+    "    for idx in range(s_idx, e_idx):\n",
+    "        if gs[idx] == pred[idx] == ent_type:\n",
+    "            return True\n",
+    "    return False"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset = 'chia' # or 'frd', 'chia'\n",
+    "outfolder = f\"attbilstmcrf\"\n",
+    "outfile = f\"{dataset}_attbilstmcrf_results.txt\"\n",
+    "labels_dict = {'chia':['Mood', 'Condition', 'Procedure', 'Measurement', 'Value', 'Drug', 'Temporal', 'Observation', 'Pregnancy', 'Person', 'Device'], 'frd':['chronic_disease', 'treatment', 'upper_bound', 'pregnancy', 'clinical_variable', 'lower_bound', 'cancer', 'age', 'language_fluency', 'gender', 'contraception_consent', 'technology_access', 'allergy_name', 'bmi', 'ethnicity']}\n",
+    "labels = labels_dict[dataset]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 21,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "eval_metrics = {\"category\":{},\"overall\":{}, \"prediction\":{}}\n",
+    "with open(f\"{outfolder}/{outfile}\", \"r\", encoding=\"utf-8\") as f:\n",
+    "    next(f)\n",
+    "    for line in f:\n",
+    "        gs = eval(line.split('\\t')[1])\n",
+    "        pred = eval(line.split('\\t')[0])\n",
+    "        for i in zip(gs, pred):\n",
+    "            if i[0] == i[1]: eval_metrics[\"overall\"][\"acc_true\"] = eval_metrics[\"overall\"].get(\"acc_true\", 0) + 1\n",
+    "            else: eval_metrics[\"overall\"][\"acc_false\"] = eval_metrics[\"overall\"].get(\"acc_false\", 0) + 1\n",
+    "        llen = len(gs)\n",
+    "        cur_idx = 0\n",
+    "        while cur_idx < llen:\n",
+    "            if gs[cur_idx] == 0:\n",
+    "                cur_idx += 1\n",
+    "            else:\n",
+    "                start_idx = cur_idx\n",
+    "                end_idx = start_idx + 1\n",
+    "                cate = gs[start_idx]\n",
+    "                while end_idx < llen and gs[end_idx] == cate:\n",
+    "                    end_idx += 1\n",
+    "                eval_metrics[\"overall\"]['gs'] = eval_metrics[\"overall\"].get('gs', {})\n",
+    "                eval_metrics[\"overall\"]['gs']['count'] = eval_metrics[\"overall\"]['gs'].get('count', 0) + 1\n",
+    "                eval_metrics[\"overall\"]['gs'][labels[cate-1]] = eval_metrics[\"overall\"]['gs'].get(labels[cate-1], 0) + 1\n",
+    "                if abc_strict_match(gs, pred, start_idx, end_idx, cate):\n",
+    "                    eval_metrics[\"overall\"][\"strict_predicted\"] = eval_metrics[\"overall\"].get(\"strict_predicted\", 0) + 1\n",
+    "                    eval_metrics[\"category\"][labels[cate-1]] = eval_metrics[\"category\"].get(labels[cate-1], {\"strict\":0, \"relax\":0, \"miss\":0})\n",
+    "                    eval_metrics[\"category\"][labels[cate-1]][\"strict\"] += 1\n",
+    "                elif abc_relax_match(gs, pred, start_idx, end_idx, cate):\n",
+    "                    eval_metrics[\"overall\"][\"relax_predicted\"] = eval_metrics[\"overall\"].get(\"relax_predicted\", 0) + 1\n",
+    "                    eval_metrics[\"category\"][labels[cate-1]] = eval_metrics[\"category\"].get(labels[cate-1], {\"strict\":0, \"relax\":0, \"miss\":0})\n",
+    "                    eval_metrics[\"category\"][labels[cate-1]][\"relax\"] += 1\n",
+    "                else:\n",
+    "                    eval_metrics[\"overall\"][\"miss_predicted\"] = eval_metrics[\"overall\"].get(\"miss_predicted\", 0) + 1\n",
+    "                    eval_metrics[\"category\"][labels[cate-1]] = eval_metrics[\"category\"].get(labels[cate-1], {\"strict\":0, \"relax\":0, \"miss\":0})\n",
+    "                    eval_metrics[\"category\"][labels[cate-1]][\"miss\"] += 1\n",
+    "                cur_idx = end_idx\n",
+    "        cur_idx = 0\n",
+    "        while cur_idx < llen:\n",
+    "            if pred[cur_idx] == 0:\n",
+    "                cur_idx += 1\n",
+    "            else:\n",
+    "                start_idx = cur_idx\n",
+    "                end_idx = start_idx + 1\n",
+    "                cate = pred[start_idx]\n",
+    "                while end_idx < llen and pred[end_idx] == cate:\n",
+    "                    end_idx += 1\n",
+    "                if abc_strict_match(gs, pred, start_idx, end_idx, cate):\n",
+    "                    eval_metrics[\"overall\"][\"strict_predict\"] = eval_metrics[\"overall\"].get(\"strict_predict\", 0) + 1\n",
+    "                    eval_metrics[\"prediction\"][labels[cate-1]] = eval_metrics[\"prediction\"].get(labels[cate-1], {\"strict\":0, \"relax\":0, \"miss\":0})\n",
+    "                    eval_metrics[\"prediction\"][labels[cate-1]][\"strict\"] += 1\n",
+    "                elif abc_relax_match(gs, pred, start_idx, end_idx, cate):\n",
+    "                    eval_metrics[\"overall\"][\"relax_predict\"] = eval_metrics[\"overall\"].get(\"relax_predict\", 0) + 1\n",
+    "                    eval_metrics[\"prediction\"][labels[cate-1]] = eval_metrics[\"prediction\"].get(labels[cate-1], {\"strict\":0, \"relax\":0, \"miss\":0})\n",
+    "                    eval_metrics[\"prediction\"][labels[cate-1]][\"relax\"] += 1\n",
+    "                else:\n",
+    "                    eval_metrics[\"overall\"][\"miss_predict\"] = eval_metrics[\"overall\"].get(\"miss_predict\", 0) + 1\n",
+    "                    eval_metrics[\"prediction\"][labels[cate-1]] = eval_metrics[\"prediction\"].get(labels[cate-1], {\"strict\":0, \"relax\":0, \"miss\":0})\n",
+    "                    eval_metrics[\"prediction\"][labels[cate-1]][\"miss\"] += 1\n",
+    "                cur_idx = end_idx"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 22,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Overall Relax Level: Precision: 0.706389088298636, Recall: 0.734399375975039, F1: 0.7201219589214525\n",
+      "Overall Strict Level: Precision: 0.35857860732232594, Recall: 0.38962558502340094, F1: 0.3734579439252337\n",
+      "\n",
+      "\n",
+      "Relax Level for Condition: Precision: 0.8081587651598677, Recall: 0.8273045507584598, F1: 0.8176195915182294\n",
+      "Strict Level for Condition: Precision: 0.43439911797133407, Recall: 0.4597432905484247, F1: 0.4467120181405896\n",
+      "\n",
+      "\n",
+      "Relax Level for Procedure: Precision: 0.6728110599078341, Recall: 0.4857142857142857, F1: 0.5641550176156381\n",
+      "Strict Level for Procedure: Precision: 0.2626728110599078, Recall: 0.20357142857142857, F1: 0.22937625754527163\n",
+      "\n",
+      "\n",
+      "Relax Level for Temporal: Precision: 0.5663956639566395, Recall: 0.853448275862069, F1: 0.6809049773755657\n",
+      "Strict Level for Temporal: Precision: 0.2601626016260163, Recall: 0.41379310344827586, F1: 0.3194675540765391\n",
+      "\n",
+      "\n",
+      "Relax Level for Pregnancy: Precision: 0.38596491228070173, Recall: 0.5555555555555556, F1: 0.45548654244306414\n",
+      "Strict Level for Pregnancy: Precision: 0.03508771929824561, Recall: 0.1111111111111111, F1: 0.05333333333333334\n",
+      "\n",
+      "\n",
+      "Relax Level for Observation: Precision: 0.5342465753424658, Recall: 0.42011834319526625, F1: 0.47035841685068797\n",
+      "Strict Level for Observation: Precision: 0.18493150684931506, Recall: 0.15976331360946747, F1: 0.17142857142857143\n",
+      "\n",
+      "\n",
+      "Relax Level for Drug: Precision: 0.6677115987460815, Recall: 0.8167330677290837, F1: 0.7347422975315082\n",
+      "Strict Level for Drug: Precision: 0.2884012539184953, Recall: 0.3665338645418327, F1: 0.32280701754385965\n",
+      "\n",
+      "\n",
+      "Relax Level for Person: Precision: 0.7818181818181819, Recall: 0.7589285714285714, F1: 0.7702033505426193\n",
+      "Strict Level for Person: Precision: 0.6363636363636364, Recall: 0.625, F1: 0.6306306306306306\n",
+      "\n",
+      "\n",
+      "Relax Level for Value: Precision: 0.7839506172839507, Recall: 0.7961783439490446, F1: 0.790017168877056\n",
+      "Strict Level for Value: Precision: 0.5030864197530864, Recall: 0.5191082802547771, F1: 0.5109717868338558\n",
+      "\n",
+      "\n",
+      "Relax Level for Measurement: Precision: 0.7153284671532847, Recall: 0.735632183908046, F1: 0.7253382676072626\n",
+      "Strict Level for Measurement: Precision: 0.3467153284671533, Recall: 0.36398467432950193, F1: 0.3551401869158879\n",
+      "\n",
+      "\n",
+      "Relax Level for Mood: Precision: 0.3888888888888889, Recall: 0.30434782608695654, F1: 0.34146341463414637\n",
+      "Strict Level for Mood: Precision: 0.05555555555555555, Recall: 0.043478260869565216, F1: 0.04878048780487805\n",
+      "\n",
+      "\n",
+      "Relax Level for Device: Precision: 0.6296296296296297, Recall: 0.5416666666666666, F1: 0.5823451910408431\n",
+      "Strict Level for Device: Precision: 0.037037037037037035, Recall: 0.041666666666666664, F1: 0.03921568627450981\n",
+      "\n",
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "attbc_metrics = {\"category\":{},\"overall\":{}}\n",
+    "attbc_metrics[\"overall\"][\"acc\"] = eval_metrics[\"overall\"][\"acc_true\"]/(eval_metrics[\"overall\"][\"acc_true\"]+eval_metrics[\"overall\"][\"acc_false\"])\n",
+    "pred_all = eval_metrics[\"overall\"]['strict_predict'] + eval_metrics[\"overall\"]['relax_predict'] + eval_metrics[\"overall\"]['miss_predict']\n",
+    "pre_relax_all = (eval_metrics[\"overall\"]['strict_predict'] + eval_metrics[\"overall\"]['relax_predict'])/ pred_all\n",
+    "rec_relax_all = (eval_metrics[\"overall\"]['strict_predicted'] + eval_metrics[\"overall\"]['relax_predicted'])/ eval_metrics[\"overall\"]['gs']['count']\n",
+    "f1_relax_all = (2*pre_relax_all*rec_relax_all)/(pre_relax_all+rec_relax_all)\n",
+    "print(f\"Overall Relax Level: Precision: {pre_relax_all}, Recall: {rec_relax_all}, F1: {f1_relax_all}\")\n",
+    "attbc_metrics[\"overall\"][\"relax\"] = {\"f_score\": f1_relax_all, \"precision\":pre_relax_all, \"recall\":rec_relax_all}\n",
+    "\n",
+    "pre_strict_all = eval_metrics[\"overall\"]['strict_predict'] / pred_all\n",
+    "rec_strict_all = eval_metrics[\"overall\"]['strict_predicted'] / eval_metrics[\"overall\"]['gs']['count']\n",
+    "f1_strict_all = (2*pre_strict_all*rec_strict_all)/(pre_strict_all+rec_strict_all)\n",
+    "print(f\"Overall Strict Level: Precision: {pre_strict_all}, Recall: {rec_strict_all}, F1: {f1_strict_all}\")\n",
+    "print('\\n')\n",
+    "attbc_metrics[\"overall\"][\"strict\"] = {\"f_score\": f1_strict_all, \"precision\":pre_strict_all, \"recall\":rec_strict_all}\n",
+    "\n",
+    "for i in eval_metrics[\"category\"].keys():\n",
+    "    tt = eval_metrics[\"overall\"]['gs'][i]\n",
+    "    tp = eval_metrics[\"prediction\"][i]['strict'] + eval_metrics[\"prediction\"][i]['relax'] + eval_metrics[\"prediction\"][i]['miss']\n",
+    "    \n",
+    "    pre_relax = (eval_metrics[\"prediction\"][i]['strict']+eval_metrics[\"prediction\"][i]['relax'])/tp\n",
+    "    rec_relax = (eval_metrics[\"category\"][i]['strict']+eval_metrics[\"category\"][i]['relax'])/tt\n",
+    "    f1_relax = (2*pre_relax*rec_relax)/(pre_relax+rec_relax)\n",
+    "    print(f\"Relax Level for {i}: Precision: {pre_relax}, Recall: {rec_relax}, F1: {f1_relax}\")\n",
+    "    attbc_metrics[\"category\"][\"relax\"] = attbc_metrics[\"category\"].get(\"relax\", {})\n",
+    "    attbc_metrics[\"category\"][\"relax\"][i] = {\"f_score\": f1_relax, \"precision\":pre_relax, \"recall\":rec_relax}\n",
+    "\n",
+    "    pre_strict = eval_metrics[\"prediction\"][i]['strict']/tp\n",
+    "    rec_strict = eval_metrics[\"category\"][i]['strict']/tt\n",
+    "    f1_strict = (2*pre_strict*rec_strict)/(pre_strict+rec_strict) if (pre_strict+rec_strict) != 0 else 0.0\n",
+    "    print(f\"Strict Level for {i}: Precision: {pre_strict}, Recall: {rec_strict}, F1: {f1_strict}\")\n",
+    "    print('\\n')\n",
+    "    attbc_metrics[\"category\"][\"strict\"] = attbc_metrics[\"category\"].get(\"strict\", {})\n",
+    "    attbc_metrics[\"category\"][\"strict\"][i] = {\"f_score\": f1_strict, \"precision\":pre_strict, \"recall\":rec_strict}\n",
+    "json.dump(attbc_metrics, open(f\"{outfolder}/{outfile}.json\", \"w\", encoding=\"utf-8\"), indent=4)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Transformer-based Model Performance and Error Analysis"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 23,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# strict matching function\n",
+    "def bio_strict_match(gs, pred, s_idx, e_idx, en_type):\n",
+    "    if gs[s_idx] != f\"B-{en_type}\" or pred[s_idx] != f\"B-{en_type}\":\n",
+    "        return False\n",
+    "    # every token in the span need to have the same label\n",
+    "    for idx in range(s_idx, e_idx):\n",
+    "        if gs[idx] != pred[idx]:\n",
+    "            return False\n",
+    "    # token after end in GS is not continued entity token\n",
+    "    # if e_idx < len(gs) and gs[e_idx] == f\"I-{en_type}\":\n",
+    "    if e_idx < len(gs) and (pred[e_idx] == f\"I-{en_type}\" or gs[e_idx] == f\"I-{en_type}\"):\n",
+    "        return False\n",
+    "    return True\n",
+    "\n",
+    "# relax matching function\n",
+    "def bio_relax_match(gs, pred, s_idx, e_idx, en_type):\n",
+    "    for idx in range(s_idx, e_idx):\n",
+    "        gs_cate = gs[idx].split(\"-\")[-1] if \"-\" in gs[idx] else \"O\"\n",
+    "        pred_cate = pred[idx].split(\"-\")[-1] if \"-\" in pred[idx] else \"O\"\n",
+    "        if gs_cate == pred_cate == en_type:\n",
+    "            return True\n",
+    "    return False"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 29,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset = 'chia' # or 'frd', 'chia'\n",
+    "outfolder = f\"transformer\"\n",
+    "test_files = os.listdir(f\"{dataset}/tests/\")\n",
+    "models = ['bert', 'bert_mimic', 'albert', 'albert_mimic', 'roberta', 'roberta_mimic', 'electra', 'electra_mimic'] #'distilbert'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 30,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "model = models[5]\n",
+    "predictions = {}\n",
+    "eval_metrics = {\"category\":{},\"overall\":{}, \"prediction\":{}}\n",
+    "for file in test_files:\n",
+    "    file_id = file.split('.')[0]\n",
+    "    # load the annotation and prediction files\n",
+    "    with open(f\"{dataset}/tests/{file}\", \"r\", encoding=\"utf-8\") as f:\n",
+    "        test_anno = f.read().strip().split('\\n\\n')\n",
+    "        test_anno = [sent.split('\\n') for sent in test_anno]\n",
+    "    with open(f\"{outfolder}/{dataset}_results/{dataset}_{model}_results/{file}\", \"r\", encoding=\"utf-8\") as f:\n",
+    "        test_pred = f.read().strip().split('\\n\\n')\n",
+    "        test_pred = [sent.split('\\n') for sent in test_pred]\n",
+    "    assert len(test_anno) == len(test_pred)\n",
+    "    # compare annotation label and prediction label sentence by sentence\n",
+    "    file_preds = {\"predicted\":{}, \"prediction\":{}}\n",
+    "    for anno, pred in zip(test_anno, test_pred):\n",
+    "        assert len(anno) == len(pred)\n",
+    "        anno_bio = [i.split()[-1] for i in anno]\n",
+    "        pred_bio = [i.split()[-1] for i in pred]\n",
+    "        for i in zip(anno_bio, pred_bio):\n",
+    "            if i[0] == i[1]: eval_metrics[\"overall\"][\"acc_true\"] = eval_metrics[\"overall\"].get(\"acc_true\", 0) + 1\n",
+    "            else: eval_metrics[\"overall\"][\"acc_false\"] = eval_metrics[\"overall\"].get(\"acc_false\", 0) + 1\n",
+    "        # process gold standard\n",
+    "        llen = len(anno)\n",
+    "        cur_idx = 0\n",
+    "        while cur_idx < llen:\n",
+    "            if anno_bio[cur_idx].strip() == 'O':\n",
+    "                cur_idx += 1\n",
+    "            else:\n",
+    "                start_idx = cur_idx\n",
+    "                end_idx = start_idx + 1\n",
+    "                _, cate = anno_bio[start_idx].strip().split('-')\n",
+    "                while end_idx < llen and anno_bio[end_idx].strip() == f\"I-{cate}\":\n",
+    "                    end_idx += 1\n",
+    "                match_entity = [f\"{anno[idx]} {pred_bio[idx]}\" for idx in range(start_idx, end_idx)]\n",
+    "                eval_metrics[\"overall\"]['gs'] = eval_metrics[\"overall\"].get('gs', {})\n",
+    "                eval_metrics[\"overall\"]['gs']['count'] = eval_metrics[\"overall\"]['gs'].get('count', 0) + 1\n",
+    "                eval_metrics[\"overall\"]['gs'][cate] = eval_metrics[\"overall\"]['gs'].get(cate, 0) + 1\n",
+    "                if bio_strict_match(anno_bio, pred_bio, start_idx, end_idx, cate):\n",
+    "                    file_preds[\"predicted\"]['strict'] = file_preds[\"predicted\"].get('strict', [])\n",
+    "                    file_preds[\"predicted\"]['strict'].append(match_entity)\n",
+    "                    eval_metrics[\"overall\"][\"strict_predicted\"] = eval_metrics[\"overall\"].get(\"strict_predicted\", 0) + 1\n",
+    "                    eval_metrics[\"category\"][cate] = eval_metrics[\"category\"].get(cate, {\"strict\":0, \"relax\":0, \"miss\":0})\n",
+    "                    eval_metrics[\"category\"][cate][\"strict\"] += 1\n",
+    "                elif bio_relax_match(anno_bio, pred_bio, start_idx, end_idx, cate):\n",
+    "                    file_preds[\"predicted\"]['relax'] = file_preds[\"predicted\"].get('relax', [])\n",
+    "                    file_preds[\"predicted\"]['relax'].append(match_entity)\n",
+    "                    eval_metrics[\"overall\"][\"relax_predicted\"] = eval_metrics[\"overall\"].get(\"relax_predicted\", 0) + 1\n",
+    "                    eval_metrics[\"category\"][cate] = eval_metrics[\"category\"].get(cate, {\"strict\":0, \"relax\":0, \"miss\":0})\n",
+    "                    eval_metrics[\"category\"][cate][\"relax\"] += 1\n",
+    "                else:\n",
+    "                    file_preds[\"predicted\"]['miss'] = file_preds[\"predicted\"].get('miss', [])\n",
+    "                    file_preds[\"predicted\"]['miss'].append(match_entity)\n",
+    "                    eval_metrics[\"overall\"][\"miss_predicted\"] = eval_metrics[\"overall\"].get(\"miss_predicted\", 0) + 1\n",
+    "                    eval_metrics[\"category\"][cate] = eval_metrics[\"category\"].get(cate, {\"strict\":0, \"relax\":0, \"miss\":0})\n",
+    "                    eval_metrics[\"category\"][cate][\"miss\"] += 1\n",
+    "                cur_idx = end_idx\n",
+    "        cur_idx = 0\n",
+    "        while cur_idx < llen:\n",
+    "            if pred_bio[cur_idx].strip() == 'O':\n",
+    "                cur_idx += 1\n",
+    "            else:\n",
+    "                start_idx = cur_idx\n",
+    "                end_idx = start_idx + 1\n",
+    "                _, cate = pred_bio[start_idx].strip().split('-')\n",
+    "                while end_idx < llen and pred_bio[end_idx].strip() == f\"I-{cate}\":\n",
+    "                    end_idx += 1\n",
+    "                match_entity = [f\"{anno[idx]} {pred_bio[idx]}\" for idx in range(start_idx, end_idx)]\n",
+    "                if bio_strict_match(anno_bio, pred_bio, start_idx, end_idx, cate):\n",
+    "                    file_preds[\"prediction\"]['strict'] = file_preds[\"prediction\"].get('strict', [])\n",
+    "                    file_preds[\"prediction\"]['strict'].append(match_entity)\n",
+    "                    eval_metrics[\"overall\"][\"strict_predict\"] = eval_metrics[\"overall\"].get(\"strict_predict\", 0) + 1\n",
+    "                    eval_metrics[\"prediction\"][cate] = eval_metrics[\"prediction\"].get(cate, {\"strict\":0, \"relax\":0, \"miss\":0})\n",
+    "                    eval_metrics[\"prediction\"][cate][\"strict\"] += 1\n",
+    "                elif bio_relax_match(anno_bio, pred_bio, start_idx, end_idx, cate):\n",
+    "                    file_preds[\"prediction\"]['relax'] = file_preds[\"prediction\"].get('relax', [])\n",
+    "                    file_preds[\"prediction\"]['relax'].append(match_entity)\n",
+    "                    eval_metrics[\"overall\"][\"relax_predict\"] = eval_metrics[\"overall\"].get(\"relax_predict\", 0) + 1\n",
+    "                    eval_metrics[\"prediction\"][cate] = eval_metrics[\"prediction\"].get(cate, {\"strict\":0, \"relax\":0, \"miss\":0})\n",
+    "                    eval_metrics[\"prediction\"][cate][\"relax\"] += 1\n",
+    "                else:\n",
+    "                    file_preds[\"prediction\"]['miss'] = file_preds[\"prediction\"].get('miss', [])\n",
+    "                    file_preds[\"prediction\"]['miss'].append(match_entity)\n",
+    "                    eval_metrics[\"overall\"][\"miss_predict\"] = eval_metrics[\"overall\"].get(\"miss_predict\", 0) + 1\n",
+    "                    eval_metrics[\"prediction\"][cate] = eval_metrics[\"prediction\"].get(cate, {\"strict\":0, \"relax\":0, \"miss\":0})\n",
+    "                    eval_metrics[\"prediction\"][cate][\"miss\"] += 1\n",
+    "                cur_idx = end_idx\n",
+    "    predictions[file_id] = file_preds\n",
+    "json.dump(eval_metrics, open(f\"{outfolder}/{dataset}_results/{dataset}_{model}_prediction.json\", \"w\", encoding=\"utf-8\"), indent=4)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 31,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Overall Relax Level: Precision: 0.771102433163112, Recall: 0.8174917491749175, F1: 0.7936197726115226\n",
+      "Overall Strict Level: Precision: 0.6158005407029138, Recall: 0.6765676567656765, F1: 0.6447554646957069\n",
+      "\n",
+      "\n",
+      "Relax Level for Condition: Precision: 0.8694915254237288, Recall: 0.8998178506375227, F1: 0.884394788316812\n",
+      "Strict Level for Condition: Precision: 0.7186440677966102, Recall: 0.7723132969034608, F1: 0.7445127304653206\n",
+      "\n",
+      "\n",
+      "Relax Level for Observation: Precision: 0.546875, Recall: 0.38333333333333336, F1: 0.45072788353863386\n",
+      "Strict Level for Observation: Precision: 0.3671875, Recall: 0.2611111111111111, F1: 0.3051948051948052\n",
+      "\n",
+      "\n",
+      "Relax Level for Temporal: Precision: 0.6457142857142857, Recall: 0.8120300751879699, F1: 0.7193845972471926\n",
+      "Strict Level for Temporal: Precision: 0.4828571428571429, Recall: 0.6353383458646616, F1: 0.5487012987012988\n",
+      "\n",
+      "\n",
+      "Relax Level for Drug: Precision: 0.8448275862068966, Recall: 0.9228295819935691, F1: 0.8821075740944017\n",
+      "Strict Level for Drug: Precision: 0.7011494252873564, Recall: 0.7845659163987139, F1: 0.7405159332321699\n",
+      "\n",
+      "\n",
+      "Relax Level for Procedure: Precision: 0.6602209944751382, Recall: 0.721875, F1: 0.6896728335686\n",
+      "Strict Level for Procedure: Precision: 0.5082872928176796, Recall: 0.575, F1: 0.5395894428152493\n",
+      "\n",
+      "\n",
+      "Relax Level for Value: Precision: 0.8157894736842105, Recall: 0.8501529051987767, F1: 0.8326167817979807\n",
+      "Strict Level for Value: Precision: 0.672514619883041, Recall: 0.7033639143730887, F1: 0.6875934230194319\n",
+      "\n",
+      "\n",
+      "Relax Level for Measurement: Precision: 0.767515923566879, Recall: 0.8, F1: 0.7834213734254368\n",
+      "Strict Level for Measurement: Precision: 0.5414012738853503, Recall: 0.6071428571428571, F1: 0.5723905723905723\n",
+      "\n",
+      "\n",
+      "Relax Level for Person: Precision: 0.7639751552795031, Recall: 0.8785714285714286, F1: 0.8172757475083056\n",
+      "Strict Level for Person: Precision: 0.7329192546583851, Recall: 0.8428571428571429, F1: 0.7840531561461795\n",
+      "\n",
+      "\n",
+      "Relax Level for Mood: Precision: 0.36507936507936506, Recall: 0.4489795918367347, F1: 0.4027059291683247\n",
+      "Strict Level for Mood: Precision: 0.1746031746031746, Recall: 0.22448979591836735, F1: 0.19642857142857142\n",
+      "\n",
+      "\n",
+      "Relax Level for Device: Precision: 0.5892857142857143, Recall: 0.8048780487804879, F1: 0.6804123711340206\n",
+      "Strict Level for Device: Precision: 0.5178571428571429, Recall: 0.7073170731707317, F1: 0.5979381443298969\n",
+      "\n",
+      "\n",
+      "Relax Level for Pregnancy_considerations: Precision: 0.52, Recall: 0.3333333333333333, F1: 0.40625000000000006\n",
+      "Strict Level for Pregnancy_considerations: Precision: 0.0, Recall: 0.0, F1: 0.0\n",
+      "\n",
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "attbc_metrics = {\"category\":{},\"overall\":{}}\n",
+    "attbc_metrics[\"overall\"][\"acc\"] = eval_metrics[\"overall\"][\"acc_true\"]/(eval_metrics[\"overall\"][\"acc_true\"]+eval_metrics[\"overall\"][\"acc_false\"])\n",
+    "pred_all = eval_metrics[\"overall\"]['strict_predict'] + eval_metrics[\"overall\"]['relax_predict'] + eval_metrics[\"overall\"]['miss_predict']\n",
+    "pre_relax_all = (eval_metrics[\"overall\"]['strict_predict'] + eval_metrics[\"overall\"]['relax_predict'])/ pred_all\n",
+    "rec_relax_all = (eval_metrics[\"overall\"]['strict_predicted'] + eval_metrics[\"overall\"]['relax_predicted'])/ eval_metrics[\"overall\"]['gs']['count']\n",
+    "f1_relax_all = (2*pre_relax_all*rec_relax_all)/(pre_relax_all+rec_relax_all)\n",
+    "print(f\"Overall Relax Level: Precision: {pre_relax_all}, Recall: {rec_relax_all}, F1: {f1_relax_all}\")\n",
+    "attbc_metrics[\"overall\"][\"relax\"] = {\"f_score\": f1_relax_all, \"precision\":pre_relax_all, \"recall\":rec_relax_all}\n",
+    "\n",
+    "pre_strict_all = eval_metrics[\"overall\"]['strict_predict'] / pred_all\n",
+    "rec_strict_all = eval_metrics[\"overall\"]['strict_predicted'] / eval_metrics[\"overall\"]['gs']['count']\n",
+    "f1_strict_all = (2*pre_strict_all*rec_strict_all)/(pre_strict_all+rec_strict_all)\n",
+    "print(f\"Overall Strict Level: Precision: {pre_strict_all}, Recall: {rec_strict_all}, F1: {f1_strict_all}\")\n",
+    "print('\\n')\n",
+    "attbc_metrics[\"overall\"][\"strict\"] = {\"f_score\": f1_strict_all, \"precision\":pre_strict_all, \"recall\":rec_strict_all}\n",
+    "for i in eval_metrics[\"category\"].keys():\n",
+    "    tt = eval_metrics[\"overall\"]['gs'][i]\n",
+    "    tp = eval_metrics[\"prediction\"][i]['strict'] + eval_metrics[\"prediction\"][i]['relax'] + eval_metrics[\"prediction\"][i]['miss'] if i in eval_metrics[\"prediction\"] else 0\n",
+    "    \n",
+    "    pre_relax = (eval_metrics[\"prediction\"][i]['strict']+eval_metrics[\"prediction\"][i]['relax'])/tp if tp != 0 else 0 \n",
+    "    rec_relax = (eval_metrics[\"category\"][i]['strict']+eval_metrics[\"category\"][i]['relax'])/tt\n",
+    "    f1_relax = (2*pre_relax*rec_relax)/(pre_relax+rec_relax) if (pre_relax+rec_relax) != 0 else 0.0\n",
+    "    print(f\"Relax Level for {i}: Precision: {pre_relax}, Recall: {rec_relax}, F1: {f1_relax}\")\n",
+    "    attbc_metrics[\"category\"][\"relax\"] = attbc_metrics[\"category\"].get(\"relax\", {})\n",
+    "    attbc_metrics[\"category\"][\"relax\"][i] = {\"f_score\": f1_relax, \"precision\":pre_relax, \"recall\":rec_relax}\n",
+    "\n",
+    "    pre_strict = eval_metrics[\"prediction\"][i]['strict']/tp if tp != 0 else 0 \n",
+    "    rec_strict = eval_metrics[\"category\"][i]['strict']/tt\n",
+    "    f1_strict = (2*pre_strict*rec_strict)/(pre_strict+rec_strict) if (pre_strict+rec_strict) != 0 else 0.0\n",
+    "    print(f\"Strict Level for {i}: Precision: {pre_strict}, Recall: {rec_strict}, F1: {f1_strict}\")\n",
+    "    print('\\n')\n",
+    "    attbc_metrics[\"category\"][\"strict\"] = attbc_metrics[\"category\"].get(\"strict\", {})\n",
+    "    attbc_metrics[\"category\"][\"strict\"][i] = {\"f_score\": f1_strict, \"precision\":pre_strict, \"recall\":rec_strict}\n",
+    "json.dump(attbc_metrics, open(f\"{outfolder}/{dataset}_results/{dataset}_{model}_eval_metric.json\", \"w\", encoding=\"utf-8\"), indent=4)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 32,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# function for loading performance output file\n",
+    "def get_perform_metric(perform):\n",
+    "    ent_type = ['Overall']\n",
+    "    pre_strict = [perform['overall']['strict']['precision']]\n",
+    "    rec_strict = [perform['overall']['strict']['recall']]\n",
+    "    f1_strict = [perform['overall']['strict']['f_score']]\n",
+    "    pre_relax = [perform['overall']['relax']['precision']]\n",
+    "    rec_relax = [perform['overall']['relax']['recall']]\n",
+    "    f1_relax = [perform['overall']['relax']['f_score']]\n",
+    "    for k, v in perform['category']['strict'].items():\n",
+    "        ent_type.append(k)\n",
+    "        pre_strict.append(v['precision'])\n",
+    "        rec_strict.append(v['recall'])\n",
+    "        f1_strict.append(v['f_score'])\n",
+    "    for k in ent_type[1:]:\n",
+    "        pre_relax.append(perform['category']['relax'][k]['precision'])\n",
+    "        rec_relax.append(perform['category']['relax'][k]['recall'])\n",
+    "        f1_relax.append(perform['category']['relax'][k]['f_score'])\n",
+    "    return {'type':ent_type, 'pre_strict':pre_strict, 'rec_strict':rec_strict, 'f1_strict':f1_strict, 'pre_relax':pre_relax, 'rec_relax':rec_relax, 'f1_relax':f1_relax}"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 33,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "                        type  pre_strict  rec_strict  f1_strict  pre_relax  \\\n",
+      "0                    Overall    0.615801    0.676568   0.644755   0.771102   \n",
+      "1                  Condition    0.718644    0.772313   0.744513   0.869492   \n",
+      "10                    Device    0.517857    0.707317   0.597938   0.589286   \n",
+      "4                       Drug    0.701149    0.784566   0.740516   0.844828   \n",
+      "7                Measurement    0.541401    0.607143   0.572391   0.767516   \n",
+      "9                       Mood    0.174603    0.224490   0.196429   0.365079   \n",
+      "2                Observation    0.367188    0.261111   0.305195   0.546875   \n",
+      "8                     Person    0.732919    0.842857   0.784053   0.763975   \n",
+      "11  Pregnancy_considerations    0.000000    0.000000   0.000000   0.520000   \n",
+      "5                  Procedure    0.508287    0.575000   0.539589   0.660221   \n",
+      "3                   Temporal    0.482857    0.635338   0.548701   0.645714   \n",
+      "6                      Value    0.672515    0.703364   0.687593   0.815789   \n",
+      "\n",
+      "    rec_relax  f1_relax  \n",
+      "0    0.817492  0.793620  \n",
+      "1    0.899818  0.884395  \n",
+      "10   0.804878  0.680412  \n",
+      "4    0.922830  0.882108  \n",
+      "7    0.800000  0.783421  \n",
+      "9    0.448980  0.402706  \n",
+      "2    0.383333  0.450728  \n",
+      "8    0.878571  0.817276  \n",
+      "11   0.333333  0.406250  \n",
+      "5    0.721875  0.689673  \n",
+      "3    0.812030  0.719385  \n",
+      "6    0.850153  0.832617  \n"
+     ]
+    }
+   ],
+   "source": [
+    "# load performance output file\n",
+    "perf_file = json.load(open(f\"{outfolder}/{dataset}_results/{dataset}_{model}_eval_metric.json\"))\n",
+    "perf_metrics = pd.DataFrame(data=get_perform_metric(perf_file))\n",
+    "df = pd.concat([perf_metrics.loc[[0]], perf_metrics[1:].sort_values(by=['type'])])\n",
+    "# print performance by entity type\n",
+    "print(df)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 34,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "\n",
+      "text/plain": [
+       "<Figure size 432x288 with 1 Axes>"
+      ]
+     },
+     "metadata": {
+      "needs_background": "light"
+     },
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "# plot performance by entity type\n",
+    "df.plot(x='type', y=['f1_strict', 'f1_relax'], title= f'Strict and Relax F1_Score of RoBERTa-MIMIC on FRD', kind=\"bar\", rot=90)\n",
+    "plt.xlabel(\"\");"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 35,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "with open(f\"{outfolder}/{dataset}_results/{dataset}_strict_match_{model}.csv\", \"w\", encoding=\"utf-8\", newline='') as fstrict, open(f\"{outfolder}/{dataset}_results/{dataset}_relax_match_{model}.csv\", \"w\", encoding=\"utf-8\", newline='') as frelax, open(f\"{outfolder}/{dataset}_results/{dataset}_miss_match_{model}.csv\", \"w\", encoding=\"utf-8\", newline='') as fmiss:\n",
+    "    fs_writer = csv.writer(fstrict)\n",
+    "    fr_writer = csv.writer(frelax)\n",
+    "    fm_writer = csv.writer(fmiss)\n",
+    "    fs_writer.writerow(['NCT_ID', 'Entity', 'Offsets', 'Golden_Label', 'Prediction'])\n",
+    "    fr_writer.writerow(['NCT_ID', 'Entity', 'Offsets', 'Golden_Label', 'Prediction'])\n",
+    "    fm_writer.writerow(['NCT_ID', 'Entity', 'Offsets', 'Golden_Label', 'Prediction'])\n",
+    "    for nct_id, nct in predictions.items():\n",
+    "        for kk, vv in nct.items():\n",
+    "            if kk == \"predicted\":\n",
+    "                for k, v in vv.items():\n",
+    "                    if k == \"strict\":\n",
+    "                        for ent in v:\n",
+    "                            for i in ent:\n",
+    "                                i = i.split()\n",
+    "                                fs_writer.writerow([nct_id,i[0],' '.join(i[1:5]),i[5],i[6]])\n",
+    "                            fs_writer.writerow([])\n",
+    "                    if k == \"relax\":\n",
+    "                        for ent in v:\n",
+    "                            for i in ent:\n",
+    "                                i = i.split()\n",
+    "                                fr_writer.writerow([nct_id,i[0],' '.join(i[1:5]),i[5],i[6]])\n",
+    "                            fr_writer.writerow([])\n",
+    "                    if k == \"miss\":\n",
+    "                        for ent in v:\n",
+    "                            for i in ent:\n",
+    "                                i = i.split()\n",
+    "                                fm_writer.writerow([nct_id,i[0],' '.join(i[1:5]),i[5],i[6]])\n",
+    "                            fm_writer.writerow([])"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.8"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}