Switch to unified view

a b/findings_classifier/chexpert_train.py
1
import os
2
3
# os.environ["CUDA_VISIBLE_DEVICES"] = "6"
4
import argparse
5
import json
6
from collections import defaultdict
7
8
import numpy as np
9
import pytorch_lightning as pl
10
import torch
11
import wandb
12
from pytorch_lightning.callbacks import ModelCheckpoint
13
from sklearn.metrics import accuracy_score, classification_report, jaccard_score, roc_auc_score
14
from torch.nn import BCEWithLogitsLoss
15
from torch.utils.data import DataLoader
16
from torchinfo import summary
17
from tqdm import tqdm
18
from transformers import AdamW
19
20
from findings_classifier.chexpert_dataset import Chexpert_Dataset
21
from findings_classifier.chexpert_model import ChexpertClassifier
22
from local_config import WANDB_ENTITY
23
24
25
class LitIGClassifier(pl.LightningModule):
26
    def __init__(self, num_classes, class_names, class_weights=None, learning_rate=1e-5):
27
        super().__init__()
28
29
        # Model
30
        self.model = ChexpertClassifier(num_classes)
31
32
        # Loss with class weights
33
        if class_weights is None:
34
            self.criterion = BCEWithLogitsLoss()
35
        else:
36
            self.criterion = BCEWithLogitsLoss(pos_weight=class_weights)
37
38
        # Learning rate
39
        self.learning_rate = learning_rate
40
        self.class_names = class_names
41
42
    def forward(self, x):
43
        return self.model(x)
44
45
    def step(self, batch, batch_idx):
46
        x, y = batch['image'].to(self.device), batch['labels'].to(self.device)
47
        logits = self(x)
48
        loss = self.criterion(logits, y)
49
50
        # Apply sigmoid to get probabilities
51
        preds_probs = torch.sigmoid(logits)
52
53
        # Get predictions as boolean values
54
        preds = preds_probs > 0.5
55
56
        # calculate jaccard index
57
        jaccard = jaccard_score(y.cpu().numpy(), preds.detach().cpu().numpy(), average='samples')
58
59
        class_report = classification_report(y.cpu().numpy(), preds.detach().cpu().numpy(), output_dict=True)
60
        # scores = class_report['micro avg']
61
        scores = class_report['macro avg']
62
        metrics_per_label = {label: metrics for label, metrics in class_report.items() if label.isdigit()}
63
64
        f1 = scores['f1-score']
65
        rec = scores['recall']
66
        prec = scores['precision']
67
        acc = accuracy_score(y.cpu().numpy().flatten(), preds.detach().cpu().numpy().flatten())
68
        try:
69
            auc = roc_auc_score(y.cpu().numpy().flatten(), preds_probs.detach().cpu().numpy().flatten())
70
        except Exception as e:
71
            auc = 0.
72
73
        return loss, acc, f1, rec, prec, jaccard, auc, metrics_per_label
74
75
    def training_step(self, batch, batch_idx):
76
        loss, acc, f1, rec, prec, jaccard, auc, _ = self.step(batch, batch_idx)
77
        train_stats = {'loss': loss, 'train_acc': acc, 'train_f1': f1, 'train_rec': rec, 'train_prec': prec, 'train_jaccard': jaccard,
78
                       'train_auc': auc}
79
        wandb_run.log(train_stats)
80
        return train_stats
81
82
    def training_epoch_end(self, outputs):
83
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
84
        avg_acc = np.mean([x['train_acc'] for x in outputs])
85
        avg_f1 = np.mean([x['train_f1'] for x in outputs])
86
        avg_rec = np.mean([x['train_rec'] for x in outputs])
87
        avg_prec = np.mean([x['train_prec'] for x in outputs])
88
        avg_jaccard = np.mean([x['train_jaccard'] for x in outputs])
89
        avg_auc = np.mean([x['train_auc'] for x in outputs])
90
        wandb_run.log({'epoch_train_loss': avg_loss, 'epoch_train_acc': avg_acc, 'epoch_train_f1': avg_f1, 'epoch_train_rec': avg_rec,
91
                       'epoch_train_prec': avg_prec, 'epoch_train_jaccard': avg_jaccard, 'epoch_train_auc': avg_auc})
