Diff of /train.py [000000] .. [27805f]

Switch to unified view

a b/train.py
1
import torch
2
import torch.nn as nn
3
from torch.optim import AdamW
4
from torch.utils.data import DataLoader
5
from transformers import get_linear_schedule_with_warmup
6
from tqdm import tqdm
7
import logging
8
import wandb
9
from pathlib import Path
10
from typing import Dict, Any
11
from torch.cuda.amp import autocast, GradScaler
12
from datetime import datetime
13
14
from data2 import data_processing
15
from alignment_model import ImageTextAlignmentModel
16
from report_generator import MedicalReportGenerator
17
from biovil_t.pretrained import get_biovil_t_image_encoder  # Ensure this import path is correct
18
from rouge_score import rouge_scorer
19
20
def train_epoch(image_encoder, alignment_model, report_generator, train_loader,
21
                contrastive_loss, alignment_optimizer, generator_optimizer,
22
                alignment_scheduler, generator_scheduler, scaler, device,
23
                gradient_accumulation_steps, max_grad_norm, epoch):
24
    alignment_model.train()
25
    report_generator.train()
26
    image_encoder.eval()
27
28
    # Metrics tracking
29
    total_train_loss = 0.0
30
    total_align_loss = 0.0
31
    total_gen_loss = 0.0
32
    total_samples = 0
33
34
    progress_bar = tqdm(train_loader, desc=f'Training Epoch {epoch}')
35
36
    for batch_idx, (images, findings_texts, findings_lists) in enumerate(progress_bar):
37
        images = images.to(device)
38
        batch_size = images.size(0)
39
        total_samples += batch_size
40
41
        # Get image embeddings
42
        with torch.no_grad():
43
            image_embeddings = image_encoder(images).img_embedding
44
45
        # Create prompts using findings_lists (for generation)
46
        batch_prompts = [
47
            f"Findings: {', '.join(findings) if findings else 'No Findings'}."
48
            for findings in findings_lists
49
        ]
50
51
        # Use findings_texts (actual findings) for alignment
52
        actual_findings = findings_texts
53
54
        # Mixed precision training
55
        with autocast():
56
            # Alignment phase
57
            projected_image, projected_text = alignment_model(image_embeddings, actual_findings)
58
59
            # Contrastive loss
60
            labels = torch.ones(batch_size).to(device)
61
            align_loss = contrastive_loss(projected_image, projected_text, labels)
62
            align_loss = align_loss / gradient_accumulation_steps
63
64
        # Scale and accumulate alignment gradients
65
        scaler.scale(align_loss).backward()
66
67
        # Generation phase
68
69
        # Tokenize the prompts
70
        prompt_encoding = report_generator.tokenizer(
71
            batch_prompts,
72
            padding=True,
73
            truncation=True,
74
            return_tensors="pt",
75
            max_length=512
76
        ).to(device)
77
78
        # Tokenize target texts (actual findings)
79
        target_encoding = report_generator.tokenizer(
80
            actual_findings,
81
            padding=True,
82
            truncation=True,
83
            return_tensors="pt",
84
            max_length=512
85
        ).to(device)
86
87
        with autocast():
88
            gen_loss, _ = report_generator(
89
                image_embeddings=image_embeddings.detach(),
90
                prompt_input_ids=prompt_encoding['input_ids'],
91
                target_ids=target_encoding['input_ids']
92
            )
93
            gen_loss = gen_loss / gradient_accumulation_steps
94
95
        # Scale and accumulate generator gradients
96
        scaler.scale(gen_loss).backward()
97
98
        # Update metrics
99
        total_align_loss += align_loss.item() * gradient_accumulation_steps * batch_size
100
        total_gen_loss += gen_loss.item() * gradient_accumulation_steps * batch_size
101
        total_train_loss += (align_loss.item() + gen_loss.item()) * gradient_accumulation_steps * batch_size
102
103
        # Step optimizers and schedulers
104
        if (batch_idx + 1) % gradient_accumulation_steps == 0:
105
            # Unscale gradients
106
            scaler.unscale_(alignment_optimizer)
107
            scaler.unscale_(generator_optimizer)
108
109
            # Clip gradients
110
            torch.nn.utils.clip_grad_norm_(
111
                alignment_model.parameters(), max_grad_norm
112
            )
