Diff of /generate_data.py [000000] .. [1de6ed]

Switch to side-by-side view

--- a
+++ b/generate_data.py
@@ -0,0 +1,225 @@
+import argparse
+
+from utils import read_data, save_pickle, read_ade_data
+from biobert_ner.utils_ner import generate_input_files
+from biobert_re.utils_re import generate_re_input_files
+from typing import List, Iterator, Dict
+import warnings
+import os
+import re
+
+labels = ['B-DRUG', 'I-DRUG', 'B-STR', 'I-STR', 'B-DUR', 'I-DUR',
+          'B-ROU', 'I-ROU', 'B-FOR', 'I-FOR', 'B-ADE', 'I-ADE',
+          'B-DOS', 'I-DOS', 'B-REA', 'I-REA', 'B-FRE', 'I-FRE', 'O']
+
+
+def parse_arguments():
+    """Parses program arguments"""
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--task", type=str,
+                        help="Task to be completed. 'NER', 'RE'. Default is 'NER'.",
+                        default="NER")
+
+    parser.add_argument("--input_dir", type=str,
+                        help="Directory with txt and ann files. Default is 'data/'.",
+                        default="data/")
+
+    parser.add_argument("--ade_dir", type=str,
+                        help="Directory with ADE corpus. Default is None.",
+                        default=None)
+
+    parser.add_argument("--target_dir", type=str,
+                        help="Directory to save files. Default is 'dataset/'.",
+                        default='dataset/')
+
+    parser.add_argument("--max_seq_len", type=int,
+                        help="Maximum sequence length. Default is 512.",
+                        default=512)
+
+    parser.add_argument("--dev_split", type=float,
+                        help="Ratio of dev data. Default is 0.1",
+                        default=0.1)
+
+    parser.add_argument("--tokenizer", type=str,
+                        help="The tokenizer to use. 'scispacy', 'scispacy_plus', 'biobert-base', 'biobert-large', 'default'.",
+                        default="scispacy")
+
+    parser.add_argument("--ext", type=str,
+                        help="Extension of target file. Default is txt.",
+                        default="txt")
+
+    parser.add_argument("--sep", type=str,
+                        help="Token-label separator. Default is a space.",
+                        default=" ")
+
+    arguments = parser.parse_args()
+    return arguments
+
+
+def default_tokenizer(sequence: str) -> List[str]:
+    """A tokenizer that splits sequence by a whitespace."""
+    words = re.split("\n| |\t", sequence)
+    tokens = []
+    for word in words:
+        word = word.strip()
+
+        if not word:
+            continue
+
+        tokens.append(word)
+
+    return tokens
+
+
+def scispacy_plus_tokenizer(sequence: str, scispacy_tok=None) -> Iterator[str]:
+    """
+    Runs the scispacy tokenizer and removes all tokens with
+    just whitespace characters
+    """
+    if scispacy_tok is None:
+        import en_ner_bc5cdr_md
+        scispacy_tok = en_ner_bc5cdr_md.load().tokenizer
+
+    scispacy_tokens = list(map(lambda x: str(x), scispacy_tok(sequence)))
+    tokens = filter(lambda t: not (' ' in t or '\n' in t or '\t' in t), scispacy_tokens)
+
+    return tokens
+
+
+def ner_generator(files: Dict[str, tuple], args) -> None:
+    """Generates files for NER"""
+    # Generate train, dev, test files
+    for filename, data in files.items():
+        generate_input_files(ehr_records=data[0], ade_records=data[1],
+                             filename=args.target_dir + filename + '.' + args.ext,
+                             max_len=args.max_seq_len, sep=args.sep)
+        save_pickle(args.target_dir + filename, {"EHR": data[0], "ADE": data[1]})
+
+    # Generate labels file
+    with open(args.target_dir + 'labels.txt', 'w') as file:
+        output_labels = map(lambda x: x + '\n', labels)
+        file.writelines(output_labels)
+
+    filenames = [name for files in map(
+        lambda x: [x + '.' + args.ext, x + '.pkl'],
+        list(files.keys()))
+                 for name in files]
+
+    print("\nGenerating files successful. Files generated: ",
+          ', '.join(filenames), ', labels.txt', sep='')
+
+
+def re_generator(files: Dict[str, tuple], args):
+    """Generates files for RE"""
+    for filename, data in files.items():
+        generate_re_input_files(ehr_records=data[0], ade_records=data[1],
+                                filename=args.target_dir + filename + '.' + args.ext,
+                                max_len=args.max_seq_len, sep=args.sep,
+                                is_test=data[2], is_label=data[3])
+
+    save_pickle(args.target_dir + 'train', {"EHR": files['train'][0], "ADE": files['train'][1]})
+    save_pickle(args.target_dir + 'test',  {"EHR": files['test'][0],  "ADE": files['test'][1]})
+
+    print("\nGenerating files successful. Files generated: ",
+          'train.tsv,', 'dev.tsv,', 'test.tsv,',
+          'test_labels.tsv,', 'train_rel.pkl,', 'test_rel.pkl,', 'test_labels_rel.pkl', sep=' ')
+
+
+def main():
+    args = parse_arguments()
+
+    if args.target_dir[-1] != '/':
+        args.target_dir += '/'
+
+    if args.sep == "tab":
+        args.sep = "\t"
+
+    if not os.path.isdir(args.target_dir):
+        os.mkdir(args.target_dir)
+
+    if args.tokenizer == "default":
+        tokenizer = default_tokenizer
+        is_bert_tokenizer = False
+
+    elif args.tokenizer == "scispacy":
+        import en_ner_bc5cdr_md
+        tokenizer = en_ner_bc5cdr_md.load().tokenizer
+        is_bert_tokenizer = False
+
+    elif args.tokenizer == 'scispacy_plus':
+        import en_ner_bc5cdr_md
+        scispacy_tok = en_ner_bc5cdr_md.load().tokenizer
+        scispacy_plus_tokenizer.__defaults__ = (scispacy_tok,)
+
+        tokenizer = scispacy_plus_tokenizer
+        is_bert_tokenizer = False
+
+    elif args.tokenizer == 'biobert-large':
+        from transformers import AutoTokenizer
+        biobert = AutoTokenizer.from_pretrained(
+            "dmis-lab/biobert-large-cased-v1.1")
+
+        args.max_seq_len -= biobert.num_special_tokens_to_add()
+        tokenizer = biobert.tokenize
+        is_bert_tokenizer = True
+
+    elif args.tokenizer == 'biobert-base':
+        from transformers import AutoTokenizer
+        biobert = AutoTokenizer.from_pretrained(
+            "dmis-lab/biobert-base-cased-v1.1")
+
+        args.max_seq_len -= biobert.num_special_tokens_to_add()
+        tokenizer = biobert.tokenize
+        is_bert_tokenizer = True
+
+    else:
+        warnings.warn("Tokenizer named " + args.tokenizer + " not found."
+                      "Using default tokenizer instead. Acceptable values"
+                      "include 'scispacy', 'biobert-base', 'biobert-large',"
+                      "and 'default'.")
+        tokenizer = default_tokenizer
+        is_bert_tokenizer = False
+
+    print("\nReading data\n")
+    train_dev, test = read_data(data_dir=args.input_dir,
+                                tokenizer=tokenizer,
+                                is_bert_tokenizer=is_bert_tokenizer,
+                                verbose=1)
+
+    if args.ade_dir is not None:
+        ade_train_dev = read_ade_data(ade_data_dir=args.ade_dir, verbose=1)
+
+        ade_dev_split_idx = int((1 - args.dev_split) * len(ade_train_dev))
+        ade_train = ade_train_dev[:ade_dev_split_idx]
+        ade_devel = ade_train_dev[ade_dev_split_idx:]
+
+    else:
+        ade_train_dev = None
+        ade_train = None
+        ade_devel = None
+
+    print('\n')
+
+    # Data is already shuffled, just split for dev set
+    dev_split_idx = int((1 - args.dev_split) * len(train_dev))
+    train = train_dev[:dev_split_idx]
+    devel = train_dev[dev_split_idx:]
+
+    # Data for NER
+    if args.task.lower() == 'ner':
+        files = {'train': (train, ade_train), 'train_dev': (train_dev, ade_train_dev),
+                 'devel': (devel, ade_devel), 'test': (test, None)}
+
+        ner_generator(files, args)
+
+    # Data for RE
+    elif args.task.lower() == 're':
+        # {dataset_name: (ehr_data, ade_data, is_test, is_label)}
+        files = {'train': (train, ade_train, False, True), 'dev': (devel, ade_devel, False, True),
+                 'test': (test, None, True, False), 'test_labels': (test, None, True, True)}
+
+        re_generator(files, args)
+
+
+if __name__ == '__main__':
+    main()