|
a |
|
b/bert_train_predict.py |
|
|
1 |
import transformers |
|
|
2 |
import torch |
|
|
3 |
import pandas as pd |
|
|
4 |
import argparse |
|
|
5 |
import random |
|
|
6 |
import numpy as np |
|
|
7 |
from sklearn.metrics import classification_report, roc_auc_score, precision_recall_fscore_support |
|
|
8 |
from sklearn.preprocessing import MultiLabelBinarizer |
|
|
9 |
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSequenceClassification, AutoConfig, BertForSequenceClassification |
|
|
10 |
from ray.tune.schedulers import PopulationBasedTraining, ASHAScheduler |
|
|
11 |
import ray |
|
|
12 |
from ray import tune |
|
|
13 |
from ray.tune import CLIReporter |
|
|
14 |
from datasets import Dataset, load_dataset, DatasetDict, concatenate_datasets |
|
|
15 |
from functools import partial |
|
|
16 |
from utils import grade_preproc, group_labels, undersample_dataset, data_split |
|
|
17 |
import os |
|
|
18 |
from collections import Counter |
|
|
19 |
import pathlib |
|
|
20 |
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType |
|
|
21 |
from torch import nn |
|
|
22 |
from ray.tune.search.bayesopt import BayesOptSearch |
|
|
23 |
from ray.tune.search.hyperopt import HyperOptSearch |
|
|
24 |
from sklearn.utils import class_weight |
|
|
25 |
|
|
|
26 |
# Disable logging for raytune, but it will still make folders and jsons for experiment states |
|
|
27 |
# They're not big files, but should be deleted PATH: ./to_be_deleted_rayArtifact |
|
|
28 |
os.environ["TUNE_DISABLE_AUTO_CALLBACK_LOGGERS"] = "1" |
|
|
29 |
|
|
|
30 |
parser = argparse.ArgumentParser() |
|
|
31 |
parser.add_argument('--logdir', type=str, help='The path to the directory to temporarily store checkpoints') |
|
|
32 |
parser.add_argument('--evaldir', type=str, help='The path to the directory to store model evaluation results') |
|
|
33 |
parser.add_argument('--num_trials', type=int, help='Number hyperparameter trials', default=5) |
|
|
34 |
parser.add_argument('--seqlens', type=str, help='list of sequence lengths to search for ray', default='20,35,50') |
|
|
35 |
parser.add_argument('--batches', type=str, help='list of batch sizes to search for ray', default='32,64,128') |
|
|
36 |
parser.add_argument('--model', type=str, help='select model to run classification: (BERT, ROBERTA, BIOBERT)', default='bert-base-uncased') |
|
|
37 |
parser.add_argument('--synth_data', type=str, help='path to synthetic data file', default='') |
|
|
38 |
parser.add_argument('--undersample', type=float, default=0.0, help='undersample majority class in train set by proportion. E.g. 0.2 will keep 20 percent of majority class data') |
|
|
39 |
parser.add_argument('--ray', action='store_true', help='tune hyperparameters') |
|
|
40 |
parser.add_argument('--adverse', action='store_true', help='for non adverse synthetic data') |
|
|
41 |
parser.add_argument('--epochs', type=int, default=5) |
|
|
42 |
|
|
|
43 |
args = parser.parse_args() |
|
|
44 |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
45 |
|
|
|
46 |
SEED_VAL = 42 |
|
|
47 |
random.seed(SEED_VAL) |
|
|
48 |
np.random.seed(SEED_VAL) |
|
|
49 |
torch.manual_seed(SEED_VAL) |
|
|
50 |
torch.cuda.manual_seed_all(SEED_VAL) |
|
|
51 |
|
|
|
52 |
MLB = MultiLabelBinarizer() |
|
|
53 |
if args.adverse: |
|
|
54 |
LABELS = {'TRANSPORTATION_distance', 'TRANSPORTATION_resource', |
|
|
55 |
'TRANSPORTATION_other', 'HOUSING_poor', 'HOUSING_undomiciled','HOUSING_other', |
|
|
56 |
'RELATIONSHIP_divorced', 'RELATIONSHIP_widowed', 'RELATIONSHIP_single', |
|
|
57 |
'PARENT','EMPLOYMENT_underemployed','EMPLOYMENT_unemployed', 'EMPLOYMENT_disability','SUPPORT_minus'} |
|
|
58 |
else: |
|
|
59 |
LABELS = {'TRANSPORTATION_distance', 'TRANSPORTATION_resource', |
|
|
60 |
'TRANSPORTATION_other', 'HOUSING_poor', 'HOUSING_undomiciled', |
|
|
61 |
'HOUSING_other', 'RELATIONSHIP_married', 'RELATIONSHIP_partnered', |
|
|
62 |
'RELATIONSHIP_divorced', 'RELATIONSHIP_widowed', 'RELATIONSHIP_single', |
|
|
63 |
'PARENT','EMPLOYMENT_employed', 'EMPLOYMENT_underemployed', |
|
|
64 |
'EMPLOYMENT_unemployed', 'EMPLOYMENT_disability', 'EMPLOYMENT_retired', |
|
|
65 |
'EMPLOYMENT_student', 'SUPPORT_plus', 'SUPPORT_minus'} |
|
|
66 |
|
|
|
67 |
BROAD_LABELS = {lab.split('_')[0] for lab in LABELS} |
|
|
68 |
BROAD_LABELS.add('<NO_SDOH>') |
|
|
69 |
|
|
|
70 |
LABEL_BROAD_NARROW = LABELS.union(BROAD_LABELS) |
|
|
71 |
if args.ray: |
|
|
72 |
ray.init(log_to_driver=False) |
|
|
73 |
|
|
|
74 |
|
|
|
75 |
class BCETrainer(Trainer): |
|
|
76 |
def compute_loss(self, model, inputs, return_outputs=False): |
|
|
77 |
labels = inputs.get("labels").to(DEVICE) # batch[0, 1, 0, 1, 0, 0] |
|
|
78 |
# forward pass |
|
|
79 |
outputs = model(inputs['input_ids']) |
|
|
80 |
logits = outputs.get("logits").to(DEVICE) |
|
|
81 |
# compute custom loss (suppose one has 3 labels with different weights) |
|
|
82 |
loss_fct = nn.BCEWithLogitsLoss().to(DEVICE) |
|
|
83 |
loss = loss_fct(logits.to(DEVICE), labels.float().to(DEVICE)) |
|
|
84 |
return (loss, outputs) if return_outputs else loss |
|
|
85 |
|
|
|
86 |
|
|
|
87 |
def undersample(df, label, keep_percent): |
|
|
88 |
""" |
|
|
89 |
Undersamples the majority class in a Pandas dataframe to balance the classes. |
|
|
90 |
|
|
|
91 |
Parameters: |
|
|
92 |
df (pandas.DataFrame): The dataframe to undersample. |
|
|
93 |
keep_percent (float): The percentage of the majority class to keep. |
|
|
94 |
|
|
|
95 |
Returns: |
|
|
96 |
pandas.DataFrame: The undersampled dataframe. |
|
|
97 |
""" |
|
|
98 |
# Find the majority class based on the labels column |
|
|
99 |
counts = df[label].value_counts() |
|
|
100 |
majority_class = counts.idxmax() |
|
|
101 |
|
|
|
102 |
# Get the indices of rows in the majority class |
|
|
103 |
majority_indices = df[df[label] == majority_class].index |
|
|
104 |
|
|
|
105 |
# Calculate the number of majority class rows to keep |
|
|
106 |
num_majority_keep = int(keep_percent * counts[majority_class]) |
|
|
107 |
|
|
|
108 |
# Get a random subset of the majority class rows to keep |
|
|
109 |
majority_keep_indices = np.random.choice(majority_indices, num_majority_keep, replace=False) |
|
|
110 |
|
|
|
111 |
# Get the indices of rows in the minority class |
|
|
112 |
minority_indices = df[df[label] != majority_class].index |
|
|
113 |
|
|
|
114 |
# Combine the majority class subset and the minority class rows |
|
|
115 |
undersampled_indices = np.concatenate([majority_keep_indices, minority_indices]) |
|
|
116 |
|
|
|
117 |
# Return the undersampled dataframe |
|
|
118 |
return df.loc[undersampled_indices] |
|
|
119 |
|
|
|
120 |
|
|
|
121 |
def generate_label_list(row: pd.DataFrame) -> str: |
|
|
122 |
""" |
|
|
123 |
Generate a label list based on the given row from a Pandas DataFrame. |
|
|
124 |
|
|
|
125 |
Args: |
|
|
126 |
row (pd.DataFrame): A row from a Pandas DataFrame. |
|
|
127 |
|
|
|
128 |
Returns: |
|
|
129 |
str: A comma-separated string of labels extracted from the row. |
|
|
130 |
|
|
|
131 |
Examples: |
|
|
132 |
>>> df = pd.DataFrame({'label1_1': [1], 'label2_0': [0], 'label3_1': [1]}) |
|
|
133 |
>>> generate_label_list(df.iloc[0]) |
|
|
134 |
'label1,label3' |
|
|
135 |
|
|
|
136 |
>>> df = pd.DataFrame({'label2_0': [0], 'label3_0': [0]}) |
|
|
137 |
>>> generate_label_list(df.iloc[0]) |
|
|
138 |
'<NO_SDOH>' |
|
|
139 |
""" |
|
|
140 |
labels = set() |
|
|
141 |
for col_name, value in row.items(): |
|
|
142 |
if col_name in LABELS and value == 1: |
|
|
143 |
labels.add(col_name.split('_')[0]) |
|
|
144 |
if len(labels) == 0: |
|
|
145 |
labels.add('<NO_SDOH>') |
|
|
146 |
return ','.join(list(labels)) |
|
|
147 |
|
|
|
148 |
|
|
|
149 |
def compute_metrics(pred): |
|
|
150 |
""" |
|
|
151 |
Calculate Evaluation metrics |
|
|
152 |
""" |
|
|
153 |
labels = pred.label_ids |
|
|
154 |
logits = torch.tensor(pred.predictions) |
|
|
155 |
act = nn.Sigmoid() |
|
|
156 |
probs = act(logits) |
|
|
157 |
preds = (probs>= 0.5).int() |
|
|
158 |
|
|
|
159 |
# labels = mlb.fit_transform(labels) |
|
|
160 |
# preds = MLB.transform(preds) |
|
|
161 |
prec, rec, f1, _ = precision_recall_fscore_support(labels, preds) |
|
|
162 |
micro_f1 = precision_recall_fscore_support(labels, preds, average='micro')[2] |
|
|
163 |
weight_f1 = precision_recall_fscore_support(labels, preds, average='weighted')[2] |
|
|
164 |
macro_f1 = precision_recall_fscore_support(labels, preds, average='macro')[2] |
|
|
165 |
|
|
|
166 |
metrics_out = {'macro_f1':macro_f1, 'micro_f1': micro_f1, 'weighted_f1': weight_f1} |
|
|
167 |
for i, lab in enumerate(list(MLB.classes_)): |
|
|
168 |
metrics_out['precision_'+str(lab)] = prec[i] |
|
|
169 |
metrics_out['recall_'+str(lab)] = rec[i] |
|
|
170 |
metrics_out['f1_'+str(lab)] = f1[i] |
|
|
171 |
print(classification_report(labels, preds, target_names=MLB.classes_)) |
|
|
172 |
return metrics_out |
|
|
173 |
|
|
|
174 |
|
|
|
175 |
def train_hf(config, dataset): |
|
|
176 |
# Define the Trainer and TrainingArguments objects |
|
|
177 |
# Initialize the tokenizer with the sequence_length parameter |
|
|
178 |
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True) |
|
|
179 |
def tokenize(batch): |
|
|
180 |
return tokenizer(batch['text'], padding='max_length', truncation=True, return_tensors="pt", max_length=config["sequence_length"]) |
|
|
181 |
|
|
|
182 |
tokenized_dataset = dataset.map(tokenize, batched=True, remove_columns=["text"]) |
|
|
183 |
training_args = TrainingArguments( |
|
|
184 |
output_dir=args.logdir, |
|
|
185 |
per_device_train_batch_size=config["batch_size"], |
|
|
186 |
per_device_eval_batch_size=config["batch_size"], |
|
|
187 |
learning_rate=config["learning_rate"], |
|
|
188 |
num_train_epochs=config["epochs"], |
|
|
189 |
disable_tqdm=False, |
|
|
190 |
bf16=True, # bfloat16 training |
|
|
191 |
optim='adamw_hf', |
|
|
192 |
logging_dir=f"{args.logdir}/logs", |
|
|
193 |
overwrite_output_dir = True, |
|
|
194 |
evaluation_strategy = 'epoch', |
|
|
195 |
weight_decay= config["weight_decay"], |
|
|
196 |
save_strategy='epoch', |
|
|
197 |
save_total_limit = 1, |
|
|
198 |
load_best_model_at_end=True, |
|
|
199 |
metric_for_best_model="macro_f1", |
|
|
200 |
seed = SEED_VAL, |
|
|
201 |
gradient_accumulation_steps = config["gradient_accumulation_steps"] |
|
|
202 |
) |
|
|
203 |
|
|
|
204 |
model = AutoModelForSequenceClassification.from_pretrained( |
|
|
205 |
pretrained_model_name_or_path=args.model, |
|
|
206 |
num_labels=len(dataset['train']['labels'][0]), |
|
|
207 |
attention_probs_dropout_prob=config["hidden_dropout_prob"], |
|
|
208 |
hidden_dropout_prob=config["hidden_dropout_prob"] |
|
|
209 |
) |
|
|
210 |
|
|
|
211 |
# clws = torch.tensor([config["class_weight0"], config["class_weight1"]], dtype=torch.float).to(DEVICE) |
|
|
212 |
trainer = BCETrainer( |
|
|
213 |
args=training_args, |
|
|
214 |
tokenizer=tokenizer, |
|
|
215 |
train_dataset=tokenized_dataset['train'], |
|
|
216 |
eval_dataset=tokenized_dataset['dev'], |
|
|
217 |
model=model, |
|
|
218 |
compute_metrics=compute_metrics, |
|
|
219 |
) |
|
|
220 |
|
|
|
221 |
# Train the model and return the evaluation |
|
|
222 |
trainer.train() |
|
|
223 |
eval_result = trainer.evaluate() |
|
|
224 |
if args.ray: |
|
|
225 |
tune.report(eval_result) |
|
|
226 |
else: |
|
|
227 |
return eval_result |
|
|
228 |
|
|
|
229 |
|
|
|
230 |
def main(args): |
|
|
231 |
train_data = pd.read_csv('./data/train_sents.csv') |
|
|
232 |
dev_data = pd.read_csv('./data/dev_sents.csv') |
|
|
233 |
|
|
|
234 |
train_data.fillna(value={'text':''}, inplace=True) |
|
|
235 |
dev_data.fillna(value={'text':''}, inplace=True) |
|
|
236 |
|
|
|
237 |
dev_text = dev_data['text'].tolist() |
|
|
238 |
dev_labels = dev_data.apply(generate_label_list, axis=1).tolist() |
|
|
239 |
|
|
|
240 |
train_data['LABEL'] = train_data.apply(generate_label_list, axis=1).tolist() |
|
|
241 |
|
|
|
242 |
if args.undersample: |
|
|
243 |
train_data = undersample(train_data, label='LABEL', keep_percent=args.undersample) |
|
|
244 |
train_text = train_data['text'].tolist() |
|
|
245 |
train_labels = train_data['LABEL'].tolist() |
|
|
246 |
|
|
|
247 |
if args.synth_data: |
|
|
248 |
synthetic_data = pd.read_csv(args.synth_data) |
|
|
249 |
if args.adverse: |
|
|
250 |
synthetic_data = synthetic_data[synthetic_data['adverse']=='adverse'] |
|
|
251 |
synthetic_data.reset_index(inplace=True, drop=True) |
|
|
252 |
|
|
|
253 |
binary_synthetic = pd.get_dummies(synthetic_data['label']) |
|
|
254 |
binary_synthetic['text'] = synthetic_data['text'] |
|
|
255 |
synth_labels = binary_synthetic.apply(generate_label_list, axis=1).tolist() |
|
|
256 |
synth_text = synthetic_data['text'].tolist() |
|
|
257 |
|
|
|
258 |
train_text.extend(synth_text) |
|
|
259 |
train_labels.extend(synth_labels) |
|
|
260 |
|
|
|
261 |
train_labels = [labs.split(',') for labs in train_labels] |
|
|
262 |
train_labs_mlb = MLB.fit_transform(train_labels) |
|
|
263 |
train_labs_mlb = [ar.tolist() for ar in train_labs_mlb] |
|
|
264 |
|
|
|
265 |
dev_labels = [labs.split(',') for labs in dev_labels] |
|
|
266 |
dev_labs_mlb = MLB.transform(dev_labels) |
|
|
267 |
dev_labs_mlb = [ar.tolist() for ar in dev_labs_mlb] |
|
|
268 |
|
|
|
269 |
train_t5 = pd.DataFrame({'text':train_text, 'labels':train_labs_mlb}) |
|
|
270 |
dev_t5 = pd.DataFrame({'text':dev_text, 'labels':dev_labs_mlb}) |
|
|
271 |
|
|
|
272 |
train_dataset = Dataset.from_pandas(train_t5) |
|
|
273 |
dev_dataset = Dataset.from_pandas(dev_t5) |
|
|
274 |
|
|
|
275 |
dataset = DatasetDict() |
|
|
276 |
dataset['train'] = train_dataset |
|
|
277 |
dataset['dev'] = dev_dataset |
|
|
278 |
|
|
|
279 |
seq_length_search = [int(x) for x in args.seqlens.split(',')] |
|
|
280 |
batch_size_search = [int(x) for x in args.batches.split(',')] |
|
|
281 |
|
|
|
282 |
params_dict ={ |
|
|
283 |
'model':args.model, |
|
|
284 |
'undersample_bool':args.undersample |
|
|
285 |
} |
|
|
286 |
|
|
|
287 |
if args.ray: |
|
|
288 |
if args.undersample: |
|
|
289 |
usample = args.undersample |
|
|
290 |
else: |
|
|
291 |
usample = 1 |
|
|
292 |
config_space = { |
|
|
293 |
"learning_rate": tune.loguniform(1e-5, 1e-3), |
|
|
294 |
"batch_size": tune.choice(batch_size_search), |
|
|
295 |
"hidden_dropout_prob": tune.uniform(0.1, 0.5), |
|
|
296 |
"undersample": usample, |
|
|
297 |
"weight_decay": tune.loguniform(1e-8, 1e-5), |
|
|
298 |
"sequence_length": tune.choice(seq_length_search), |
|
|
299 |
"gradient_accumulation_steps": 3, |
|
|
300 |
"epochs": args.epochs |
|
|
301 |
} |
|
|
302 |
|
|
|
303 |
scheduler = ASHAScheduler( |
|
|
304 |
metric="_metric/eval_macro_f1", |
|
|
305 |
mode="max", |
|
|
306 |
grace_period=1, |
|
|
307 |
reduction_factor=2 |
|
|
308 |
) |
|
|
309 |
|
|
|
310 |
met_cols = ["training_iteration","macro_f1", "micro_f1", "precision", "recall"] |
|
|
311 |
for i in range(len(train_labs_mlb[0])): |
|
|
312 |
met_cols.append('precision_'+str(i)) |
|
|
313 |
met_cols.append('recall_'+str(i)) |
|
|
314 |
met_cols.append('f1_'+str(i)) |
|
|
315 |
|
|
|
316 |
reporter = CLIReporter( |
|
|
317 |
parameter_columns=list(config_space.keys()), |
|
|
318 |
metric_columns=met_cols, |
|
|
319 |
) |
|
|
320 |
result = tune.run( |
|
|
321 |
partial(train_hf,dataset=dataset), |
|
|
322 |
config=config_space, |
|
|
323 |
num_samples=args.num_trials, |
|
|
324 |
resources_per_trial={"gpu": 1}, |
|
|
325 |
scheduler=scheduler, |
|
|
326 |
progress_reporter=reporter, |
|
|
327 |
local_dir="./to_be_deleted_rayArtifact", |
|
|
328 |
name='empty_folders', |
|
|
329 |
log_to_file=False, |
|
|
330 |
) |
|
|
331 |
|
|
|
332 |
best_trial = result.get_best_trial(metric='_metric/eval_macro_f1', mode='max', scope="all") |
|
|
333 |
config_dict = best_trial.config |
|
|
334 |
dev_eval_dict = best_trial.last_result['_metric'] |
|
|
335 |
output_dict = {**params_dict, **config_dict, **dev_eval_dict} |
|
|
336 |
|
|
|
337 |
outpath = pathlib.Path().joinpath(args.evaldir, 'multi_BERT_ray.csv') |
|
|
338 |
print(output_dict) |
|
|
339 |
if os.path.isfile(outpath): |
|
|
340 |
indf = pd.read_csv(outpath) |
|
|
341 |
outdf = pd.concat([indf, pd.DataFrame([output_dict])], ignore_index=True) |
|
|
342 |
else: |
|
|
343 |
outdf = pd.DataFrame([output_dict]) |
|
|
344 |
outdf.to_csv(outpath, index=False) |
|
|
345 |
else: |
|
|
346 |
config_space = { |
|
|
347 |
"learning_rate": 5e-5, |
|
|
348 |
"batch_size":32, #32 |
|
|
349 |
"hidden_dropout_prob": 0.1, |
|
|
350 |
"undersample": 1.0, |
|
|
351 |
"weight_decay": 2e-8, |
|
|
352 |
"sequence_length": 100, |
|
|
353 |
"gradient_accumulation_steps": 3, |
|
|
354 |
"epochs": 10 |
|
|
355 |
} |
|
|
356 |
|
|
|
357 |
dev_eval_dict = train_hf(config_space, dataset) |
|
|
358 |
output_dict = {**params_dict, **config_space, **dev_eval_dict} |
|
|
359 |
outpath = pathlib.Path().joinpath(args.evaldir, 'multi_BERT_noray.csv') |
|
|
360 |
print(output_dict) |
|
|
361 |
if os.path.isfile(outpath): |
|
|
362 |
indf = pd.read_csv(outpath) |
|
|
363 |
outdf = pd.concat([indf, pd.DataFrame([output_dict])], ignore_index=True) |
|
|
364 |
else: |
|
|
365 |
outdf = pd.DataFrame([output_dict]) |
|
|
366 |
outdf.to_csv(outpath, index=False) |
|
|
367 |
|
|
|
368 |
|
|
|
369 |
if __name__ =='__main__': |
|
|
370 |
main(args) |