a b/bert_train_predict.py
1
import transformers
2
import torch
3
import pandas as pd
4
import argparse
5
import random
6
import numpy as np
7
from sklearn.metrics import classification_report, roc_auc_score, precision_recall_fscore_support
8
from sklearn.preprocessing import MultiLabelBinarizer
9
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSequenceClassification, AutoConfig, BertForSequenceClassification
10
from ray.tune.schedulers import PopulationBasedTraining, ASHAScheduler
11
import ray
12
from ray import tune
13
from ray.tune import CLIReporter
14
from datasets import Dataset, load_dataset, DatasetDict, concatenate_datasets
15
from functools import partial
16
from utils import grade_preproc, group_labels, undersample_dataset, data_split
17
import os
18
from collections import Counter
19
import pathlib
20
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType
21
from torch import nn
22
from ray.tune.search.bayesopt import BayesOptSearch
23
from ray.tune.search.hyperopt import HyperOptSearch
24
from sklearn.utils import class_weight
25
26
# Disable logging for raytune, but it will still make folders and jsons for experiment states
27
# They're not big files, but should be deleted PATH: ./to_be_deleted_rayArtifact
28
os.environ["TUNE_DISABLE_AUTO_CALLBACK_LOGGERS"] = "1"
29
30
parser = argparse.ArgumentParser()
31
parser.add_argument('--logdir', type=str, help='The path to the directory to temporarily store checkpoints')
32
parser.add_argument('--evaldir', type=str, help='The path to the directory to store model evaluation results')
33
parser.add_argument('--num_trials', type=int, help='Number hyperparameter trials', default=5)
34
parser.add_argument('--seqlens', type=str, help='list of sequence lengths to search for ray', default='20,35,50')
35
parser.add_argument('--batches', type=str, help='list of batch sizes to search for ray', default='32,64,128')
36
parser.add_argument('--model', type=str, help='select model to run classification: (BERT, ROBERTA, BIOBERT)', default='bert-base-uncased')
37
parser.add_argument('--synth_data', type=str, help='path to synthetic data file', default='')
38
parser.add_argument('--undersample', type=float, default=0.0, help='undersample majority class in train set by proportion. E.g. 0.2 will keep 20 percent of majority class data')
39
parser.add_argument('--ray', action='store_true', help='tune hyperparameters')
40
parser.add_argument('--adverse', action='store_true', help='for non adverse synthetic data')
41
parser.add_argument('--epochs', type=int, default=5)
42
43
args = parser.parse_args()
44
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
45
46
SEED_VAL = 42
47
random.seed(SEED_VAL)
48
np.random.seed(SEED_VAL)
49
torch.manual_seed(SEED_VAL)
50
torch.cuda.manual_seed_all(SEED_VAL)
51
52
MLB = MultiLabelBinarizer()
53
if args.adverse:
54
    LABELS = {'TRANSPORTATION_distance', 'TRANSPORTATION_resource',
55
        'TRANSPORTATION_other', 'HOUSING_poor', 'HOUSING_undomiciled','HOUSING_other',
56
        'RELATIONSHIP_divorced', 'RELATIONSHIP_widowed', 'RELATIONSHIP_single',
57
        'PARENT','EMPLOYMENT_underemployed','EMPLOYMENT_unemployed', 'EMPLOYMENT_disability','SUPPORT_minus'}
58
else:
59
    LABELS = {'TRANSPORTATION_distance', 'TRANSPORTATION_resource',
60
        'TRANSPORTATION_other', 'HOUSING_poor', 'HOUSING_undomiciled',
61
        'HOUSING_other', 'RELATIONSHIP_married', 'RELATIONSHIP_partnered',
62
        'RELATIONSHIP_divorced', 'RELATIONSHIP_widowed', 'RELATIONSHIP_single',
63
        'PARENT','EMPLOYMENT_employed', 'EMPLOYMENT_underemployed',
64
        'EMPLOYMENT_unemployed', 'EMPLOYMENT_disability', 'EMPLOYMENT_retired',
65
        'EMPLOYMENT_student', 'SUPPORT_plus', 'SUPPORT_minus'}
66
67
BROAD_LABELS = {lab.split('_')[0] for lab in LABELS}
68
BROAD_LABELS.add('<NO_SDOH>')
69
70
LABEL_BROAD_NARROW = LABELS.union(BROAD_LABELS)
71
if args.ray:
72
    ray.init(log_to_driver=False)
