--- a
+++ b/CheXbert/src/label.py
@@ -0,0 +1,145 @@
+import os
+import sys
+import argparse
+import torch
+import torch.nn as nn
+import pandas as pd
+import numpy as np
+from tqdm import tqdm
+from collections import OrderedDict
+
+# Local imports
+import utils
+from models.bert_labeler import bert_labeler
+from bert_tokenizer import tokenize
+from transformers import BertTokenizer
+from datasets.unlabeled_dataset import UnlabeledDataset
+from constants import *
+
+
+def collate_fn_no_labels(sample_list):
+    tensor_list = [s['imp'] for s in sample_list]
+    batched_imp = torch.nn.utils.rnn.pad_sequence(
+        tensor_list, batch_first=True, padding_value=PAD_IDX
+    )
+    len_list = [s['len'] for s in sample_list]
+    batch = {'imp': batched_imp, 'len': len_list}
+    return batch
+
+
+def load_unlabeled_data(csv_path, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False):
+    collate_fn = collate_fn_no_labels
+    dset = UnlabeledDataset(csv_path)
+    loader = torch.utils.data.DataLoader(
+        dset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn
+    )
+    return loader
+
+
+def label(checkpoint_path, csv_path):
+    ld = load_unlabeled_data(csv_path)
+    
+    model = bert_labeler()
+    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+    if torch.cuda.device_count() > 0: 
+        print("Using", torch.cuda.device_count(), "GPUs!")
+        model = nn.DataParallel(model) 
+        model = model.to(device)
+        checkpoint = torch.load(checkpoint_path)
+        model.load_state_dict(checkpoint['model_state_dict'])
+    else:
+        checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
+        new_state_dict = OrderedDict()
+        for k, v in checkpoint['model_state_dict'].items():
+            name = k[7:] 
+            new_state_dict[name] = v
+        model.load_state_dict(new_state_dict)
+        
+    was_training = model.training
+    model.eval()
+    y_pred = [[] for _ in range(len(CONDITIONS))]
+
+    print("\nBegin report impression labeling. The progress bar counts the # of batches completed:")
+    print("The batch size is %d" % BATCH_SIZE)
+    with torch.no_grad():
+        for i, data in enumerate(tqdm(ld)):
+            batch = data['imp']
+            batch = batch.to(device)
+            src_len = data['len']
+            batch_size = batch.shape[0]
+            attn_mask = utils.generate_attention_masks(batch, src_len, device)
+
+            out = model(batch, attn_mask)
+
+            for j in range(len(out)):
+                curr_y_pred = out[j].argmax(dim=1)
+                y_pred[j].append(curr_y_pred)
+
+        for j in range(len(y_pred)):
+            y_pred[j] = torch.cat(y_pred[j], dim=0)
+             
+    if was_training:
+        model.train()
+
+    y_pred = [t.tolist() for t in y_pred]
+    return y_pred
+
+
+def save_preds(y_pred, csv_path, out_path):
+    y_pred = np.array(y_pred)
+    y_pred = y_pred.T
+
+    # Read original dataset to include image_id
+    original_dataset = pd.read_csv(csv_path)
+
+    # Ensure the 'image_id' column exists
+    if 'image_id' not in original_dataset.columns:
+        raise ValueError("The input CSV must have an 'image_id' column.")
+
+    # Extract image IDs and report impressions
+    image_ids = original_dataset['image_id']
+    reports = original_dataset['Report Impression']
+
+    # Create DataFrame with predictions
+    df = pd.DataFrame(y_pred, columns=CONDITIONS)
+    df['image_id'] = image_ids
+    df['Report Impression'] = reports
+
+    # Reorder columns
+    new_cols = ['image_id', 'Report Impression'] + CONDITIONS
+    df = df[new_cols]
+
+    # Replace classes with their appropriate values
+    df.replace(0, np.nan, inplace=True)  # Blank class is NaN
+    df.replace(3, -1, inplace=True)      # Uncertain class is -1
+    df.replace(2, 0, inplace=True)       # Negative class is 0 
+
+    # Save output CSV
+    output_file = os.path.join(out_path, 'labeled_reports_with_images.csv')
+    df.to_csv(output_file, index=False)
+    print(f"Labeled reports saved with image file names: {output_file}")
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='Label a csv file containing radiology reports')
+    parser.add_argument('-d', '--data', type=str, nargs='?', required=True,
+                        help='path to csv containing reports. The reports should be \
+                              under the \"Report Impression\" column')
+    parser.add_argument('-o', '--output_dir', type=str, nargs='?', required=True,
+                        help='path to intended output folder')
+    parser.add_argument('-c', '--checkpoint', type=str, nargs='?', required=True,
+                        help='path to the pytorch checkpoint')
+    args = parser.parse_args()
+    csv_path = args.data
+    out_path = args.output_dir
+    checkpoint_path = args.checkpoint
+
+    y_pred = label(checkpoint_path, csv_path)
+    save_preds(y_pred, csv_path, out_path)
+
+
+# Sample Usage : # You will need to download chexbert checkpoint : "https://stanfordmedicine.app.box.com/s/c3stck6w6dol3h36grdc97xoydzxd7w9" (Dowload Chexbert.pth)
+
+"""
+python label.py -d="C:\Anand\Projects_GWU\NLP_Project\FinalProject-Group6\Data\final_cleaned.csv" -o="C:\Anand\Projects_GWU\NLP_Project\FinalProject-Group6\Results" -c="C:\Users\anand\Downloads\chexbert.pth"
+"""
\ No newline at end of file