92
93
    def validation_step(self, batch, batch_idx):
94
        loss, acc, f1, rec, prec, jaccard, auc, metrics_per_label = self.step(batch, batch_idx)
95
        # log f1 for checkpoint callback
96
        self.log('val_f1', f1)
97
        return {'val_loss': loss, 'val_acc': acc, 'val_f1': f1, 'val_rec': rec, 'val_prec': prec, 'val_jaccard': jaccard,
98
                'val_auc': auc}, metrics_per_label
99
100
    def validation_epoch_end(self, outputs):
101
        outputs, per_label_metrics_outputs = zip(*outputs)
102
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
103
        avg_acc = np.mean([x['val_acc'] for x in outputs])
104
        avg_f1 = np.mean([x['val_f1'] for x in outputs])
105
        avg_rec = np.mean([x['val_rec'] for x in outputs])
106
        avg_prec = np.mean([x['val_prec'] for x in outputs])
107
        avg_jaccard = np.mean([x['val_jaccard'] for x in outputs])
108
        avg_auc = np.mean([x['val_auc'] for x in outputs])
109
110
        per_label_metrics = defaultdict(lambda: defaultdict(float))
111
        label_counts = defaultdict(int)
112
        for metrics_per_label in per_label_metrics_outputs:
113
            for label, metrics in metrics_per_label.items():
114
                label_name = self.class_names[int(label)]
115
                per_label_metrics[label_name]['precision'] += metrics['precision']
116
                per_label_metrics[label_name]['recall'] += metrics['recall']
117
                per_label_metrics[label_name]['f1-score'] += metrics['f1-score']
118
                per_label_metrics[label_name]['support'] += metrics['support']
119
                label_counts[label_name] += 1
120
121
        # Average the metrics
122
        for label, metrics in per_label_metrics.items():
123
            for metric_name in ['precision', 'recall', 'f1-score']:
124
                if metrics['support'] > 0:
125
                    per_label_metrics[label][metric_name] /= label_counts[label]
126
127
        val_stats = {'val_loss': avg_loss, 'val_acc': avg_acc, 'val_f1': avg_f1, 'val_rec': avg_rec, 'val_prec': avg_prec, 'val_jaccard': avg_jaccard,
128
                     'val_auc': avg_auc}
129
        wandb_run.log(val_stats)
130
131
    def test_step(self, batch, batch_idx):
132
        loss, acc, f1, rec, prec, jaccard, auc, _ = self.step(batch, batch_idx)
133
        return {'test_loss': loss, 'test_acc': acc, 'test_f1': f1, 'test_rec': rec, 'test_prec': prec, 'test_jaccard': jaccard, 'test_auc': auc}
134
135
    def test_epoch_end(self, outputs):
136
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
137
        avg_acc = np.mean([x['test_acc'] for x in outputs])
138
        avg_f1 = np.mean([x['test_f1'] for x in outputs])
139
        avg_rec = np.mean([x['test_rec'] for x in outputs])
140
        avg_prec = np.mean([x['test_prec'] for x in outputs])
141
        avg_jaccard = np.mean([x['test_jaccard'] for x in outputs])
142
        avg_auc = np.mean([x['test_auc'] for x in outputs])
143
144
        test_stats = {'test_loss': avg_loss, 'test_acc': avg_acc, 'test_f1': avg_f1, 'test_rec': avg_rec, 'test_prec': avg_prec,
145
                      'test_jaccard': avg_jaccard, 'test_auc': avg_auc}
146
        wandb_run.log(test_stats)
147
148
    def configure_optimizers(self):
149
        optimizer = AdamW(self.parameters(), lr=self.learning_rate)
150
        return optimizer
151
152
153
def save_preds(dataloader, split):
154
    # load checkpoint
155
    ckpt_path = f"findings_classifier/checkpoints/chexpert_train/ChexpertClassifier-epoch=06-val_f1=0.36.ckpt"
