a b/t5_predict.py
1
from datasets import Dataset, DatasetDict
2
import torch
3
from random import randrange, sample
4
from transformers import DataCollatorForSeq2Seq, T5ForConditionalGeneration
5
import pandas as pd
6
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType, PeftModel, PeftConfig
8
from transformers import DataCollatorForSeq2Seq
9
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
10
from sklearn.preprocessing import MultiLabelBinarizer
11
from sklearn.metrics import classification_report, roc_auc_score, precision_recall_fscore_support
12
import json
13
import argparse
14
import tqdm
15
import numpy as np
16
import random
17
import os
18
19
SEED_VAL = 42
20
random.seed(SEED_VAL)
21
np.random.seed(SEED_VAL)
22
torch.manual_seed(SEED_VAL)
23
torch.cuda.manual_seed_all(SEED_VAL)
24
25
parser = argparse.ArgumentParser()
26
parser.add_argument('--model_path', type=str, help='path to trained t5 model. FORMAT: model_task_augmentationBool_undersampleValue_syntheticDataPath')
27
parser.add_argument('--output_path', type=str, help='path to json store metrics')
28
parser.add_argument('--error_file', type=str, help='path to synthetic data file')
29
parser.add_argument('--batch_size', type=int, help='prediction batches')
30
parser.add_argument('--adverse', action='store_true', help='only add adverse labels')
31
parser.add_argument('--test', action='store_true', help='eval on test set')
32
args = parser.parse_args()
33
34
if args.adverse:
35
    LABELS = {'TRANSPORTATION_distance', 'TRANSPORTATION_resource',
36
        'TRANSPORTATION_other', 'HOUSING_poor', 'HOUSING_undomiciled','HOUSING_other',
37
        'RELATIONSHIP_divorced', 'RELATIONSHIP_widowed', 'RELATIONSHIP_single',
38
        'PARENT','EMPLOYMENT_underemployed','EMPLOYMENT_unemployed', 'EMPLOYMENT_disability', 'EMPLOYMENT_retired',
39
        'EMPLOYMENT_student','SUPPORT_minus'}
40
else:
41
    LABELS = {'TRANSPORTATION_distance', 'TRANSPORTATION_resource',
42
        'TRANSPORTATION_other', 'HOUSING_poor', 'HOUSING_undomiciled',
43
        'HOUSING_other', 'RELATIONSHIP_married', 'RELATIONSHIP_partnered',
44
        'RELATIONSHIP_divorced', 'RELATIONSHIP_widowed', 'RELATIONSHIP_single',
45
        'PARENT','EMPLOYMENT_employed', 'EMPLOYMENT_underemployed',
46
        'EMPLOYMENT_unemployed', 'EMPLOYMENT_disability', 'EMPLOYMENT_retired',
47
        'EMPLOYMENT_student', 'SUPPORT_plus', 'SUPPORT_minus'}
48
49
BROAD_LABELS = {lab.split('_')[0] for lab in LABELS}
50
BROAD_LABELS.add('<NO_SDOH>')
51
52
LABEL_BROAD_NARROW = LABELS.union(BROAD_LABELS)
53
TOKENIZER = AutoTokenizer.from_pretrained(args.model_path)
54
MAX_S_LEN = 100
55
MAX_T_LEN = 40
56
57
58
def generate_label_list(row: pd.DataFrame) -> str:
59
    """
60
    Generate a label list based on the given row from a Pandas DataFrame.
61
62
    Args:
63
        row (pd.DataFrame): A row from a Pandas DataFrame.
64
65
    Returns:
66
        str: A comma-separated string of labels extracted from the row.
67
68
    Examples:
69
        >>> df = pd.DataFrame({'label1_1': [1], 'label2_0': [0], 'label3_1': [1]})
70
        >>> generate_label_list(df.iloc[0])
71
        'label1,label3'
72
73
        >>> df = pd.DataFrame({'label2_0': [0], 'label3_0': [0]})
74
        >>> generate_label_list(df.iloc[0])
75
        '<NO_SDOH>'
76
    """
77
    labels = set()