73
    
74
75
class BCETrainer(Trainer):
76
    def compute_loss(self, model, inputs, return_outputs=False):
77
        labels = inputs.get("labels").to(DEVICE) # batch[0, 1, 0, 1, 0, 0]
78
        # forward pass
79
        outputs = model(inputs['input_ids'])
80
        logits = outputs.get("logits").to(DEVICE)
81
        # compute custom loss (suppose one has 3 labels with different weights)
82
        loss_fct = nn.BCEWithLogitsLoss().to(DEVICE)
83
        loss = loss_fct(logits.to(DEVICE), labels.float().to(DEVICE))
84
        return (loss, outputs) if return_outputs else loss
85
    
86
87
def undersample(df, label, keep_percent):
88
    """
89
    Undersamples the majority class in a Pandas dataframe to balance the classes.
90
91
    Parameters:
92
    df (pandas.DataFrame): The dataframe to undersample.
93
    keep_percent (float): The percentage of the majority class to keep.
94
95
    Returns:
96
    pandas.DataFrame: The undersampled dataframe.
97
    """
98
    # Find the majority class based on the labels column
99
    counts = df[label].value_counts()
100
    majority_class = counts.idxmax()
101
102
    # Get the indices of rows in the majority class
103
    majority_indices = df[df[label] == majority_class].index
104
105
    # Calculate the number of majority class rows to keep
106
    num_majority_keep = int(keep_percent * counts[majority_class])
107
108
    # Get a random subset of the majority class rows to keep
109
    majority_keep_indices = np.random.choice(majority_indices, num_majority_keep, replace=False)
110
111
    # Get the indices of rows in the minority class
112
    minority_indices = df[df[label] != majority_class].index
113
114
    # Combine the majority class subset and the minority class rows
115
    undersampled_indices = np.concatenate([majority_keep_indices, minority_indices])
116
117
    # Return the undersampled dataframe
118
    return df.loc[undersampled_indices]
119
120
121
def generate_label_list(row: pd.DataFrame) -> str:
122
    """
123
    Generate a label list based on the given row from a Pandas DataFrame.
124
125
    Args:
126
        row (pd.DataFrame): A row from a Pandas DataFrame.
127
128
    Returns:
129
        str: A comma-separated string of labels extracted from the row.
130
131
    Examples:
132
        >>> df = pd.DataFrame({'label1_1': [1], 'label2_0': [0], 'label3_1': [1]})
133
        >>> generate_label_list(df.iloc[0])
134
        'label1,label3'
135
136
        >>> df = pd.DataFrame({'label2_0': [0], 'label3_0': [0]})
137
        >>> generate_label_list(df.iloc[0])
138
        '<NO_SDOH>'
139
    """
140
    labels = set()
141
    for col_name, value in row.items():
142
        if col_name in LABELS and value == 1:
143
            labels.add(col_name.split('_')[0])
144
    if len(labels) == 0:
145
        labels.add('<NO_SDOH>')
146
    return ','.join(list(labels))
147
148
149
def compute_metrics(pred):
150
    """
151
    Calculate Evaluation metrics
152
    """
153
    labels = pred.label_ids
154
    logits = torch.tensor(pred.predictions)
155
    act = nn.Sigmoid()
156
    probs = act(logits)
157
    preds = (probs>= 0.5).int()
158
    
159
    # labels = mlb.fit_transform(labels)
160
    # preds = MLB.transform(preds)
161
    prec, rec, f1, _ = precision_recall_fscore_support(labels, preds)
162
    micro_f1  = precision_recall_fscore_support(labels, preds, average='micro')[2]
163
    weight_f1 = precision_recall_fscore_support(labels, preds, average='weighted')[2]
164
    macro_f1 = precision_recall_fscore_support(labels, preds, average='macro')[2]
165
166
    metrics_out = {'macro_f1':macro_f1, 'micro_f1': micro_f1, 'weighted_f1': weight_f1}
167
    for i, lab in enumerate(list(MLB.classes_)):
168
        metrics_out['precision_'+str(lab)] = prec[i]
169
        metrics_out['recall_'+str(lab)] = rec[i]
170
        metrics_out['f1_'+str(lab)] = f1[i]
