--- a +++ b/CheXbert/src/label.py @@ -0,0 +1,145 @@ +import os +import sys +import argparse +import torch +import torch.nn as nn +import pandas as pd +import numpy as np +from tqdm import tqdm +from collections import OrderedDict + +# Local imports +import utils +from models.bert_labeler import bert_labeler +from bert_tokenizer import tokenize +from transformers import BertTokenizer +from datasets.unlabeled_dataset import UnlabeledDataset +from constants import * + + +def collate_fn_no_labels(sample_list): + tensor_list = [s['imp'] for s in sample_list] + batched_imp = torch.nn.utils.rnn.pad_sequence( + tensor_list, batch_first=True, padding_value=PAD_IDX + ) + len_list = [s['len'] for s in sample_list] + batch = {'imp': batched_imp, 'len': len_list} + return batch + + +def load_unlabeled_data(csv_path, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False): + collate_fn = collate_fn_no_labels + dset = UnlabeledDataset(csv_path) + loader = torch.utils.data.DataLoader( + dset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn + ) + return loader + + +def label(checkpoint_path, csv_path): + ld = load_unlabeled_data(csv_path) + + model = bert_labeler() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + if torch.cuda.device_count() > 0: + print("Using", torch.cuda.device_count(), "GPUs!") + model = nn.DataParallel(model) + model = model.to(device) + checkpoint = torch.load(checkpoint_path) + model.load_state_dict(checkpoint['model_state_dict']) + else: + checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) + new_state_dict = OrderedDict() + for k, v in checkpoint['model_state_dict'].items(): + name = k[7:] + new_state_dict[name] = v + model.load_state_dict(new_state_dict) + + was_training = model.training + model.eval() + y_pred = [[] for _ in range(len(CONDITIONS))] + + print("\nBegin report impression labeling. The progress bar counts the # of batches completed:") + print("The batch size is %d" % BATCH_SIZE) + with torch.no_grad(): + for i, data in enumerate(tqdm(ld)): + batch = data['imp'] + batch = batch.to(device) + src_len = data['len'] + batch_size = batch.shape[0] + attn_mask = utils.generate_attention_masks(batch, src_len, device) + + out = model(batch, attn_mask) + + for j in range(len(out)): + curr_y_pred = out[j].argmax(dim=1) + y_pred[j].append(curr_y_pred) + + for j in range(len(y_pred)): + y_pred[j] = torch.cat(y_pred[j], dim=0) + + if was_training: + model.train() + + y_pred = [t.tolist() for t in y_pred] + return y_pred + + +def save_preds(y_pred, csv_path, out_path): + y_pred = np.array(y_pred) + y_pred = y_pred.T + + # Read original dataset to include image_id + original_dataset = pd.read_csv(csv_path) + + # Ensure the 'image_id' column exists + if 'image_id' not in original_dataset.columns: + raise ValueError("The input CSV must have an 'image_id' column.") + + # Extract image IDs and report impressions + image_ids = original_dataset['image_id'] + reports = original_dataset['Report Impression'] + + # Create DataFrame with predictions + df = pd.DataFrame(y_pred, columns=CONDITIONS) + df['image_id'] = image_ids + df['Report Impression'] = reports + + # Reorder columns + new_cols = ['image_id', 'Report Impression'] + CONDITIONS + df = df[new_cols] + + # Replace classes with their appropriate values + df.replace(0, np.nan, inplace=True) # Blank class is NaN + df.replace(3, -1, inplace=True) # Uncertain class is -1 + df.replace(2, 0, inplace=True) # Negative class is 0 + + # Save output CSV + output_file = os.path.join(out_path, 'labeled_reports_with_images.csv') + df.to_csv(output_file, index=False) + print(f"Labeled reports saved with image file names: {output_file}") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Label a csv file containing radiology reports') + parser.add_argument('-d', '--data', type=str, nargs='?', required=True, + help='path to csv containing reports. The reports should be \ + under the \"Report Impression\" column') + parser.add_argument('-o', '--output_dir', type=str, nargs='?', required=True, + help='path to intended output folder') + parser.add_argument('-c', '--checkpoint', type=str, nargs='?', required=True, + help='path to the pytorch checkpoint') + args = parser.parse_args() + csv_path = args.data + out_path = args.output_dir + checkpoint_path = args.checkpoint + + y_pred = label(checkpoint_path, csv_path) + save_preds(y_pred, csv_path, out_path) + + +# Sample Usage : # You will need to download chexbert checkpoint : "https://stanfordmedicine.app.box.com/s/c3stck6w6dol3h36grdc97xoydzxd7w9" (Dowload Chexbert.pth) + +""" +python label.py -d="C:\Anand\Projects_GWU\NLP_Project\FinalProject-Group6\Data\final_cleaned.csv" -o="C:\Anand\Projects_GWU\NLP_Project\FinalProject-Group6\Results" -c="C:\Users\anand\Downloads\chexbert.pth" +""" \ No newline at end of file