Diff of /biobert_ner/run_ner.py [000000] .. [1de6ed]

Switch to side-by-side view

--- a
+++ b/biobert_ner/run_ner.py
@@ -0,0 +1,284 @@
+import logging
+import os
+import sys
+
+from dataclasses import dataclass, field
+from typing import Dict, List, Optional, Tuple
+
+import numpy as np
+from seqeval.metrics import f1_score, precision_score, recall_score
+from torch import nn
+
+from transformers import (
+    AutoConfig,
+    AutoModelForTokenClassification,
+    AutoTokenizer,
+    EvalPrediction,
+    HfArgumentParser,
+    Trainer,
+    TrainingArguments,
+    set_seed
+)
+from utils_ner import NerDataset, Split, get_labels
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class ModelArguments:
+    """
+    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
+    """
+    model_name_or_path: str = field(
+        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
+    )
+    config_name: Optional[str] = field(
+        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
+    )
+    tokenizer_name: Optional[str] = field(
+        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
+    )
+    use_fast: bool = field(default=False, metadata={"help": "Set this flag to use fast tokenization."})
+    # If you want to tweak more attributes on your tokenizer, you should do it in a distinct script,
+    # or just modify its tokenizer_config.json.
+    cache_dir: Optional[str] = field(
+        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
+    )
+
+
+@dataclass
+class DataTrainingArguments:
+    """
+    Arguments pertaining to what data we are going to input our model for training and eval.
+    """
+    data_dir: str = field(
+        metadata={"help": "The input data dir. Should contain the .txt files for a CoNLL-2003-formatted task."}
+    )
+    labels: Optional[str] = field(
+        default=None,
+        metadata={"help": "Path to a file containing all labels. If not specified, CoNLL-2003 labels are used."},
+    )
+    max_seq_length: int = field(
+        default=128,
+        metadata={
+            "help": "The maximum total input sequence length after tokenization. Sequences longer "
+            "than this will be truncated, sequences shorter will be padded."
+        },
+    )
+    overwrite_cache: bool = field(
+        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
+    )
+
+
+def main():
+    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
+    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
+        # If we pass only one argument to the script and it's the path to a json file,
+        # let's parse it to get our arguments.
+        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
+    else:
+        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+    if (
+        os.path.exists(training_args.output_dir)
+        and os.listdir(training_args.output_dir)
+        and training_args.do_train
+        and not training_args.overwrite_output_dir
+    ):
+        raise ValueError(
+            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
+        )
+
+    # Setup logging
+    logging.basicConfig(
+        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
+        datefmt="%m/%d/%Y %H:%M:%S",
+        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
+    )
+    logger.warning(
+        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
+        training_args.local_rank,
+        training_args.device,
+        training_args.n_gpu,
+        bool(training_args.local_rank != -1),
+        training_args.fp16,
+    )
+    logger.info("Training/evaluation parameters %s", training_args)
+
+    # Set seed
+    set_seed(training_args.seed)
+
+    # Prepare CONLL-2003 task
+    labels = get_labels(data_args.labels)
+    label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)}
+    num_labels = len(labels)
+
+    # Load pretrained model and tokenizer
+    #
+    # Distributed training:
+    # The .from_pretrained methods guarantee that only one local process can concurrently
+    # download model & vocab.
+
+    config = AutoConfig.from_pretrained(
+        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
+        num_labels=num_labels,
+        id2label=label_map,
+        label2id={label: i for i, label in enumerate(labels)},
+        cache_dir=model_args.cache_dir,
+    )
+    tokenizer = AutoTokenizer.from_pretrained(
+        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
+        cache_dir=model_args.cache_dir,
+        use_fast=model_args.use_fast,
+    )
+    model = AutoModelForTokenClassification.from_pretrained(
+        model_args.model_name_or_path,
+        from_tf=bool(".ckpt" in model_args.model_name_or_path),
+        config=config,
+        cache_dir=model_args.cache_dir,
+    )
+
+    # Get datasets
+    train_dataset = (
+        NerDataset(
+            data_dir=data_args.data_dir,
+            tokenizer=tokenizer,
+            labels=labels,
+            model_type=config.model_type,
+            max_seq_length=data_args.max_seq_length,
+            overwrite_cache=data_args.overwrite_cache,
+            mode=Split.train,
+        )
+        if training_args.do_train
+        else None
+    )
+    eval_dataset = (
+        NerDataset(
+            data_dir=data_args.data_dir,
+            tokenizer=tokenizer,
+            labels=labels,
+            model_type=config.model_type,
+            max_seq_length=data_args.max_seq_length,
+            overwrite_cache=data_args.overwrite_cache,
+            mode=Split.dev,
+        )
+        if training_args.do_eval
+        else None
+    )
+
+    def align_predictions(predictions: np.ndarray, label_ids: np.ndarray) \
+            -> Tuple[List[List[str]], List[List[str]]]:
+        preds = np.argmax(predictions, axis=2)
+
+        batch_size, seq_len = preds.shape
+
+        out_label_list = [[] for _ in range(batch_size)]
+        preds_list = [[] for _ in range(batch_size)]
+
+        for i in range(batch_size):
+            for j in range(seq_len):
+                if label_ids[i, j] != nn.CrossEntropyLoss().ignore_index:
+                    out_label_list[i].append(label_map[label_ids[i][j]])
+                    preds_list[i].append(label_map[preds[i][j]])
+        return preds_list, out_label_list
+
+    def compute_metrics(p: EvalPrediction) -> Dict:
+        preds_list, out_label_list = align_predictions(p.predictions, p.label_ids)
+        
+        return {
+            "precision": precision_score(out_label_list, preds_list),
+            "recall": recall_score(out_label_list, preds_list),
+            "f1": f1_score(out_label_list, preds_list),
+        }
+
+    # Initialize our Trainer
+    trainer = Trainer(
+        model=model,
+        args=training_args,
+        train_dataset=train_dataset,
+        eval_dataset=eval_dataset,
+        compute_metrics=compute_metrics,
+    )
+
+    # Training
+    if training_args.do_train:
+        trainer.train(
+            model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
+        )
+        trainer.save_model()
+        if trainer.is_world_master():
+            tokenizer.save_pretrained(training_args.output_dir)
+
+    # Evaluation
+    results = {}
+    if training_args.do_eval:
+        logger.info("*** Evaluate ***")
+
+        result = trainer.evaluate()
+        
+        output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
+        if trainer.is_world_master():
+            with open(output_eval_file, "w") as writer:
+                logger.info("***** Eval results *****")
+                for key, value in result.items():
+                    logger.info("  %s = %s", key, value)
+                    writer.write("%s = %s\n" % (key, value))
+
+            results.update(result)
+
+    # Predict
+    if training_args.do_predict:
+        test_dataset = NerDataset(
+            data_dir=data_args.data_dir,
+            tokenizer=tokenizer,
+            labels=labels,
+            model_type=config.model_type,
+            max_seq_length=data_args.max_seq_length,
+            overwrite_cache=data_args.overwrite_cache,
+            mode=Split.test,
+        )
+
+        predictions, label_ids, metrics = trainer.predict(test_dataset)
+        logger.info("Predictions shape: " + str(predictions.shape))
+
+        preds_list, _ = align_predictions(predictions, label_ids)
+        
+        # Save predictions
+        output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt")
+        if trainer.is_world_master():
+            with open(output_test_results_file, "w") as writer:
+                logger.info("***** Test results *****")
+                for key, value in metrics.items():
+                    logger.info("  %s = %s", key, value)
+                    writer.write("%s = %s\n" % (key, value))
+
+        output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt")
+        prev_pred = ""
+        if trainer.is_world_master():
+            with open(output_test_predictions_file, "w") as writer:
+                with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f:
+                    example_id = 0
+                    for line in f:
+                        if line.startswith("##"):
+                            if prev_pred != "O":
+                                prev_pred = "I-" + prev_pred.split('-')[-1]
+                            output_line = line.split()[0] + " " + prev_pred + "\n"
+                            writer.write(output_line)
+                        elif line.startswith("-DOCSTART-") or line == "" or line == "\n":
+                            writer.write(line)
+                            if not preds_list[example_id]:
+                                example_id += 1
+                        elif preds_list[example_id]:
+                            prev_pred = preds_list[example_id].pop(0)
+                            output_line = line.split()[0] + " " + prev_pred + "\n"
+                            writer.write(output_line)
+                        else:
+                            logger.warning(
+                                "Example %d, Example: %s" % (example_id, line)
+                            )
+            
+    return results
+
+
+if __name__ == "__main__":
+    main()