78
    for col_name, value in row.items():
79
        if col_name in LABELS and value == 1:
80
            labels.add(col_name.split('_')[0])
81
    if len(labels) == 0:
82
        labels.add('<NO_SDOH>')
83
    return ','.join(list(labels))
84
85
86
def postprocess_function(preds):
87
    """
88
    Perform post-processing on the predictions.
89
90
    Args:
91
        preds (list): A list of predictions.
92
93
    Returns:
94
        list: Processed predictions with fixed labels.
95
96
    Examples:
97
        >>> preds = ['REL', 'EMPLO', 'HOUS', 'UNKNOWN']
98
        >>> postprocess_function(preds)
99
        ['RELATIONSHIP', 'EMPLOYMENT', 'HOUSING', 'UNKNOWN']
100
101
        >>> preds = ['NO_SD', np.nan, 'SUPP']
102
        >>> postprocess_function(preds)
103
        ['<NO_SDOH>', '<NO_SDOH>', 'SUPPORT']
104
    """
105
    lab_fixed_dict = {
106
        'REL': 'RELATIONSHIP',
107
        'RELAT': 'RELATIONSHIP',
108
        'EMP': 'EMPLOYMENT',
109
        'EMPLO': 'EMPLOYMENT',
110
        'SUPP': 'SUPPORT',
111
        'HOUS': 'HOUSING',
112
        'PAREN': 'PARENT',
113
        'TRANSPORT': 'TRANSPORTATION',
114
        'NO_SD': '<NO_SDOH>',
115
        np.nan: '<NO_SDOH>',
116
        'NO_SDOH>': '<NO_SDOH>',
117
        '<NO_SDOH': '<NO_SDOH>',
118
    }
119
120
    new_preds = []
121
    for pred in preds:
122
        pred_ls = []
123
        pred = str(pred)
124
        for pp in pred.split(','):
125
            if pp in lab_fixed_dict.keys():
126
                pred_ls.append(lab_fixed_dict[pp])
127
            else:
128
                pred_ls.append(pp)
129
        new_preds.append(','.join(pred_ls))
130
131
    return new_preds
132
133
134
def preprocess_function(sample,padding="max_length"):
135
    # add prefix to the input for t5
136
    inputs = ["summarize: " + item for item in sample["text"]]
137
    # tokenize inputs
138
    model_inputs = TOKENIZER(inputs, max_length=MAX_S_LEN, padding=padding, truncation=True)
139
140
    # Tokenize targets with the `text_target` keyword argument
141
    labels = TOKENIZER(text_target=sample["SDOHlabels"], max_length=MAX_T_LEN, padding=padding, truncation=True)
142
143
    # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
144
    # padding in the loss.
145
    if padding == "max_length":
146
        labels["input_ids"] = [
147
            [(l if l != TOKENIZER.pad_token_id else -100) for l in label] for label in labels["input_ids"]
148
        ]
149
    model_inputs["labels"] = labels["input_ids"]
150
    return model_inputs
151
152
153
def normal_eval(preds, gold):
154
    """
155
    Evaluate the model predictions against the gold labels.
156
157
    Args:
158
        preds (list): A list of prediction strings.
159
        gold (list): A list of gold label strings.
160
161
    Returns:
162
        dict: Metrics computed for the evaluation.
163
164
    """
165
    pred_temp = [p.split(",") for p in preds]
166
    gold_list = [g.split(',') for g in gold]
167
168
    pred_list = []
169
    for labs in pred_temp:
170
        point_pred = [p for p in labs if p in BROAD_LABELS]
171
        pred_list.append(point_pred)
172
    mlb = MultiLabelBinarizer()
173
    oh_gold = mlb.fit_transform(gold_list)
174
    oh_pred = mlb.transform(pred_list)
175
176
    prec, rec, f1, _ = precision_recall_fscore_support(oh_gold, oh_pred)
177
    micro_f1  = precision_recall_fscore_support(oh_gold, oh_pred, average='micro')[2]
178
    weight_f1 = precision_recall_fscore_support(oh_gold, oh_pred, average='weighted')[2]