156
    model = LitIGClassifier.load_from_checkpoint(ckpt_path, num_classes=num_classes, class_weights=val_dataset.get_class_weights(),
157
                                                 class_names=class_names, learning_rate=args.lr)
158
    model.eval()
159
    model.cuda()
160
    model.half()
161
    class_names_np = np.asarray(class_names)
162
163
    # get predictions for all study ids
164
    structured_preds = {}
165
    for batch in tqdm(dataloader):
166
        dicom_ids = batch['dicom_id']
167
        logits = model(batch['image'].half().cuda())
168
        preds_probs = torch.sigmoid(logits)
169
        preds = preds_probs > 0.5
170
171
        # iterate over each study id in the batch
172
        for i, (dicom_id, pred) in enumerate(zip(dicom_ids, preds.detach().cpu())):
173
            # get all positive labels
174
            findings = class_names_np[pred].tolist()
175
            structured_preds[dicom_id] = findings
176
177
    # save predictions
178
    with open(f"findings_classifier/predictions/structured_preds_chexpert_log_weighting_macro_{split}.json", "w") as f:
179
        json.dump(structured_preds, f, indent=4)
180
181
182
if __name__ == '__main__':
183
    parser = argparse.ArgumentParser()
184
    parser.add_argument("--run_name", type=str, default="debug")
185
    parser.add_argument("--lr", type=float, default=5e-5)
186
    parser.add_argument("--epochs", type=int, default=6)
187
    parser.add_argument("--loss_weighting", type=str, default="log", choices=["lin", "log", "none"])
188
    parser.add_argument("--truncate", type=int, default=None)
189
    parser.add_argument("--batch_size", type=int, default=64)
190
    parser.add_argument("--num_workers", type=int, default=12)
191
    parser.add_argument("--use_augs", action="store_true", default=False)
192
    parser.add_argument("--train", action="store_true", default=False)
193
    args = parser.parse_args()
194
195
    TRAIN = args.train
196
197
    # fix all seeds
198
    pl.seed_everything(42, workers=True)
199
200
    # Create DataLoaders
201
    train_dataset = Chexpert_Dataset(split='train', truncate=args.truncate, loss_weighting=args.loss_weighting, use_augs=args.use_augs)
202
    val_dataset = Chexpert_Dataset(split='validate', truncate=args.truncate)
203
    test_dataset = Chexpert_Dataset(split='test')
204
205
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
206
    val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
207
    test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
208
209
    # Number of classes for IGClassifier
210
    num_classes = len(train_dataset.chexpert_cols)
211
    class_names = train_dataset.chexpert_cols
212
213
    if TRAIN:
214
        class_weights = torch.tensor(train_dataset.get_class_weights(), dtype=torch.float32)
215
        # Define the model
216
        lit_model = LitIGClassifier(num_classes, class_names=class_names, class_weights=class_weights, learning_rate=args.lr)
217
        print(summary(lit_model))
218
219
        # WandB logger
220
        wandb_run = wandb.init(
221
            project="ChexpertClassifier",
222
            entity= WANDB_ENTITY,
223
            name=args.run_name
224
        )
225
226
        # checkpoint callback
227
        checkpoint_callback = ModelCheckpoint(
228
            monitor='val_f1',
229
            dirpath=f'findings_classifier/checkpoints/{args.run_name}',
230
            filename='ChexpertClassifier-{epoch:02d}-{val_f1:.2f}',
231
            save_top_k=1,
232
            save_last=True,
233
            mode='max',
234
        )
235
        # Train the model
236
        trainer = pl.Trainer(max_epochs=args.epochs, gpus=1, callbacks=[checkpoint_callback], benchmark=False, deterministic=True, precision=16)
237
        trainer.fit(lit_model, train_dataloader, val_dataloader)
238
239
        # Test the model
240
        # trainer.validate(lit_model, val_dataloader, ckpt_path="checkpoints_IGCLassifier/lr_5e-5_to0_log_weighting_patches_augs_imgemb/IGClassifier-epoch=09-val_f1=0.65.ckpt")
241
    else:
242
        save_preds(train_dataloader, "train")
243
        save_preds(val_dataloader, "val")
244
        save_preds(test_dataloader, "test")