113
            torch.nn.utils.clip_grad_norm_(
114
                report_generator.parameters(), max_grad_norm
115
            )
116
117
            # Step optimizers
118
            scaler.step(alignment_optimizer)
119
            scaler.step(generator_optimizer)
120
            scaler.update()
121
122
            # Zero gradients
123
            alignment_optimizer.zero_grad()
124
            generator_optimizer.zero_grad()
125
126
            # Step schedulers
127
            alignment_scheduler.step()
128
            generator_scheduler.step()
129
130
        # Update progress bar
131
        progress_bar.set_postfix({
132
            'align_loss': f"{align_loss.item():.4f}",
133
            'gen_loss': f"{gen_loss.item():.4f}"
134
        })
135
136
    epoch_align_loss = total_align_loss / total_samples
137
    epoch_gen_loss = total_gen_loss / total_samples
138
    epoch_train_loss = total_train_loss / total_samples
139
140
    return {
141
        'train_loss': epoch_train_loss,
142
        'train_align_loss': epoch_align_loss,
143
        'train_gen_loss': epoch_gen_loss,
144
    }
145
146
def validate_epoch(image_encoder, alignment_model, report_generator, val_loader,
147
                   contrastive_loss, device, epoch):
148
    alignment_model.eval()
149
    report_generator.eval()
150
    image_encoder.eval()
151
152
    # Metrics storage
153
    total_val_loss = 0.0
154
    total_align_loss = 0.0
155
    total_gen_loss = 0.0
156
    total_samples = 0
157
    all_generated = []
158
    all_references = []
159
160
    with torch.no_grad():
161
        progress_bar = tqdm(val_loader, desc=f'Validation Epoch {epoch}')
162
163
        for batch_idx, (images, findings_texts, findings_lists) in enumerate(progress_bar):
164
            images = images.to(device)
165
            batch_size = images.size(0)
166
            total_samples += batch_size
167
168
            # Get image embeddings
169
            image_embeddings = image_encoder(images).img_embedding
170
171
            # Create prompts using findings_lists
172
            batch_prompts = [
173
                f"Findings: {', '.join(findings) if findings else 'No Findings'}."
174
                for findings in findings_lists
175
            ]
176
177
            # Actual findings for alignment and reference
178
            actual_findings = findings_texts
179
180
            # Alignment phase
181
            projected_image, projected_text = alignment_model(image_embeddings, actual_findings)
182
            labels = torch.ones(batch_size).to(device)
183
            align_loss = contrastive_loss(projected_image, projected_text, labels)
184
185
            # Generation phase
186
            prompt_encoding = report_generator.tokenizer(
187
                batch_prompts,
188
                padding=True,
189
                truncation=True,
190
                return_tensors="pt",
191
                max_length=512
192
            ).to(device)
193
194
            target_encoding = report_generator.tokenizer(
195
                actual_findings,
196
                padding=True,
197
                truncation=True,
198
                return_tensors="pt",
199
                max_length=512
200
            ).to(device)
201
202
            # Compute generation loss
203
            gen_loss, _ = report_generator(
204
                image_embeddings=image_embeddings,
205
                prompt_input_ids=prompt_encoding['input_ids'],
206
                target_ids=target_encoding['input_ids']
207
            )
208
209
            # Generate text for evaluation
210
            generated_texts = report_generator(
211
                image_embeddings=image_embeddings,
212
                prompt_input_ids=prompt_encoding['input_ids'],
213
                target_ids=None
214
            )
215
216
            # Store the generated and reference texts for ROUGE calculation
217
            all_generated.extend(generated_texts)
218
            all_references.extend(actual_findings)
219
220
            # Update totals
221
            total_align_loss += align_loss.item() * batch_size
222
            total_gen_loss += gen_loss.item() * batch_size
223
            total_val_loss += (align_loss.item() + gen_loss.item()) * batch_size
224
225
            # Print sample generation
226
            if batch_idx % 10 == 0:
227
                print(f"\nSample Generation (Batch {batch_idx}):")
228
                print(f"Generated: {generated_texts[0]}")
229
                print(f"Reference: {actual_findings[0]}")
230
                # Also display the pathologies findings from findings_lists
231
                print(f"Pathologies/Findings List: {findings_lists[0]}\n")
232
233
        # Calculate overall metrics
