Diff of /streamlit.py [000000] .. [27805f]

Switch to unified view

a b/streamlit.py
1
# app.py
2
3
import streamlit as st
4
import torch
5
from transformers import Blip2Processor, Blip2ForConditionalGeneration
6
from PIL import Image
7
from pathlib import Path
8
import logging
9
import sys
10
import os
11
import numpy as np
12
import torchvision.transforms as transforms
13
from typing import Union, List, Dict
14
15
# Import the MedicalReportGenerator from the appropriate modules with aliases
16
from report_generator_bioclip import MedicalReportGenerator as BioClipMedicalReportGenerator
17
from report_generator_concat import MedicalReportGenerator as BioViltMedicalReportGenerator
18
19
# Import the ModifiedCheXNet model class
20
from chexnet_train import ModifiedCheXNet
21
22
# Import BioVilt specific modules
23
from alignment_concat import ImageTextAlignmentModel
24
from biovil_t.pretrained import get_biovil_t_image_encoder  # Ensure this import path is correct
25
26
# Additional imports for BioVilt pipeline
27
import cv2
28
import albumentations as A
29
from albumentations.pytorch import ToTensorV2
30
import re
31
32
# Suppress excessive warnings from transformers and torchvision
33
import warnings
34
warnings.filterwarnings("ignore")
35
36
# To disable torchvision beta transforms warnings
37
import torchvision
38
if hasattr(torchvision, 'disable_beta_transforms_warning'):
39
    torchvision.disable_beta_transforms_warning()
40
41
# Import torchxrayvision
42
import torchxrayvision as xrv
43
44
# ---------------------- Grayscale Classification ---------------------- #
45
46
def is_grayscale(image: Image.Image, threshold: float = 90.0) -> bool:
47
    """
48
    Determine if the image is predominantly grayscale.
49
    Removed multiple checks and kept only one check 
50
    """
51
    try:
52
        # Ensure image is in RGB
53
        image = image.convert("RGB")
54
        w, h = image.size
55
        pixels = image.getdata()
56
        grayscale_pixels = sum(1 for pixel in pixels if pixel[0] == pixel[1] == pixel[2])
57
        total_pixels = w * h
58
        grayscale_percentage = (grayscale_pixels / total_pixels) * 100
59
        return grayscale_percentage > threshold
60
    except Exception as e:
61
        logging.error(f"Error in is_grayscale: {e}")
62
        return False
63
64
# ---------------------- Inference Pipelines ---------------------- #
65
66
class ChestXrayFullInference:
67
    def __init__(
68
        self,
69
        chexnet_model_path: str,
70
        blip2_model_name: str = "Salesforce/blip2-opt-2.7b",
71
        blip2_device_map: str = 'auto',
72
        chexnet_num_classes: int = 14,
73
        report_generator_checkpoint: str = None,
74
        device: str = None
75
    ):
76
        """
77
        Initialize the full inference pipeline with CheXNet, BLIP-2, and BioClip MedicalReportGenerator.
78
        
79
        Args:
80
            chexnet_model_path (str): Path to the trained CheXNet model checkpoint.
81
            blip2_model_name (str): Hugging Face model name for BLIP-2.
82
            blip2_device_map (str): Device mapping for BLIP-2 ('auto' by default).
83
            chexnet_num_classes (int): Number of classes for CheXNet.
84
            report_generator_checkpoint (str): Path to the BioClip MedicalReportGenerator checkpoint.
85
            device (str): Device to use ('cuda' or 'cpu').
86
        """
87
        self.logger = self._setup_logger()
88
        self.device = torch.device(device) if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
89
        self.logger.info(f"Using device: {self.device}")
90
91
        # Initialize CheXNet Predictor
92
        self.chexnet_predictor = self._initialize_chexnet(
93
            chexnet_model_path, chexnet_num_classes
94
        )
95
96
        # Initialize BLIP-2 Processor and Model