179
    macro_f1 = precision_recall_fscore_support(oh_gold, oh_pred, average='macro')[2]
180
181
    metrics_out = {'macro_f1':macro_f1, 'micro_f1': micro_f1, 'weighted_f1': weight_f1}
182
    for i, lab in enumerate(list(mlb.classes_)):
183
        metrics_out['precision_'+str(lab)] = prec[i]
184
        metrics_out['recall_'+str(lab)] = rec[i]
185
        metrics_out['f1_'+str(lab)] = f1[i]
186
    print(classification_report(oh_gold, oh_pred, target_names=mlb.classes_))
187
    return metrics_out
188
189
190
def predict(dataset, model, batch_size):
191
    # Initialize empty lists to store predictions and references
192
    predictions, references = [], []
193
194
    # Iterate over the dataset in batches
195
    for i in tqdm.tqdm(range(0, len(dataset["dev"]), batch_size)):
196
        # Get the texts for the current batch
197
        texts = dataset['dev'][i:i+batch_size]
198
199
        # Tokenize the texts and convert them to input tensors
200
        input_ids = TOKENIZER(texts["text"], return_tensors="pt", truncation=True, padding="max_length").input_ids.cuda()
201
202
        # Generate predictions using the model
203
        outputs = model.generate(input_ids=input_ids, do_sample=False, top_p=0.9, max_new_tokens=5, num_beams=4)
204
205
        # Decode the generated outputs into text
206
        outputs = TOKENIZER.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)
207
208
        # Get the reference labels for the current batch
209
        labels = dataset['dev'][i:i+batch_size]["SDOHlabels"]
210
211
        # Extend the predictions and references lists
212
        predictions.extend(outputs)
213
        references.extend(labels)
214
215
    # Return the final predictions and references
216
    return predictions, references
217
218
if __name__ == '__main__':
219
    if args.test:
220
        dev_data = pd.read_csv('../data/test_sents.csv')
221
    else:
222
        dev_data = pd.read_csv('../data/dev_sents.csv')
223
224
    dev_data.fillna(value={'text':''}, inplace=True)
225
226
    dev_text = dev_data['text'].tolist()
227
    dev_labels = dev_data.apply(generate_label_list, axis=1).tolist()
228
    dev_t5 = pd.DataFrame({'text':dev_text, 'SDOHlabels':dev_labels})
229
    dev_dataset = Dataset.from_pandas(dev_t5)
230
    dataset = DatasetDict()
231
    dataset['dev'] = dev_dataset
232
233
    config = PeftConfig.from_pretrained(args.model_path)
234
    # load base LLM model and tokenizer
235
    reloaded_model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path,  load_in_8bit=True,  device_map={"":0})
236
    # Load the Lora model
237
    reloaded_model = PeftModel.from_pretrained(reloaded_model, args.model_path, device_map={"":0})
238
    reloaded_model.eval()
239
240
    predictions, references = predict(dataset, reloaded_model, 6)
241
242
    df = pd.DataFrame({'gold':references, 'pred':predictions})
243
    df.to_csv(args.error_file, index=False)
244
245
    params = args.model_path.split('_')
246
    param_dict = {'model':params[0], 'task':params[1], 'train_data':params[2], 'undersample':params[3], 'synthetic_data':params[4]}
247
248
    metrics = normal_eval(predictions, references)
249
    print('='*30+'POST PROCESSED'+'='*30)
250
    processed_predictions = postprocess_function(predictions)
251
    processed_metrics = normal_eval(processed_predictions, references)
252
    output_dict = {**param_dict, **processed_metrics}
253
    if os.path.isfile('./processed_results_dev.csv'):
254
        indf = pd.read_csv('./processed_results_dev.csv')
255
        outdf = pd.concat([indf, pd.DataFrame([output_dict])], ignore_index=True)
256
    else:
257
        outdf = pd.DataFrame([output_dict])
258
    outdf.to_csv('./processed_results_dev.csv', index=False)
259
260
    with open(args.output_path, 'w') as j:
261
        json.dump(metrics, j, indent=4)