234
        epoch_align_loss = total_align_loss / total_samples
235
        epoch_gen_loss = total_gen_loss / total_samples
236
        epoch_val_loss = total_val_loss / total_samples
237
238
    # Compute ROUGE-L
239
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
240
    rouge_l_scores = []
241
    for ref, gen in zip(all_references, all_generated):
242
        score = scorer.score(ref, gen)['rougeL'].fmeasure
243
        rouge_l_scores.append(score)
244
    avg_rouge_l = sum(rouge_l_scores) / len(rouge_l_scores) if rouge_l_scores else 0.0
245
246
    # Display validation losses and ROUGE-L
247
    print(f"\nEpoch {epoch} Validation Metrics:")
248
    print(f"Validation Loss: {epoch_val_loss:.4f}")
249
    print(f"Alignment Loss: {epoch_align_loss:.4f}")
250
    print(f"Generation Loss: {epoch_gen_loss:.4f}")
251
    print(f"ROUGE-L: {avg_rouge_l:.4f}")
252
253
    return {
254
        'val_loss': epoch_val_loss,
255
        'val_align_loss': epoch_align_loss,
256
        'val_gen_loss': epoch_gen_loss,
257
        'val_rouge_l': avg_rouge_l
258
    }
259
260
261
def train_model(
262
        csv_with_image_paths: str,
263
        csv_with_labels: str,
264
        num_epochs: int = 30,
265
        batch_size: int = 8,
266
        train_split: float = 0.85,
267
        num_workers: int = 4,
268
        learning_rate: float = 2e-4,
269
        warmup_steps: int = 1000,
270
        gradient_accumulation_steps: int = 4,
271
        max_grad_norm: float = 1.0,
272
        use_wandb: bool = True,
273
        checkpoint_dir: str = "checkpoints",
274
        seed: int = 42
275
):
276
    # Set device
277
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
278
    print(f"Using device: {device}")
279
280
    # Initialize models
281
    image_encoder = get_biovil_t_image_encoder()
282
    alignment_model = ImageTextAlignmentModel(image_embedding_dim=512)
283
    report_generator = MedicalReportGenerator(image_embedding_dim=512)
284
285
    # Move models to device
286
    image_encoder = image_encoder.to(device)
287
    alignment_model = alignment_model.to(device)
288
    report_generator = report_generator.to(device)
289
290
    # Initialize wandb
291
    if use_wandb:
292
        wandb.init(
293
            project="medical-report-generation",
294
            config={
295
                "learning_rate": learning_rate,
296
                "epochs": num_epochs,
297
                "batch_size": batch_size,
298
                "warmup_steps": warmup_steps,
299
                "gradient_accumulation_steps": gradient_accumulation_steps,
300
            }
301
        )
302
        wandb.watch(models=[alignment_model, report_generator], log="all")
303
304
    # Get dataloaders
305
    train_loader, val_loader = data_processing.get_dataloaders(
306
        csv_with_image_paths=csv_with_image_paths,
307
        csv_with_labels=csv_with_labels,
308
        batch_size=batch_size,
309
        train_split=train_split,
310
        num_workers=num_workers,
311
        seed=seed,
312
    )
313
314
    # Initialize optimizers
315
    alignment_optimizer = AdamW(
316
        alignment_model.parameters(),
317
        lr=learning_rate,
318
        weight_decay=0.01
319
    )
320
    generator_optimizer = AdamW([
321
        {'params': report_generator.model.parameters(), 'lr': learning_rate},
322
        {'params': report_generator.image_projection.parameters(), 'lr': learning_rate * 10}
323
    ])
324
325
    # Initialize schedulers
326
    num_training_steps = len(train_loader) * num_epochs // gradient_accumulation_steps
327
    alignment_scheduler = get_linear_schedule_with_warmup(
328
        alignment_optimizer,
329
        num_warmup_steps=warmup_steps,
330
        num_training_steps=num_training_steps
331
    )
332
    generator_scheduler = get_linear_schedule_with_warmup(
333
        generator_optimizer,
334
        num_warmup_steps=warmup_steps,
335
        num_training_steps=num_training_steps
336
    )
337
338
    # Initialize loss function and scaler
339
    contrastive_loss = nn.CosineEmbeddingLoss()
340
    scaler = GradScaler()
341
342
    # Create checkpoint directory