97
        self.processor, self.blip_model = self._initialize_blip2(
98
            blip2_model_name, blip2_device_map
99
        )
100
101
        # Initialize BioClip MedicalReportGenerator
102
        self.report_generator = self._initialize_report_generator(
103
            report_generator_checkpoint
104
        )
105
106
        # Define label columns
107
        self.label_columns = [
108
            'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity',
109
            'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia',
110
            'Atelectasis', 'Pneumothorax', 'Pleural Effusion',
111
            'Pleural Other', 'Fracture', 'Support Devices', 'No Finding'
112
        ]
113
114
    def _setup_logger(self) -> logging.Logger:
115
        """Set up logging configuration."""
116
        logger = logging.getLogger('ChestXrayFullInference')
117
        logger.setLevel(logging.INFO)
118
119
        if not logger.handlers:
120
            handler = logging.StreamHandler(sys.stdout)
121
            handler.setFormatter(logging.Formatter(
122
                '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
123
            ))
124
            logger.addHandler(handler)
125
126
        return logger
127
128
    def _initialize_chexnet(self, model_path: str, num_classes: int) -> ModifiedCheXNet:
129
        """Initialize the CheXNet model."""
130
        try:
131
            self.logger.info("Initializing CheXNet model...")
132
            chexnet = ModifiedCheXNet(num_classes=num_classes).to(self.device)
133
            checkpoint = torch.load(model_path, map_location=self.device)
134
135
            # Handle different checkpoint formats
136
            if 'model_state_dict' in checkpoint:
137
                chexnet.load_state_dict(checkpoint['model_state_dict'])
138
            else:
139
                chexnet.load_state_dict(checkpoint)
140
141
            chexnet.eval()
142
            self.logger.info("CheXNet model loaded successfully.")
143
            return chexnet
144
145
        except Exception as e:
146
            self.logger.error(f"Error initializing CheXNet model: {str(e)}")
147
            raise
148
149
    def _initialize_blip2(
150
        self, model_name: str, device_map: str
151
    ) -> (Blip2Processor, Blip2ForConditionalGeneration):
152
        """Initialize the BLIP-2 processor and model."""
153
        try:
154
            self.logger.info("Initializing BLIP-2 model and processor...")
155
            processor = Blip2Processor.from_pretrained(model_name, force_download=True)
156
            blip_model = Blip2ForConditionalGeneration.from_pretrained(
157
                model_name,
158
                torch_dtype=torch.float32,
159
                device_map=device_map
160
            )
161
            blip_model.eval()
162
            self.logger.info("BLIP-2 model and processor loaded successfully.")
163
            return processor, blip_model
164
165
        except Exception as e:
166
            self.logger.error(f"Error initializing BLIP-2 model: {str(e)}")
167
            raise
168
169
    def _initialize_report_generator(self, checkpoint_path: str) -> BioClipMedicalReportGenerator:
170
        """Initialize the BioClip MedicalReportGenerator."""
171
        try:
172
            self.logger.info("Initializing BioClip MedicalReportGenerator...")
173
            vision_hidden_size = self.blip_model.vision_model.config.hidden_size
174
            report_gen = BioClipMedicalReportGenerator(input_embedding_dim=vision_hidden_size)
175
176
            # Load trained weights
177
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
178
            report_gen.load_state_dict(checkpoint['model_state_dict'])
179
            report_gen.to(self.device)
180
            report_gen.eval()
181
            self.logger.info("BioClip MedicalReportGenerator loaded successfully.")
182
            return report_gen
183
184
        except Exception as e:
185
            self.logger.error(f"Error initializing BioClip MedicalReportGenerator: {str(e)}")
186
            raise
187
188
    def _get_transform(self) -> transforms.Compose:
189
        """Get the transformation pipeline for CheXNet."""
