|
a |
|
b/findings_classifier/chexpert_train.py |
|
|
1 |
import os |
|
|
2 |
|
|
|
3 |
# os.environ["CUDA_VISIBLE_DEVICES"] = "6" |
|
|
4 |
import argparse |
|
|
5 |
import json |
|
|
6 |
from collections import defaultdict |
|
|
7 |
|
|
|
8 |
import numpy as np |
|
|
9 |
import pytorch_lightning as pl |
|
|
10 |
import torch |
|
|
11 |
import wandb |
|
|
12 |
from pytorch_lightning.callbacks import ModelCheckpoint |
|
|
13 |
from sklearn.metrics import accuracy_score, classification_report, jaccard_score, roc_auc_score |
|
|
14 |
from torch.nn import BCEWithLogitsLoss |
|
|
15 |
from torch.utils.data import DataLoader |
|
|
16 |
from torchinfo import summary |
|
|
17 |
from tqdm import tqdm |
|
|
18 |
from transformers import AdamW |
|
|
19 |
|
|
|
20 |
from findings_classifier.chexpert_dataset import Chexpert_Dataset |
|
|
21 |
from findings_classifier.chexpert_model import ChexpertClassifier |
|
|
22 |
from local_config import WANDB_ENTITY |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
class LitIGClassifier(pl.LightningModule): |
|
|
26 |
def __init__(self, num_classes, class_names, class_weights=None, learning_rate=1e-5): |
|
|
27 |
super().__init__() |
|
|
28 |
|
|
|
29 |
# Model |
|
|
30 |
self.model = ChexpertClassifier(num_classes) |
|
|
31 |
|
|
|
32 |
# Loss with class weights |
|
|
33 |
if class_weights is None: |
|
|
34 |
self.criterion = BCEWithLogitsLoss() |
|
|
35 |
else: |
|
|
36 |
self.criterion = BCEWithLogitsLoss(pos_weight=class_weights) |
|
|
37 |
|
|
|
38 |
# Learning rate |
|
|
39 |
self.learning_rate = learning_rate |
|
|
40 |
self.class_names = class_names |
|
|
41 |
|
|
|
42 |
def forward(self, x): |
|
|
43 |
return self.model(x) |
|
|
44 |
|
|
|
45 |
def step(self, batch, batch_idx): |
|
|
46 |
x, y = batch['image'].to(self.device), batch['labels'].to(self.device) |
|
|
47 |
logits = self(x) |
|
|
48 |
loss = self.criterion(logits, y) |
|
|
49 |
|
|
|
50 |
# Apply sigmoid to get probabilities |
|
|
51 |
preds_probs = torch.sigmoid(logits) |
|
|
52 |
|
|
|
53 |
# Get predictions as boolean values |
|
|
54 |
preds = preds_probs > 0.5 |
|
|
55 |
|
|
|
56 |
# calculate jaccard index |
|
|
57 |
jaccard = jaccard_score(y.cpu().numpy(), preds.detach().cpu().numpy(), average='samples') |
|
|
58 |
|
|
|
59 |
class_report = classification_report(y.cpu().numpy(), preds.detach().cpu().numpy(), output_dict=True) |
|
|
60 |
# scores = class_report['micro avg'] |
|
|
61 |
scores = class_report['macro avg'] |
|
|
62 |
metrics_per_label = {label: metrics for label, metrics in class_report.items() if label.isdigit()} |
|
|
63 |
|
|
|
64 |
f1 = scores['f1-score'] |
|
|
65 |
rec = scores['recall'] |
|
|
66 |
prec = scores['precision'] |
|
|
67 |
acc = accuracy_score(y.cpu().numpy().flatten(), preds.detach().cpu().numpy().flatten()) |
|
|
68 |
try: |
|
|
69 |
auc = roc_auc_score(y.cpu().numpy().flatten(), preds_probs.detach().cpu().numpy().flatten()) |
|
|
70 |
except Exception as e: |
|
|
71 |
auc = 0. |
|
|
72 |
|
|
|
73 |
return loss, acc, f1, rec, prec, jaccard, auc, metrics_per_label |
|
|
74 |
|
|
|
75 |
def training_step(self, batch, batch_idx): |
|
|
76 |
loss, acc, f1, rec, prec, jaccard, auc, _ = self.step(batch, batch_idx) |
|
|
77 |
train_stats = {'loss': loss, 'train_acc': acc, 'train_f1': f1, 'train_rec': rec, 'train_prec': prec, 'train_jaccard': jaccard, |
|
|
78 |
'train_auc': auc} |
|
|
79 |
wandb_run.log(train_stats) |
|
|
80 |
return train_stats |
|
|
81 |
|
|
|
82 |
def training_epoch_end(self, outputs): |
|
|
83 |
avg_loss = torch.stack([x['loss'] for x in outputs]).mean() |
|
|
84 |
avg_acc = np.mean([x['train_acc'] for x in outputs]) |
|
|
85 |
avg_f1 = np.mean([x['train_f1'] for x in outputs]) |
|
|
86 |
avg_rec = np.mean([x['train_rec'] for x in outputs]) |
|
|
87 |
avg_prec = np.mean([x['train_prec'] for x in outputs]) |
|
|
88 |
avg_jaccard = np.mean([x['train_jaccard'] for x in outputs]) |
|
|
89 |
avg_auc = np.mean([x['train_auc'] for x in outputs]) |
|
|
90 |
wandb_run.log({'epoch_train_loss': avg_loss, 'epoch_train_acc': avg_acc, 'epoch_train_f1': avg_f1, 'epoch_train_rec': avg_rec, |
|
|
91 |
'epoch_train_prec': avg_prec, 'epoch_train_jaccard': avg_jaccard, 'epoch_train_auc': avg_auc}) |
|
|
92 |
|
|
|
93 |
def validation_step(self, batch, batch_idx): |
|
|
94 |
loss, acc, f1, rec, prec, jaccard, auc, metrics_per_label = self.step(batch, batch_idx) |
|
|
95 |
# log f1 for checkpoint callback |
|
|
96 |
self.log('val_f1', f1) |
|
|
97 |
return {'val_loss': loss, 'val_acc': acc, 'val_f1': f1, 'val_rec': rec, 'val_prec': prec, 'val_jaccard': jaccard, |
|
|
98 |
'val_auc': auc}, metrics_per_label |
|
|
99 |
|
|
|
100 |
def validation_epoch_end(self, outputs): |
|
|
101 |
outputs, per_label_metrics_outputs = zip(*outputs) |
|
|
102 |
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() |
|
|
103 |
avg_acc = np.mean([x['val_acc'] for x in outputs]) |
|
|
104 |
avg_f1 = np.mean([x['val_f1'] for x in outputs]) |
|
|
105 |
avg_rec = np.mean([x['val_rec'] for x in outputs]) |
|
|
106 |
avg_prec = np.mean([x['val_prec'] for x in outputs]) |
|
|
107 |
avg_jaccard = np.mean([x['val_jaccard'] for x in outputs]) |
|
|
108 |
avg_auc = np.mean([x['val_auc'] for x in outputs]) |
|
|
109 |
|
|
|
110 |
per_label_metrics = defaultdict(lambda: defaultdict(float)) |
|
|
111 |
label_counts = defaultdict(int) |
|
|
112 |
for metrics_per_label in per_label_metrics_outputs: |
|
|
113 |
for label, metrics in metrics_per_label.items(): |
|
|
114 |
label_name = self.class_names[int(label)] |
|
|
115 |
per_label_metrics[label_name]['precision'] += metrics['precision'] |
|
|
116 |
per_label_metrics[label_name]['recall'] += metrics['recall'] |
|
|
117 |
per_label_metrics[label_name]['f1-score'] += metrics['f1-score'] |
|
|
118 |
per_label_metrics[label_name]['support'] += metrics['support'] |
|
|
119 |
label_counts[label_name] += 1 |
|
|
120 |
|
|
|
121 |
# Average the metrics |
|
|
122 |
for label, metrics in per_label_metrics.items(): |
|
|
123 |
for metric_name in ['precision', 'recall', 'f1-score']: |
|
|
124 |
if metrics['support'] > 0: |
|
|
125 |
per_label_metrics[label][metric_name] /= label_counts[label] |
|
|
126 |
|
|
|
127 |
val_stats = {'val_loss': avg_loss, 'val_acc': avg_acc, 'val_f1': avg_f1, 'val_rec': avg_rec, 'val_prec': avg_prec, 'val_jaccard': avg_jaccard, |
|
|
128 |
'val_auc': avg_auc} |
|
|
129 |
wandb_run.log(val_stats) |
|
|
130 |
|
|
|
131 |
def test_step(self, batch, batch_idx): |
|
|
132 |
loss, acc, f1, rec, prec, jaccard, auc, _ = self.step(batch, batch_idx) |
|
|
133 |
return {'test_loss': loss, 'test_acc': acc, 'test_f1': f1, 'test_rec': rec, 'test_prec': prec, 'test_jaccard': jaccard, 'test_auc': auc} |
|
|
134 |
|
|
|
135 |
def test_epoch_end(self, outputs): |
|
|
136 |
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() |
|
|
137 |
avg_acc = np.mean([x['test_acc'] for x in outputs]) |
|
|
138 |
avg_f1 = np.mean([x['test_f1'] for x in outputs]) |
|
|
139 |
avg_rec = np.mean([x['test_rec'] for x in outputs]) |
|
|
140 |
avg_prec = np.mean([x['test_prec'] for x in outputs]) |
|
|
141 |
avg_jaccard = np.mean([x['test_jaccard'] for x in outputs]) |
|
|
142 |
avg_auc = np.mean([x['test_auc'] for x in outputs]) |
|
|
143 |
|
|
|
144 |
test_stats = {'test_loss': avg_loss, 'test_acc': avg_acc, 'test_f1': avg_f1, 'test_rec': avg_rec, 'test_prec': avg_prec, |
|
|
145 |
'test_jaccard': avg_jaccard, 'test_auc': avg_auc} |
|
|
146 |
wandb_run.log(test_stats) |
|
|
147 |
|
|
|
148 |
def configure_optimizers(self): |
|
|
149 |
optimizer = AdamW(self.parameters(), lr=self.learning_rate) |
|
|
150 |
return optimizer |
|
|
151 |
|
|
|
152 |
|
|
|
153 |
def save_preds(dataloader, split): |
|
|
154 |
# load checkpoint |
|
|
155 |
ckpt_path = f"findings_classifier/checkpoints/chexpert_train/ChexpertClassifier-epoch=06-val_f1=0.36.ckpt" |
|
|
156 |
model = LitIGClassifier.load_from_checkpoint(ckpt_path, num_classes=num_classes, class_weights=val_dataset.get_class_weights(), |
|
|
157 |
class_names=class_names, learning_rate=args.lr) |
|
|
158 |
model.eval() |
|
|
159 |
model.cuda() |
|
|
160 |
model.half() |
|
|
161 |
class_names_np = np.asarray(class_names) |
|
|
162 |
|
|
|
163 |
# get predictions for all study ids |
|
|
164 |
structured_preds = {} |
|
|
165 |
for batch in tqdm(dataloader): |
|
|
166 |
dicom_ids = batch['dicom_id'] |
|
|
167 |
logits = model(batch['image'].half().cuda()) |
|
|
168 |
preds_probs = torch.sigmoid(logits) |
|
|
169 |
preds = preds_probs > 0.5 |
|
|
170 |
|
|
|
171 |
# iterate over each study id in the batch |
|
|
172 |
for i, (dicom_id, pred) in enumerate(zip(dicom_ids, preds.detach().cpu())): |
|
|
173 |
# get all positive labels |
|
|
174 |
findings = class_names_np[pred].tolist() |
|
|
175 |
structured_preds[dicom_id] = findings |
|
|
176 |
|
|
|
177 |
# save predictions |
|
|
178 |
with open(f"findings_classifier/predictions/structured_preds_chexpert_log_weighting_macro_{split}.json", "w") as f: |
|
|
179 |
json.dump(structured_preds, f, indent=4) |
|
|
180 |
|
|
|
181 |
|
|
|
182 |
if __name__ == '__main__': |
|
|
183 |
parser = argparse.ArgumentParser() |
|
|
184 |
parser.add_argument("--run_name", type=str, default="debug") |
|
|
185 |
parser.add_argument("--lr", type=float, default=5e-5) |
|
|
186 |
parser.add_argument("--epochs", type=int, default=6) |
|
|
187 |
parser.add_argument("--loss_weighting", type=str, default="log", choices=["lin", "log", "none"]) |
|
|
188 |
parser.add_argument("--truncate", type=int, default=None) |
|
|
189 |
parser.add_argument("--batch_size", type=int, default=64) |
|
|
190 |
parser.add_argument("--num_workers", type=int, default=12) |
|
|
191 |
parser.add_argument("--use_augs", action="store_true", default=False) |
|
|
192 |
parser.add_argument("--train", action="store_true", default=False) |
|
|
193 |
args = parser.parse_args() |
|
|
194 |
|
|
|
195 |
TRAIN = args.train |
|
|
196 |
|
|
|
197 |
# fix all seeds |
|
|
198 |
pl.seed_everything(42, workers=True) |
|
|
199 |
|
|
|
200 |
# Create DataLoaders |
|
|
201 |
train_dataset = Chexpert_Dataset(split='train', truncate=args.truncate, loss_weighting=args.loss_weighting, use_augs=args.use_augs) |
|
|
202 |
val_dataset = Chexpert_Dataset(split='validate', truncate=args.truncate) |
|
|
203 |
test_dataset = Chexpert_Dataset(split='test') |
|
|
204 |
|
|
|
205 |
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) |
|
|
206 |
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.num_workers) |
|
|
207 |
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) |
|
|
208 |
|
|
|
209 |
# Number of classes for IGClassifier |
|
|
210 |
num_classes = len(train_dataset.chexpert_cols) |
|
|
211 |
class_names = train_dataset.chexpert_cols |
|
|
212 |
|
|
|
213 |
if TRAIN: |
|
|
214 |
class_weights = torch.tensor(train_dataset.get_class_weights(), dtype=torch.float32) |
|
|
215 |
# Define the model |
|
|
216 |
lit_model = LitIGClassifier(num_classes, class_names=class_names, class_weights=class_weights, learning_rate=args.lr) |
|
|
217 |
print(summary(lit_model)) |
|
|
218 |
|
|
|
219 |
# WandB logger |
|
|
220 |
wandb_run = wandb.init( |
|
|
221 |
project="ChexpertClassifier", |
|
|
222 |
entity= WANDB_ENTITY, |
|
|
223 |
name=args.run_name |
|
|
224 |
) |
|
|
225 |
|
|
|
226 |
# checkpoint callback |
|
|
227 |
checkpoint_callback = ModelCheckpoint( |
|
|
228 |
monitor='val_f1', |
|
|
229 |
dirpath=f'findings_classifier/checkpoints/{args.run_name}', |
|
|
230 |
filename='ChexpertClassifier-{epoch:02d}-{val_f1:.2f}', |
|
|
231 |
save_top_k=1, |
|
|
232 |
save_last=True, |
|
|
233 |
mode='max', |
|
|
234 |
) |
|
|
235 |
# Train the model |
|
|
236 |
trainer = pl.Trainer(max_epochs=args.epochs, gpus=1, callbacks=[checkpoint_callback], benchmark=False, deterministic=True, precision=16) |
|
|
237 |
trainer.fit(lit_model, train_dataloader, val_dataloader) |
|
|
238 |
|
|
|
239 |
# Test the model |
|
|
240 |
# trainer.validate(lit_model, val_dataloader, ckpt_path="checkpoints_IGCLassifier/lr_5e-5_to0_log_weighting_patches_augs_imgemb/IGClassifier-epoch=09-val_f1=0.65.ckpt") |
|
|
241 |
else: |
|
|
242 |
save_preds(train_dataloader, "train") |
|
|
243 |
save_preds(val_dataloader, "val") |
|
|
244 |
save_preds(test_dataloader, "test") |