343
    checkpoint_dir = Path(checkpoint_dir)
344
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
345
346
    for epoch in range(num_epochs):
347
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
348
349
        # Training phase
350
        train_metrics = train_epoch(
351
            image_encoder=image_encoder,
352
            alignment_model=alignment_model,
353
            report_generator=report_generator,
354
            train_loader=train_loader,
355
            contrastive_loss=contrastive_loss,
356
            alignment_optimizer=alignment_optimizer,
357
            generator_optimizer=generator_optimizer,
358
            alignment_scheduler=alignment_scheduler,
359
            generator_scheduler=generator_scheduler,
360
            scaler=scaler,
361
            device=device,
362
            gradient_accumulation_steps=gradient_accumulation_steps,
363
            max_grad_norm=max_grad_norm,
364
            epoch=epoch + 1
365
        )
366
367
        # Validation phase
368
        val_metrics = validate_epoch(
369
            image_encoder=image_encoder,
370
            alignment_model=alignment_model,
371
            report_generator=report_generator,
372
            val_loader=val_loader,
373
            contrastive_loss=contrastive_loss,
374
            device=device,
375
            epoch=epoch + 1
376
        )
377
378
        # Display training and validation losses
379
        print(f"\nEpoch {epoch + 1} Training Loss: {train_metrics['train_loss']:.4f}")
380
        print(f"Epoch {epoch + 1} Validation Loss: {val_metrics['val_loss']:.4f}")
381
        print(f"Alignment Loss - Train: {train_metrics['train_align_loss']:.4f}, Val: {val_metrics['val_align_loss']:.4f}")
382
        print(f"Generation Loss - Train: {train_metrics['train_gen_loss']:.4f}, Val: {val_metrics['val_gen_loss']:.4f}")
383
        print(f"ROUGE-L (Val): {val_metrics['val_rouge_l']:.4f}")
384
385
        # Log metrics to wandb
386
        if use_wandb:
387
            wandb.log({**train_metrics, **val_metrics})
388
389
        # Save model checkpoint after each epoch
390
        checkpoint_save_path = checkpoint_dir / f"model_epoch_{epoch+1}.pt"
391
        torch.save({
392
            'epoch': epoch + 1,
393
            'image_encoder_state_dict': image_encoder.state_dict(),
394
            'alignment_model_state_dict': alignment_model.state_dict(),
395
            'report_generator_state_dict': report_generator.state_dict(),
396
            'alignment_optimizer_state_dict': alignment_optimizer.state_dict(),
397
            'generator_optimizer_state_dict': generator_optimizer.state_dict(),
398
            'alignment_scheduler_state_dict': alignment_scheduler.state_dict(),
399
            'generator_scheduler_state_dict': generator_scheduler.state_dict(),
400
            'scaler_state_dict': scaler.state_dict(),
401
            'config': {
402
                'learning_rate': learning_rate,
403
                'batch_size': batch_size,
404
                'gradient_accumulation_steps': gradient_accumulation_steps,
405
                'max_grad_norm': max_grad_norm,
406
            }
407
        }, checkpoint_save_path)
408
        logging.info(f"Saved checkpoint: {checkpoint_save_path}")
409
410
    if use_wandb:
411
        wandb.finish()
412
413
414
if __name__ == "__main__":
415
    logging.basicConfig(
416
        level=logging.INFO,
417
        format='%(asctime)s - %(levelname)s - %(message)s'
418
    )
419
420
    # Path to your CSV files
421
    csv_with_image_paths = "/home/ubuntu/NLP/NLP_Project/Temp_3_NLP/Data/final.csv"
422
    csv_with_labels = "/home/ubuntu/NLP/NLP_Project/Temp_3_NLP/Data/labeled_reports_with_images.csv"
423
424
    # Training configuration
425
    config = {
426
        'num_epochs': 30,
427
        'batch_size': 8,
428
        'learning_rate': 1e-4,
429
        'warmup_steps': 1000,
430
        'gradient_accumulation_steps': 4,
431
        'use_wandb': True,
432
        'checkpoint_dir': 'checkpoints',
433
        'seed': 42
434
    }
435
436
    # Start training
437
    train_model(
438
        csv_with_image_paths=csv_with_image_paths,
439
        csv_with_labels=csv_with_labels,
440
        **config
441
    )