190
        return transforms.Compose([
191
            transforms.Resize(256),
192
            transforms.CenterCrop(224),
193
            transforms.ToTensor(),
194
            transforms.Normalize([0.485, 0.456, 0.406], 
195
                                 [0.229, 0.224, 0.225])
196
        ])
197
198
    def _convert_labels_to_findings(self, binary_labels: List[int]) -> str:
199
        """Convert binary labels to a comma-separated string of findings."""
200
        findings = [label for label, val in zip(self.label_columns, binary_labels) if val == 1]
201
        return ", ".join(findings) if findings else "No Findings"
202
203
    def predict_labels(self, image: Image.Image, threshold: float = 0.5) -> List[int]:
204
        """
205
        Predict binary labels for the given image using CheXNet.
206
        
207
        Args:
208
            image (PIL.Image.Image): Input image.
209
            threshold (float): Probability threshold for positive prediction.
210
        
211
        Returns:
212
            List[int]: Binary labels (0 or 1) for each condition.
213
        """
214
        try:
215
            self.logger.info("Predicting labels using CheXNet...")
216
            transform = self._get_transform()
217
            image_tensor = transform(image).unsqueeze(0).to(self.device)
218
219
            with torch.no_grad():
220
                output = self.chexnet_predictor(image_tensor)
221
                probabilities = torch.sigmoid(output).cpu().numpy()[0]
222
223
            binary_labels = [1 if prob >= threshold else 0 for prob in probabilities]
224
            self.logger.info(f"Predicted binary labels: {binary_labels}")
225
            return binary_labels
226
227
        except Exception as e:
228
            self.logger.error(f"Error predicting labels: {str(e)}")
229
            raise
230
231
    def extract_image_features(self, image: Image.Image) -> torch.Tensor:
232
        """
233
        Extract image features using BLIP-2.
234
        
235
        Args:
236
            image (PIL.Image.Image): Input image.
237
        
238
        Returns:
239
            torch.Tensor: Image features tensor.
240
        """
241
        try:
242
            self.logger.info("Extracting image features using BLIP-2...")
243
            processed = self.processor(images=image, return_tensors="pt")
244
            pixel_values = processed.pixel_values.to(self.device)
245
246
            with torch.no_grad():
247
                vision_outputs = self.blip_model.vision_model(pixel_values)
248
                image_features = vision_outputs.pooler_output
249
250
            self.logger.info(f"Extracted image features with shape: {image_features.shape}")
251
            return image_features
252
253
        except Exception as e:
254
            self.logger.error(f"Error extracting image features: {str(e)}")
255
            raise
256
257
    def generate_report(self, image: Union[str, Path, Image.Image], threshold: float = 0.5) -> Dict:
258
        """
259
        Generate a medical report for the given chest X-ray image.
260
        
261
        Args:
262
            image (str, Path, or PIL.Image.Image): Input image or path to the image.
263
            threshold (float): Probability threshold for positive prediction.
264
        
265
        Returns:
266
            Dict: Contains the generated report and binary labels.
267
        """
268
        try:
269
            if isinstance(image, (str, Path)):
270
                self.logger.info(f"Generating report for image path: {image}")
271
                image_path = Path(image)
272
                if not image_path.exists():
273
                    raise FileNotFoundError(f"Image file {image_path} does not exist.")
274
                # Load image
275
                image = Image.open(image_path).convert('RGB')
276
            elif isinstance(image, Image.Image):
277
                self.logger.info("Generating report for uploaded image.")
278
            else:
279
                raise TypeError("Image must be a string path or a PIL.Image.Image object.")
280
281
            # Predict labels
282
            binary_labels = self.predict_labels(image, threshold=threshold)
283
284
            # Convert binary labels to findings string
285
            findings = self._convert_labels_to_findings(binary_labels)
286
            prompt = f"Findings: {findings}."
287
288
            # Tokenize prompt
289
            self.logger.info("Tokenizing prompt...")