171
    print(classification_report(labels, preds, target_names=MLB.classes_))
172
    return metrics_out
173
174
175
def train_hf(config, dataset):
176
    # Define the Trainer and TrainingArguments objects
177
    # Initialize the tokenizer with the sequence_length parameter
178
    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True)
179
    def tokenize(batch):
180
        return tokenizer(batch['text'], padding='max_length', truncation=True, return_tensors="pt", max_length=config["sequence_length"])
181
    
182
    tokenized_dataset = dataset.map(tokenize, batched=True, remove_columns=["text"])
183
    training_args = TrainingArguments(
184
        output_dir=args.logdir,
185
        per_device_train_batch_size=config["batch_size"],
186
        per_device_eval_batch_size=config["batch_size"],
187
        learning_rate=config["learning_rate"],
188
        num_train_epochs=config["epochs"],
189
        disable_tqdm=False,
190
        bf16=True, # bfloat16 training
191
        optim='adamw_hf',
192
        logging_dir=f"{args.logdir}/logs",
193
        overwrite_output_dir = True,
194
        evaluation_strategy = 'epoch',
195
        weight_decay= config["weight_decay"],      
196
        save_strategy='epoch',
197
        save_total_limit = 1,
198
        load_best_model_at_end=True,
199
        metric_for_best_model="macro_f1",
200
        seed = SEED_VAL,
201
        gradient_accumulation_steps = config["gradient_accumulation_steps"]
202
        )
203
    
204
    model = AutoModelForSequenceClassification.from_pretrained(
205
        pretrained_model_name_or_path=args.model,
206
        num_labels=len(dataset['train']['labels'][0]),
207
        attention_probs_dropout_prob=config["hidden_dropout_prob"],
208
        hidden_dropout_prob=config["hidden_dropout_prob"]
209
        )
210
211
    # clws = torch.tensor([config["class_weight0"], config["class_weight1"]], dtype=torch.float).to(DEVICE)
212
    trainer = BCETrainer(
213
        args=training_args,
214
        tokenizer=tokenizer,
215
        train_dataset=tokenized_dataset['train'],
216
        eval_dataset=tokenized_dataset['dev'],
217
        model=model,
218
        compute_metrics=compute_metrics,
219
        )
220
221
    # Train the model and return the evaluation
222
    trainer.train()
223
    eval_result = trainer.evaluate()
224
    if args.ray:
225
        tune.report(eval_result)
226
    else:
227
        return eval_result
228
229
230
def main(args):
231
    train_data = pd.read_csv('./data/train_sents.csv')
232
    dev_data = pd.read_csv('./data/dev_sents.csv')
233
234
    train_data.fillna(value={'text':''}, inplace=True)
235
    dev_data.fillna(value={'text':''}, inplace=True)
236
237
    dev_text = dev_data['text'].tolist()
238
    dev_labels = dev_data.apply(generate_label_list, axis=1).tolist()
239
240
    train_data['LABEL'] = train_data.apply(generate_label_list, axis=1).tolist()
241
    
242
    if args.undersample:
243
        train_data = undersample(train_data, label='LABEL', keep_percent=args.undersample)
244
    train_text = train_data['text'].tolist()
245
    train_labels = train_data['LABEL'].tolist()
246
247
    if args.synth_data:
248
        synthetic_data = pd.read_csv(args.synth_data)
249
        if args.adverse:
250
            synthetic_data = synthetic_data[synthetic_data['adverse']=='adverse']
251
        synthetic_data.reset_index(inplace=True, drop=True)
252
253
        binary_synthetic = pd.get_dummies(synthetic_data['label'])
254
        binary_synthetic['text'] = synthetic_data['text']
255
        synth_labels = binary_synthetic.apply(generate_label_list, axis=1).tolist()
256
        synth_text = synthetic_data['text'].tolist()
257
258
        train_text.extend(synth_text)
259
        train_labels.extend(synth_labels)
260
261
    train_labels = [labs.split(',') for labs in train_labels]
262
    train_labs_mlb = MLB.fit_transform(train_labels)
263
    train_labs_mlb = [ar.tolist() for ar in train_labs_mlb]
264
265
    dev_labels = [labs.split(',') for labs in dev_labels]
266
    dev_labs_mlb = MLB.transform(dev_labels)
