a b/biobert_re/run_re.py
1
import dataclasses
2
import logging
3
import os
4
import sys
5
from dataclasses import dataclass, field
6
from typing import Callable, Dict, Optional
7
8
import numpy as np
9
10
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction
11
from transformers import GlueDataTrainingArguments as DataTrainingArguments
12
from transformers import (
13
    HfArgumentParser,
14
    Trainer,
15
    TrainingArguments,
16
    set_seed,
17
)
18
19
from data_processor import glue_output_modes, glue_tasks_num_labels
20
from utils_re import REDataset, get_eval_results
21
22
from metrics import glue_compute_metrics
23
24
logger = logging.getLogger(__name__)
25
26
27
@dataclass
28
class ModelArguments:
29
    """
30
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
31
    """
32
33
    model_name_or_path: str = field(
34
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
35
    )
36
    config_name: Optional[str] = field(
37
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
38
    )
39
    tokenizer_name: Optional[str] = field(
40
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
41
    )
42
    cache_dir: Optional[str] = field(
43
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
44
    )
45
    warmup_proportion: Optional[float] = field(
46
        default=0.1, metadata={"help": "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% of training."}
47
    )
48
49
50
def main():
51
52
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
53
54
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
55
        # If we pass only one argument to the script and it's the path to a json file,
56
        # let's parse it to get our arguments.
57
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
58
    else:
59
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
60
61
    if (
62
        os.path.exists(training_args.output_dir)
63
        and os.listdir(training_args.output_dir)
64
        and training_args.do_train
65
        and not training_args.overwrite_output_dir
66
    ):
67
        raise ValueError(
68
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
69
        )
70
71
    # Setup logging
72
    logging.basicConfig(
73
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
74
        datefmt="%m/%d/%Y %H:%M:%S",
75
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
76
    )
77
    logger.warning(
78
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
79
        training_args.local_rank,
80
        training_args.device,
81
        training_args.n_gpu,
82
        bool(training_args.local_rank != -1),
83
        training_args.fp16,
84
    )
85
86
    # Set seed
87
    set_seed(training_args.seed)
88
89
    try:
90
        num_labels = glue_tasks_num_labels[data_args.task_name]
91
        output_mode = glue_output_modes[data_args.task_name]
92
    except KeyError:
93
        raise ValueError("Task not found: %s" % data_args.task_name)
94
95
    # Load tokenizer
96
    tokenizer = AutoTokenizer.from_pretrained(
97
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
98
        cache_dir=model_args.cache_dir,
99
    )
100
101
    # Get datasets
102
    train_dataset = (
103
        REDataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None
104
    )
105
    eval_dataset = (
106
        REDataset(data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir)
107
        if training_args.do_eval
108
        else None
109
    )
110
    test_dataset = (
111
        REDataset(data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir)
112
        if training_args.do_predict
113
        else None
114
    )
115
116
    # Load pretrained model
117
    # Distributed training:
118
    # The .from_pretrained methods guarantee that only one local process can concurrently
119
    # download model & vocab.
120
121
    # Currently, this code do not support distributed training.
122
    training_args.warmup_steps = int(model_args.warmup_proportion * (len(train_dataset) / training_args.per_device_train_batch_size) * training_args.num_train_epochs)
123
    logger.info("Training/evaluation parameters %s", training_args)
124
125
    config = AutoConfig.from_pretrained(
126
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
127
        num_labels=num_labels,
128
        finetuning_task=data_args.task_name,
129
        cache_dir=model_args.cache_dir,
130
    )
131
    model = AutoModelForSequenceClassification.from_pretrained(
132
        model_args.model_name_or_path,
133
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
134
        config=config,
135
        cache_dir=model_args.cache_dir,
136
    )
137
138
    def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
139
        def compute_metrics_fn(p: EvalPrediction):
140
            preds = np.argmax(p.predictions, axis=1)
141
            return glue_compute_metrics(preds, p.label_ids)
142
143
        return compute_metrics_fn
144
145
    # Initialize our Trainer
146
    trainer = Trainer(
147
        model=model,
148
        args=training_args,
149
        train_dataset=train_dataset,
150
        eval_dataset=eval_dataset,
151
        compute_metrics=build_compute_metrics_fn(data_args.task_name),
152
    )
153
154
    # Training
155
    if training_args.do_train:
156
        trainer.train(
157
            model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
158
        )
159
        trainer.save_model()
160
        # For convenience, we also re-save the tokenizer to the same directory,
161
        # so that you can share your model easily on huggingface.co/models =)
162
        if trainer.is_world_master():
163
            tokenizer.save_pretrained(training_args.output_dir)
164
165
    # Evaluation
166
    eval_results = {}
167
    if training_args.do_eval:
168
        logger.info("*** Evaluate ***")
169
170
        # Loop to handle MNLI double evaluation (matched, mis-matched)
171
        eval_datasets = [eval_dataset]
172
173
        for eval_dataset in eval_datasets:
174
            trainer.compute_metrics = build_compute_metrics_fn(eval_dataset.args.task_name)
175
            eval_result = trainer.evaluate(eval_dataset=eval_dataset)
176
177
            output_eval_file = os.path.join(
178
                training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt"
179
            )
180
            if trainer.is_world_master():
181
                with open(output_eval_file, "w") as writer:
182
                    logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name))
183
                    for key, value in eval_result.items():
184
                        logger.info("  %s = %s", key, value)
185
                        writer.write("%s = %s\n" % (key, value))
186
187
            eval_results.update(eval_result)
188
189
    if training_args.do_predict:
190
        logging.info("*** Test ***")
191
        test_datasets = [test_dataset]
192
193
        for test_dataset in test_datasets:
194
            predictions = trainer.predict(test_dataset=test_dataset).predictions
195
            if output_mode == "classification":
196
                predictions = np.argmax(predictions, axis=1)
197
198
            output_test_file = os.path.join(
199
                training_args.output_dir,
200
                f"test_predictions.txt"
201
                )
202
            if trainer.is_world_master():
203
                with open(output_test_file, "w") as writer:
204
                    logger.info("***** Test results {} *****".format(test_dataset.args.task_name))
205
                    writer.write("index\tprediction\n")
206
                    for index, item in enumerate(predictions):
207
                        item = test_dataset.get_labels()[item]
208
                        writer.write("%d\t%s\n" % (index, item))
209
210
                output_label_file = os.path.join(
211
                    data_args.data_dir,
212
                    f"test_labels.tsv"
213
                    )
214
215
                output_test_result_file = os.path.join(
216
                    training_args.output_dir,
217
                    f"test_results.txt"
218
                    )
219
220
                test_result = get_eval_results(output_label_file, output_test_file)
221
                with open(output_test_result_file, "w") as writer:
222
                    for key, value in test_result.items():
223
                        logger.info("  %s = %s", key, value)
224
                        writer.write("%s = %s\n" % (key, value))
225
226
    return eval_results
227
228
229
if __name__ == "__main__":
230
    main()