290
            prompt_encoding = self.report_generator.tokenizer(
291
                [prompt],
292
                padding=True,
293
                truncation=True,
294
                return_tensors="pt",
295
                max_length=512
296
            ).to(self.device)
297
298
            # Extract image features
299
            image_features = self.extract_image_features(image)
300
301
            # Start report generation
302
            self.logger.info("Starting report generation...")
303
            # Corrected: Do not pass 'prompt' argument
304
            generated_report = self.report_generator.generate_report(
305
                input_embeddings=image_features,
306
                labels=torch.tensor(binary_labels, dtype=torch.float32).unsqueeze(0).to(self.device)
307
            )
308
            self.logger.info("Report generation completed.")
309
310
            # Check if generated_report is a list or similar iterable
311
            if isinstance(generated_report, (list, tuple)):
312
                if len(generated_report) == 0:
313
                    raise ValueError("MedicalReportGenerator returned an empty report list.")
314
                generated_report_text = generated_report[0]
315
            elif isinstance(generated_report, str):
316
                generated_report_text = generated_report
317
            else:
318
                raise TypeError("MedicalReportGenerator.generate_report returned an unsupported type.")
319
320
            # Create labels dictionary
321
            labels_dict = {
322
                label: int(val) for label, val in zip(self.label_columns, binary_labels)
323
            }
324
325
            self.logger.info("Report generation successful.")
326
            return {
327
                'report': generated_report_text,
328
                'labels': labels_dict
329
            }
330
331
        except Exception as e:
332
            self.logger.error(f"Error generating report: {str(e)}")
333
            raise
334
335
336
class ChestXrayBioViltInference:
337
    def __init__(
338
        self,
339
        chexnet_model_path: str,
340
        biovilt_checkpoint_path: str,
341
        device: str = None
342
    ):
343
        """
344
        Initialize the inference pipeline with CheXNet and BioVilt + BioGPT.
345
        
346
        Args:
347
            chexnet_model_path (str): Path to the trained CheXNet model checkpoint.
348
            biovilt_checkpoint_path (str): Path to the BioVilt + BioGPT model checkpoint.
349
            device (str): Device to use ('cuda' or 'cpu').
350
        """
351
        self.logger = self._setup_logger()
352
        self.device = torch.device(device) if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
353
        self.logger.info(f"Using device: {self.device}")
354
355
        # Initialize CheXNet Predictor
356
        self.chexnet_predictor = self._initialize_chexnet(
357
            chexnet_model_path, num_classes=14  # Corrected parameter name
358
        )
359
360
        # Initialize BioVilt components
361
        self.image_encoder, self.alignment_model, self.report_generator = self._initialize_biovilt(
362
            biovilt_checkpoint_path
363
        )
364
365
        # Define label columns
366
        self.label_columns = [
367
            'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity',
368
            'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia',
369
            'Atelectasis', 'Pneumothorax', 'Pleural Effusion',
370
            'Pleural Other', 'Fracture', 'Support Devices', 'No Finding'
371
        ]
372
373
    def _setup_logger(self) -> logging.Logger:
374
        """Set up logging configuration."""
375
        logger = logging.getLogger('ChestXrayBioViltInference')
376
        logger.setLevel(logging.INFO)
377
378
        if not logger.handlers:
379
            handler = logging.StreamHandler(sys.stdout)
380
            handler.setFormatter(logging.Formatter(
381
                '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
382
            ))
383
            logger.addHandler(handler)
384
385
        return logger
386
387
    def _initialize_chexnet(self, model_path: str, num_classes: int) -> ModifiedCheXNet:
388
        """Initialize the CheXNet model."""
389
        try:
390
            self.logger.info("Initializing CheXNet model for BioVilt pipeline...")
391
            chexnet = ModifiedCheXNet(num_classes=num_classes).to(self.device)
392
            checkpoint = torch.load(model_path, map_location=self.device)
393
394
            # Handle different checkpoint formats