267
    dev_labs_mlb = [ar.tolist() for ar in dev_labs_mlb]
268
269
    train_t5 = pd.DataFrame({'text':train_text, 'labels':train_labs_mlb})
270
    dev_t5 = pd.DataFrame({'text':dev_text, 'labels':dev_labs_mlb})
271
272
    train_dataset = Dataset.from_pandas(train_t5)
273
    dev_dataset = Dataset.from_pandas(dev_t5)
274
275
    dataset = DatasetDict()
276
    dataset['train'] = train_dataset
277
    dataset['dev'] = dev_dataset
278
279
    seq_length_search = [int(x) for x in args.seqlens.split(',')]
280
    batch_size_search = [int(x) for x in args.batches.split(',')]
281
282
    params_dict ={
283
            'model':args.model,
284
            'undersample_bool':args.undersample
285
            }
286
    
287
    if args.ray:
288
        if args.undersample:
289
            usample = args.undersample
290
        else:
291
            usample = 1
292
        config_space = {
293
            "learning_rate": tune.loguniform(1e-5, 1e-3),
294
            "batch_size": tune.choice(batch_size_search),
295
            "hidden_dropout_prob": tune.uniform(0.1, 0.5),
296
            "undersample": usample,
297
            "weight_decay": tune.loguniform(1e-8, 1e-5),
298
            "sequence_length": tune.choice(seq_length_search),
299
            "gradient_accumulation_steps": 3,
300
            "epochs": args.epochs  
301
            }
302
303
        scheduler = ASHAScheduler(
304
            metric="_metric/eval_macro_f1",
305
            mode="max",
306
            grace_period=1,
307
            reduction_factor=2
308
            )
309
        
310
        met_cols = ["training_iteration","macro_f1", "micro_f1", "precision", "recall"]
311
        for i in range(len(train_labs_mlb[0])):
312
            met_cols.append('precision_'+str(i))
313
            met_cols.append('recall_'+str(i))
314
            met_cols.append('f1_'+str(i))
315
316
        reporter = CLIReporter(
317
            parameter_columns=list(config_space.keys()),
318
            metric_columns=met_cols,
319
        )
320
        result = tune.run(
321
            partial(train_hf,dataset=dataset),
322
            config=config_space,
323
            num_samples=args.num_trials,
324
            resources_per_trial={"gpu": 1},
325
            scheduler=scheduler,
326
            progress_reporter=reporter,
327
            local_dir="./to_be_deleted_rayArtifact",
328
            name='empty_folders',
329
            log_to_file=False,
330
            )
331
332
        best_trial = result.get_best_trial(metric='_metric/eval_macro_f1', mode='max', scope="all")
333
        config_dict = best_trial.config
334
        dev_eval_dict = best_trial.last_result['_metric']
335
        output_dict = {**params_dict, **config_dict, **dev_eval_dict}
336
337
        outpath = pathlib.Path().joinpath(args.evaldir, 'multi_BERT_ray.csv')
338
        print(output_dict)
339
        if os.path.isfile(outpath):
340
            indf = pd.read_csv(outpath)
341
            outdf = pd.concat([indf, pd.DataFrame([output_dict])], ignore_index=True)
342
        else:
343
            outdf = pd.DataFrame([output_dict])
344
        outdf.to_csv(outpath, index=False)
345
    else:
346
        config_space = {
347
            "learning_rate": 5e-5,
348
            "batch_size":32, #32
349
            "hidden_dropout_prob": 0.1,
350
            "undersample": 1.0,
351
            "weight_decay": 2e-8,
352
            "sequence_length": 100,
353
            "gradient_accumulation_steps": 3,
354
            "epochs": 10 
355
        }
356
357
        dev_eval_dict = train_hf(config_space, dataset)
358
        output_dict = {**params_dict, **config_space, **dev_eval_dict}
359
        outpath = pathlib.Path().joinpath(args.evaldir, 'multi_BERT_noray.csv')
360
        print(output_dict)
361
        if os.path.isfile(outpath):
362
            indf = pd.read_csv(outpath)
363
            outdf = pd.concat([indf, pd.DataFrame([output_dict])], ignore_index=True)
364
        else:
365
            outdf = pd.DataFrame([output_dict])
366
        outdf.to_csv(outpath, index=False)
367
        
368
369
if __name__ =='__main__':
370
    main(args)