--- a
+++ b/train.py
@@ -0,0 +1,441 @@
+import torch
+import torch.nn as nn
+from torch.optim import AdamW
+from torch.utils.data import DataLoader
+from transformers import get_linear_schedule_with_warmup
+from tqdm import tqdm
+import logging
+import wandb
+from pathlib import Path
+from typing import Dict, Any
+from torch.cuda.amp import autocast, GradScaler
+from datetime import datetime
+
+from data2 import data_processing
+from alignment_model import ImageTextAlignmentModel
+from report_generator import MedicalReportGenerator
+from biovil_t.pretrained import get_biovil_t_image_encoder  # Ensure this import path is correct
+from rouge_score import rouge_scorer
+
+def train_epoch(image_encoder, alignment_model, report_generator, train_loader,
+                contrastive_loss, alignment_optimizer, generator_optimizer,
+                alignment_scheduler, generator_scheduler, scaler, device,
+                gradient_accumulation_steps, max_grad_norm, epoch):
+    alignment_model.train()
+    report_generator.train()
+    image_encoder.eval()
+
+    # Metrics tracking
+    total_train_loss = 0.0
+    total_align_loss = 0.0
+    total_gen_loss = 0.0
+    total_samples = 0
+
+    progress_bar = tqdm(train_loader, desc=f'Training Epoch {epoch}')
+
+    for batch_idx, (images, findings_texts, findings_lists) in enumerate(progress_bar):
+        images = images.to(device)
+        batch_size = images.size(0)
+        total_samples += batch_size
+
+        # Get image embeddings
+        with torch.no_grad():
+            image_embeddings = image_encoder(images).img_embedding
+
+        # Create prompts using findings_lists (for generation)
+        batch_prompts = [
+            f"Findings: {', '.join(findings) if findings else 'No Findings'}."
+            for findings in findings_lists
+        ]
+
+        # Use findings_texts (actual findings) for alignment
+        actual_findings = findings_texts
+
+        # Mixed precision training
+        with autocast():
+            # Alignment phase
+            projected_image, projected_text = alignment_model(image_embeddings, actual_findings)
+
+            # Contrastive loss
+            labels = torch.ones(batch_size).to(device)
+            align_loss = contrastive_loss(projected_image, projected_text, labels)
+            align_loss = align_loss / gradient_accumulation_steps
+
+        # Scale and accumulate alignment gradients
+        scaler.scale(align_loss).backward()
+
+        # Generation phase
+
+        # Tokenize the prompts
+        prompt_encoding = report_generator.tokenizer(
+            batch_prompts,
+            padding=True,
+            truncation=True,
+            return_tensors="pt",
+            max_length=512
+        ).to(device)
+
+        # Tokenize target texts (actual findings)
+        target_encoding = report_generator.tokenizer(
+            actual_findings,
+            padding=True,
+            truncation=True,
+            return_tensors="pt",
+            max_length=512
+        ).to(device)
+
+        with autocast():
+            gen_loss, _ = report_generator(
+                image_embeddings=image_embeddings.detach(),
+                prompt_input_ids=prompt_encoding['input_ids'],
+                target_ids=target_encoding['input_ids']
+            )
+            gen_loss = gen_loss / gradient_accumulation_steps
+
+        # Scale and accumulate generator gradients
+        scaler.scale(gen_loss).backward()
+
+        # Update metrics
+        total_align_loss += align_loss.item() * gradient_accumulation_steps * batch_size
+        total_gen_loss += gen_loss.item() * gradient_accumulation_steps * batch_size
+        total_train_loss += (align_loss.item() + gen_loss.item()) * gradient_accumulation_steps * batch_size
+
+        # Step optimizers and schedulers
+        if (batch_idx + 1) % gradient_accumulation_steps == 0:
+            # Unscale gradients
+            scaler.unscale_(alignment_optimizer)
+            scaler.unscale_(generator_optimizer)
+
+            # Clip gradients
+            torch.nn.utils.clip_grad_norm_(
+                alignment_model.parameters(), max_grad_norm
+            )
+            torch.nn.utils.clip_grad_norm_(
+                report_generator.parameters(), max_grad_norm
+            )
+
+            # Step optimizers
+            scaler.step(alignment_optimizer)
+            scaler.step(generator_optimizer)
+            scaler.update()
+
+            # Zero gradients
+            alignment_optimizer.zero_grad()
+            generator_optimizer.zero_grad()
+
+            # Step schedulers
+            alignment_scheduler.step()
+            generator_scheduler.step()
+
+        # Update progress bar
+        progress_bar.set_postfix({
+            'align_loss': f"{align_loss.item():.4f}",
+            'gen_loss': f"{gen_loss.item():.4f}"
+        })
+
+    epoch_align_loss = total_align_loss / total_samples
+    epoch_gen_loss = total_gen_loss / total_samples
+    epoch_train_loss = total_train_loss / total_samples
+
+    return {
+        'train_loss': epoch_train_loss,
+        'train_align_loss': epoch_align_loss,
+        'train_gen_loss': epoch_gen_loss,
+    }
+
+def validate_epoch(image_encoder, alignment_model, report_generator, val_loader,
+                   contrastive_loss, device, epoch):
+    alignment_model.eval()
+    report_generator.eval()
+    image_encoder.eval()
+
+    # Metrics storage
+    total_val_loss = 0.0
+    total_align_loss = 0.0
+    total_gen_loss = 0.0
+    total_samples = 0
+    all_generated = []
+    all_references = []
+
+    with torch.no_grad():
+        progress_bar = tqdm(val_loader, desc=f'Validation Epoch {epoch}')
+
+        for batch_idx, (images, findings_texts, findings_lists) in enumerate(progress_bar):
+            images = images.to(device)
+            batch_size = images.size(0)
+            total_samples += batch_size
+
+            # Get image embeddings
+            image_embeddings = image_encoder(images).img_embedding
+
+            # Create prompts using findings_lists
+            batch_prompts = [
+                f"Findings: {', '.join(findings) if findings else 'No Findings'}."
+                for findings in findings_lists
+            ]
+
+            # Actual findings for alignment and reference
+            actual_findings = findings_texts
+
+            # Alignment phase
+            projected_image, projected_text = alignment_model(image_embeddings, actual_findings)
+            labels = torch.ones(batch_size).to(device)
+            align_loss = contrastive_loss(projected_image, projected_text, labels)
+
+            # Generation phase
+            prompt_encoding = report_generator.tokenizer(
+                batch_prompts,
+                padding=True,
+                truncation=True,
+                return_tensors="pt",
+                max_length=512
+            ).to(device)
+
+            target_encoding = report_generator.tokenizer(
+                actual_findings,
+                padding=True,
+                truncation=True,
+                return_tensors="pt",
+                max_length=512
+            ).to(device)
+
+            # Compute generation loss
+            gen_loss, _ = report_generator(
+                image_embeddings=image_embeddings,
+                prompt_input_ids=prompt_encoding['input_ids'],
+                target_ids=target_encoding['input_ids']
+            )
+
+            # Generate text for evaluation
+            generated_texts = report_generator(
+                image_embeddings=image_embeddings,
+                prompt_input_ids=prompt_encoding['input_ids'],
+                target_ids=None
+            )
+
+            # Store the generated and reference texts for ROUGE calculation
+            all_generated.extend(generated_texts)
+            all_references.extend(actual_findings)
+
+            # Update totals
+            total_align_loss += align_loss.item() * batch_size
+            total_gen_loss += gen_loss.item() * batch_size
+            total_val_loss += (align_loss.item() + gen_loss.item()) * batch_size
+
+            # Print sample generation
+            if batch_idx % 10 == 0:
+                print(f"\nSample Generation (Batch {batch_idx}):")
+                print(f"Generated: {generated_texts[0]}")
+                print(f"Reference: {actual_findings[0]}")
+                # Also display the pathologies findings from findings_lists
+                print(f"Pathologies/Findings List: {findings_lists[0]}\n")
+
+        # Calculate overall metrics
+        epoch_align_loss = total_align_loss / total_samples
+        epoch_gen_loss = total_gen_loss / total_samples
+        epoch_val_loss = total_val_loss / total_samples
+
+    # Compute ROUGE-L
+    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
+    rouge_l_scores = []
+    for ref, gen in zip(all_references, all_generated):
+        score = scorer.score(ref, gen)['rougeL'].fmeasure
+        rouge_l_scores.append(score)
+    avg_rouge_l = sum(rouge_l_scores) / len(rouge_l_scores) if rouge_l_scores else 0.0
+
+    # Display validation losses and ROUGE-L
+    print(f"\nEpoch {epoch} Validation Metrics:")
+    print(f"Validation Loss: {epoch_val_loss:.4f}")
+    print(f"Alignment Loss: {epoch_align_loss:.4f}")
+    print(f"Generation Loss: {epoch_gen_loss:.4f}")
+    print(f"ROUGE-L: {avg_rouge_l:.4f}")
+
+    return {
+        'val_loss': epoch_val_loss,
+        'val_align_loss': epoch_align_loss,
+        'val_gen_loss': epoch_gen_loss,
+        'val_rouge_l': avg_rouge_l
+    }
+
+
+def train_model(
+        csv_with_image_paths: str,
+        csv_with_labels: str,
+        num_epochs: int = 30,
+        batch_size: int = 8,
+        train_split: float = 0.85,
+        num_workers: int = 4,
+        learning_rate: float = 2e-4,
+        warmup_steps: int = 1000,
+        gradient_accumulation_steps: int = 4,
+        max_grad_norm: float = 1.0,
+        use_wandb: bool = True,
+        checkpoint_dir: str = "checkpoints",
+        seed: int = 42
+):
+    # Set device
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    print(f"Using device: {device}")
+
+    # Initialize models
+    image_encoder = get_biovil_t_image_encoder()
+    alignment_model = ImageTextAlignmentModel(image_embedding_dim=512)
+    report_generator = MedicalReportGenerator(image_embedding_dim=512)
+
+    # Move models to device
+    image_encoder = image_encoder.to(device)
+    alignment_model = alignment_model.to(device)
+    report_generator = report_generator.to(device)
+
+    # Initialize wandb
+    if use_wandb:
+        wandb.init(
+            project="medical-report-generation",
+            config={
+                "learning_rate": learning_rate,
+                "epochs": num_epochs,
+                "batch_size": batch_size,
+                "warmup_steps": warmup_steps,
+                "gradient_accumulation_steps": gradient_accumulation_steps,
+            }
+        )
+        wandb.watch(models=[alignment_model, report_generator], log="all")
+
+    # Get dataloaders
+    train_loader, val_loader = data_processing.get_dataloaders(
+        csv_with_image_paths=csv_with_image_paths,
+        csv_with_labels=csv_with_labels,
+        batch_size=batch_size,
+        train_split=train_split,
+        num_workers=num_workers,
+        seed=seed,
+    )
+
+    # Initialize optimizers
+    alignment_optimizer = AdamW(
+        alignment_model.parameters(),
+        lr=learning_rate,
+        weight_decay=0.01
+    )
+    generator_optimizer = AdamW([
+        {'params': report_generator.model.parameters(), 'lr': learning_rate},
+        {'params': report_generator.image_projection.parameters(), 'lr': learning_rate * 10}
+    ])
+
+    # Initialize schedulers
+    num_training_steps = len(train_loader) * num_epochs // gradient_accumulation_steps
+    alignment_scheduler = get_linear_schedule_with_warmup(
+        alignment_optimizer,
+        num_warmup_steps=warmup_steps,
+        num_training_steps=num_training_steps
+    )
+    generator_scheduler = get_linear_schedule_with_warmup(
+        generator_optimizer,
+        num_warmup_steps=warmup_steps,
+        num_training_steps=num_training_steps
+    )
+
+    # Initialize loss function and scaler
+    contrastive_loss = nn.CosineEmbeddingLoss()
+    scaler = GradScaler()
+
+    # Create checkpoint directory
+    checkpoint_dir = Path(checkpoint_dir)
+    checkpoint_dir.mkdir(parents=True, exist_ok=True)
+
+    for epoch in range(num_epochs):
+        print(f"\nEpoch {epoch + 1}/{num_epochs}")
+
+        # Training phase
+        train_metrics = train_epoch(
+            image_encoder=image_encoder,
+            alignment_model=alignment_model,
+            report_generator=report_generator,
+            train_loader=train_loader,
+            contrastive_loss=contrastive_loss,
+            alignment_optimizer=alignment_optimizer,
+            generator_optimizer=generator_optimizer,
+            alignment_scheduler=alignment_scheduler,
+            generator_scheduler=generator_scheduler,
+            scaler=scaler,
+            device=device,
+            gradient_accumulation_steps=gradient_accumulation_steps,
+            max_grad_norm=max_grad_norm,
+            epoch=epoch + 1
+        )
+
+        # Validation phase
+        val_metrics = validate_epoch(
+            image_encoder=image_encoder,
+            alignment_model=alignment_model,
+            report_generator=report_generator,
+            val_loader=val_loader,
+            contrastive_loss=contrastive_loss,
+            device=device,
+            epoch=epoch + 1
+        )
+
+        # Display training and validation losses
+        print(f"\nEpoch {epoch + 1} Training Loss: {train_metrics['train_loss']:.4f}")
+        print(f"Epoch {epoch + 1} Validation Loss: {val_metrics['val_loss']:.4f}")
+        print(f"Alignment Loss - Train: {train_metrics['train_align_loss']:.4f}, Val: {val_metrics['val_align_loss']:.4f}")
+        print(f"Generation Loss - Train: {train_metrics['train_gen_loss']:.4f}, Val: {val_metrics['val_gen_loss']:.4f}")
+        print(f"ROUGE-L (Val): {val_metrics['val_rouge_l']:.4f}")
+
+        # Log metrics to wandb
+        if use_wandb:
+            wandb.log({**train_metrics, **val_metrics})
+
+        # Save model checkpoint after each epoch
+        checkpoint_save_path = checkpoint_dir / f"model_epoch_{epoch+1}.pt"
+        torch.save({
+            'epoch': epoch + 1,
+            'image_encoder_state_dict': image_encoder.state_dict(),
+            'alignment_model_state_dict': alignment_model.state_dict(),
+            'report_generator_state_dict': report_generator.state_dict(),
+            'alignment_optimizer_state_dict': alignment_optimizer.state_dict(),
+            'generator_optimizer_state_dict': generator_optimizer.state_dict(),
+            'alignment_scheduler_state_dict': alignment_scheduler.state_dict(),
+            'generator_scheduler_state_dict': generator_scheduler.state_dict(),
+            'scaler_state_dict': scaler.state_dict(),
+            'config': {
+                'learning_rate': learning_rate,
+                'batch_size': batch_size,
+                'gradient_accumulation_steps': gradient_accumulation_steps,
+                'max_grad_norm': max_grad_norm,
+            }
+        }, checkpoint_save_path)
+        logging.info(f"Saved checkpoint: {checkpoint_save_path}")
+
+    if use_wandb:
+        wandb.finish()
+
+
+if __name__ == "__main__":
+    logging.basicConfig(
+        level=logging.INFO,
+        format='%(asctime)s - %(levelname)s - %(message)s'
+    )
+
+    # Path to your CSV files
+    csv_with_image_paths = "/home/ubuntu/NLP/NLP_Project/Temp_3_NLP/Data/final.csv"
+    csv_with_labels = "/home/ubuntu/NLP/NLP_Project/Temp_3_NLP/Data/labeled_reports_with_images.csv"
+
+    # Training configuration
+    config = {
+        'num_epochs': 30,
+        'batch_size': 8,
+        'learning_rate': 1e-4,
+        'warmup_steps': 1000,
+        'gradient_accumulation_steps': 4,
+        'use_wandb': True,
+        'checkpoint_dir': 'checkpoints',
+        'seed': 42
+    }
+
+    # Start training
+    train_model(
+        csv_with_image_paths=csv_with_image_paths,
+        csv_with_labels=csv_with_labels,
+        **config
+    )