395
            if 'model_state_dict' in checkpoint:
396
                chexnet.load_state_dict(checkpoint['model_state_dict'])
397
            else:
398
                chexnet.load_state_dict(checkpoint)
399
400
            chexnet.eval()
401
            self.logger.info("CheXNet model loaded successfully for BioVilt pipeline.")
402
            return chexnet
403
404
        except Exception as e:
405
            self.logger.error(f"Error initializing CheXNet model for BioVilt pipeline: {str(e)}")
406
            raise
407
408
    def _initialize_biovilt(self, checkpoint_path: str):
409
        """Initialize BioVilt Image Encoder, Alignment Model, and Report Generator."""
410
        try:
411
            self.logger.info("Initializing BioVilt Image Encoder, Alignment Model, and Report Generator...")
412
            image_encoder, alignment_model, report_generator = load_biovilt_checkpoint(
413
                checkpoint_path, self.device
414
            )
415
            self.logger.info("BioVilt components loaded successfully.")
416
            return image_encoder, alignment_model, report_generator
417
418
        except Exception as e:
419
            self.logger.error(f"Error initializing BioVilt components: {str(e)}")
420
            raise
421
422
    def _get_transform(self) -> A.Compose:
423
        """Get the transformation pipeline for CheXNet."""
424
        return A.Compose([
425
            A.Resize(224, 224),
426
            A.Normalize(
427
                mean=[0.485, 0.456, 0.406],
428
                std=[0.229, 0.224, 0.225]
429
            ),
430
            ToTensorV2()
431
        ])
432
433
    def _convert_labels_to_findings(self, binary_labels: List[int]) -> str:
434
        """Convert binary labels to a comma-separated string of findings."""
435
        findings = [label for label, val in zip(self.label_columns, binary_labels) if val == 1]
436
        return ", ".join(findings) if findings else "No Findings"
437
438
    def predict_labels(self, image: Image.Image, threshold: float = 0.5) -> List[int]:
439
        """
440
        Predict binary labels for the given image using CheXNet.
441
        
442
        Args:
443
            image (PIL.Image.Image): Input image.
444
            threshold (float): Probability threshold for positive prediction.
445
        
446
        Returns:
447
            List[int]: Binary labels (0 or 1) for each condition.
448
        """
449
        try:
450
            self.logger.info("Predicting labels using CheXNet for BioVilt pipeline...")
451
            transform = self._get_transform()
452
            image_np = np.array(image)
453
            transformed = transform(image=image_np)
454
            image_tensor = transformed['image'].unsqueeze(0).to(self.device)
455
456
            with torch.no_grad():
457
                output = self.chexnet_predictor(image_tensor)
458
                probabilities = torch.sigmoid(output).cpu().numpy()[0]
459
460
            binary_labels = [1 if prob >= threshold else 0 for prob in probabilities]
461
            self.logger.info(f"Predicted binary labels for BioVilt pipeline: {binary_labels}")
462
            return binary_labels
463
464
        except Exception as e:
465
            self.logger.error(f"Error predicting labels for BioVilt pipeline: {str(e)}")
466
            raise
467
468
    def generate_report(self, image: Union[str, Path, Image.Image], threshold: float = 0.5) -> Dict:
469
        """
470
        Generate a medical report for the given chest X-ray image using BioVilt + BioGPT.
471
        
472
        Args:
473
            image (str, Path, or PIL.Image.Image): Input image or path to the image.
474
            threshold (float): Probability threshold for positive prediction.
475
        
476
        Returns:
477
            Dict: Contains the generated report and binary labels.
478
        """
479
        try:
480
            if isinstance(image, (str, Path)):
481
                self.logger.info(f"Generating BioVilt report for image path: {image}")
482
                image_path = Path(image)
483
                if not image_path.exists():
484
                    raise FileNotFoundError(f"Image file {image_path} does not exist.")
485
                # Load image
486
                image = Image.open(image_path).convert('RGB')
487
            elif isinstance(image, Image.Image):
488
                self.logger.info("Generating BioVilt report for uploaded image.")
489
            else:
490
                raise TypeError("Image must be a string path or a PIL.Image.Image object.")
491
492
            # Predict labels
493
            binary_labels = self.predict_labels(image, threshold=threshold)
494
495
            # Convert binary labels to findings string
496
            findings = self._convert_labels_to_findings(binary_labels)
497
            prompt = f"Findings: {findings}."
498
499
            # Tokenize prompt
500
            self.logger.info("Tokenizing prompt...")
501
            prompt_encoding = self.report_generator.tokenizer(
502
                [prompt],
503
                padding=True,
504
                truncation=True,
505
                return_tensors="pt",
506
                max_length=512
507
            ).to(self.device)
508
509
            # Extract image embeddings using BioVilt Image Encoder
510
            self.logger.info("Extracting image embeddings using BioVilt Image Encoder...")
511
            image_np = np.array(image)
512
            transform = A.Compose([
513
                A.Resize(224, 224),
514
                A.Normalize(
515
                    mean=[0.485, 0.456, 0.406],
516
                    std=[0.229, 0.224, 0.225]
517
                ),
518
                ToTensorV2()
519
            ])
520
            transformed = transform(image=image_np)
521
            image_tensor = transformed['image'].unsqueeze(0).to(self.device)
522
523
            with torch.no_grad():
524
                image_encoder_output = self.image_encoder(image_tensor)
525
                # Extract the tensor from ImageModelOutput
526
                if hasattr(image_encoder_output, 'img_embedding'):
527
                    image_embeddings = image_encoder_output.img_embedding
528
                else:
529
                    raise AttributeError("Image encoder output does not have 'img_embedding' attribute.")
530
531
            # Generate medical report
532
            self.logger.info("Generating medical report using BioVilt + BioGPT...")
533
            generated_report = self.report_generator(
534
                image_embeddings=image_embeddings,
535
                prompt_input_ids=prompt_encoding['input_ids'],
536
                target_ids=None  # Not needed during inference
537
            )
538
            self.logger.info("Report generation completed using BioVilt + BioGPT.")
539
540
            # Check if generated_report is a list or similar iterable
541
            if isinstance(generated_report, (list, tuple)):
542
                if len(generated_report) == 0:
543
                    raise ValueError("MedicalReportGenerator returned an empty report list.")
544
                generated_report_text = generated_report[0]
545
            elif isinstance(generated_report, str):
546
                generated_report_text = generated_report
547
            else:
548
                raise TypeError("MedicalReportGenerator.generate_report returned an unsupported type.")
549
550
            # Clean the generated report
551
            cleaned_report = self.clean_report(generated_report_text)
552
553
            # Create labels dictionary
554
            labels_dict = {
555
                label: int(val) for label, val in zip(self.label_columns, binary_labels)
556
            }
557
558
            self.logger.info("BioVilt report generation successful.")
559
            return {
560
                'report': cleaned_report,
561
                'labels': labels_dict
562
            }
563
564
        except Exception as e:
565
            self.logger.error(f"Error generating BioVilt report: {str(e)}")
566
            raise
567
568
    def clean_report(self, text: str) -> str:
569
        """
570
        Remove non-English characters, any occurrence of 'madeupword' followed by digits,
571
        and discard any text after the last period.
572
        
573
        Args:
574
            text (str): The generated medical report text.
575
        
576
        Returns:
577
            str: The cleaned medical report.
578
        """
579
        try:
580
            self.logger.info("Cleaning the generated BioVilt report...")
581
582
            # Remove 'madeupword' followed by any number of digits
583
            text = re.sub(r'madeupword\d+', '', text, flags=re.IGNORECASE)
584
585
            # Remove any non-ASCII characters
586
            text = text.encode('ascii', 'ignore').decode('ascii')
587
588
            # Remove extra spaces created by removals
589
            text = ' '.join(text.split())
590
591
            # Truncate the text after the last period
592
            last_period_index = text.rfind('.')
593
            if last_period_index != -1:
594
                text = text[:last_period_index + 1]
595
            else:
596
                # If no period is found, return the text as is
597
                self.logger.warning("No period found in the text. Returning the original text.")
598
599
            self.logger.info("BioVilt report cleaned successfully.")
600
            return text
601
602
        except Exception as e:
603
            self.logger.error(f"Error cleaning BioVilt report: {str(e)}")
604
            raise
605
606
def load_biovilt_checkpoint(checkpoint_path: str, device: torch.device):
607
    """
608
    Load the BioVilt checkpoint and initialize the models.
609
    
610
    Args:
611
        checkpoint_path (str): Path to the BioVilt checkpoint.
612
        device (torch.device): Device to load the models onto.
613
    
614
    Returns:
615
        Tuple containing image_encoder, alignment_model, report_generator
616
    """
617
    logging.info(f"Loading BioVilt checkpoint from {checkpoint_path}...")
618
    checkpoint = torch.load(checkpoint_path, map_location=device)
619
620
    # Initialize models
621
    image_encoder = get_biovil_t_image_encoder()
622
    alignment_model = ImageTextAlignmentModel(image_embedding_dim=512)
623
    report_generator = BioViltMedicalReportGenerator(image_embedding_dim=512)
624
625
    # Load state dicts
626
    image_encoder.load_state_dict(checkpoint['image_encoder_state_dict'])
627
    alignment_model.load_state_dict(checkpoint['alignment_model_state_dict'])
628
    report_generator.load_state_dict(checkpoint['report_generator_state_dict'])
629
630
    # Move to device
631
    image_encoder = image_encoder.to(device)
632
    alignment_model = alignment_model.to(device)
633
    report_generator = report_generator.to(device)
634
635
    # Set to eval mode
636
    image_encoder.eval()
637
    alignment_model.eval()
638
    report_generator.eval()
639
640
    logging.info("BioVilt models loaded successfully.")
641
    return image_encoder, alignment_model, report_generator
642
643
def load_bioclip_checkpoint(checkpoint_path: str, device: torch.device) -> BioClipMedicalReportGenerator:
644
    """
645
    Load the BioClip MedicalReportGenerator checkpoint.
646
    
647
    Args:
648
        checkpoint_path (str): Path to the BioClip MedicalReportGenerator checkpoint.
649
        device (torch.device): Device to load the model onto.
650
    
651
    Returns:
652
        BioClipMedicalReportGenerator: The loaded MedicalReportGenerator model.
653
    """
654
    logging.info(f"Loading BioClip MedicalReportGenerator checkpoint from {checkpoint_path}...")
655
    checkpoint = torch.load(checkpoint_path, map_location=device)
656
657
    # Initialize BioClip MedicalReportGenerator
658
    vision_hidden_size = 768  # Update this based on your model's hidden size
659
    report_generator = BioClipMedicalReportGenerator(input_embedding_dim=vision_hidden_size)
660
661
    # Load state dict
662
    report_generator.load_state_dict(checkpoint['model_state_dict'])
663
664
    # Move to device and set to eval mode
665
    report_generator.to(device)
666
    report_generator.eval()
667
668
    logging.info("BioClip MedicalReportGenerator loaded successfully.")
669
    return report_generator
670
671
# ---------------------- Streamlit Application ---------------------- #
672
673
def main():
674
    st.set_page_config(page_title="Chest X-ray Medical Report Generator", layout="centered")
675
    st.title("Chest X-ray Medical Report Generator")
676
677
    st.markdown("""
678
    Upload a chest X-ray image, and click the **Generate Report** button to receive a detailed medical report along with predicted conditions.
679
    """)
680
681
    # File uploader
682
    uploaded_file = st.file_uploader("Upload a chest X-ray image", type=["png", "jpg", "jpeg"])
683
684
    if uploaded_file is not None:
685
        # Display the image
686
        image = Image.open(uploaded_file).convert('RGB')
687
        st.image(image, caption='Uploaded Chest X-ray Image', use_container_width=True)
688
689
        # Perform Grayscale Classification
690
        with st.spinner("Verifying if the uploaded image is a chest X-ray..."):
691
            is_cxr = is_grayscale(image, threshold=90.0)  # Adjust threshold as needed
692
693
        if not is_cxr:
694
            st.error("This image is not a chest X-ray image, please upload a chest X-ray image.")
695
            st.stop()  # Stop further execution
696
        else:
697
            st.success("Image verified as a chest X-ray. Proceeding with report generation.")
698
699
        # Initialize the inference pipelines
700
        @st.cache_resource
701
        def load_inference_pipelines():
702
            # Paths for BLIP2 + BioGPT
703
            blip2_model_name = "Salesforce/blip2-opt-2.7b"
704
            blip2_device_map = 'auto'
705
            blip2_checkpoint = r"C:\Users\anand\Downloads\checkpoint_epoch_20.pt"  # Update path as needed
706
707
            blip2_pipeline = ChestXrayFullInference(
708
                chexnet_model_path=r"C:\Users\anand\Downloads\best_chexnet_finetuned_16_f1.pth",  # Update path as needed
709
                blip2_model_name=blip2_model_name,
710
                blip2_device_map=blip2_device_map,
711
                chexnet_num_classes=14,
712
                report_generator_checkpoint=blip2_checkpoint
713
            )
714
715
            # Paths for BioVilt + BioGPT
716
            biovilt_checkpoint_path = r"C:\Users\anand\Downloads\model_epoch_7.pt"  # Update path as needed
717
718
            biovilt_pipeline = ChestXrayBioViltInference(
719
                chexnet_model_path=r"C:\Users\anand\Downloads\best_chexnet_finetuned_16_f1.pth",  # Update path as needed
720
                biovilt_checkpoint_path=biovilt_checkpoint_path
721
            )
722
723
            return blip2_pipeline, biovilt_pipeline
724
725
        try:
726
            blip2_pipeline, biovilt_pipeline = load_inference_pipelines()
727
        except Exception as e:
728
            st.error(f"Failed to load inference pipelines: {e}")
729
            st.stop()
730
731
        # Define buttons for model selection
732
        col1, col2 = st.columns(2)
733
734
        with col1:
735
            blip2_button = st.button("Generate Report with BLIP2 + BioGPT")
736
737
        with col2:
738
            biovilt_button = st.button("Generate Report with BioVilt + BioGPT")
739
740
        # Handle BLIP2 + BioGPT report generation
741
        if blip2_button:
742
            with st.spinner("Generating report with BLIP2 + BioGPT..."):
743
                try:
744
                    result = blip2_pipeline.generate_report(image, threshold=0.65)
745
                    
746
                    # Display the report
747
                    st.subheader("Generated Medical Report (BLIP2 + BioGPT)")
748
                    st.write(result['report'])
749
750
                except Exception as e:
751
                    st.error(f"Failed to generate BLIP2 + BioGPT report: {e}")
752
753
        # Handle BioVilt + BioGPT report generation
754
        if biovilt_button:
755
            with st.spinner("Generating report with BioVilt + BioGPT..."):
756
                try:
757
                    result = biovilt_pipeline.generate_report(image, threshold=0.65)
758
                    
759
                    # Display the report
760
                    st.subheader("Generated Medical Report (BioVilt + BioGPT)")
761
                    st.write(result['report'])
762
763
                except Exception as e:
764
                    st.error(f"Failed to generate BioVilt + BioGPT report: {e}")
765
766
if __name__ == "__main__":
767
    import pandas as pd  
768
    main()