--- a +++ b/streamlit.py @@ -0,0 +1,768 @@ +# app.py + +import streamlit as st +import torch +from transformers import Blip2Processor, Blip2ForConditionalGeneration +from PIL import Image +from pathlib import Path +import logging +import sys +import os +import numpy as np +import torchvision.transforms as transforms +from typing import Union, List, Dict + +# Import the MedicalReportGenerator from the appropriate modules with aliases +from report_generator_bioclip import MedicalReportGenerator as BioClipMedicalReportGenerator +from report_generator_concat import MedicalReportGenerator as BioViltMedicalReportGenerator + +# Import the ModifiedCheXNet model class +from chexnet_train import ModifiedCheXNet + +# Import BioVilt specific modules +from alignment_concat import ImageTextAlignmentModel +from biovil_t.pretrained import get_biovil_t_image_encoder # Ensure this import path is correct + +# Additional imports for BioVilt pipeline +import cv2 +import albumentations as A +from albumentations.pytorch import ToTensorV2 +import re + +# Suppress excessive warnings from transformers and torchvision +import warnings +warnings.filterwarnings("ignore") + +# To disable torchvision beta transforms warnings +import torchvision +if hasattr(torchvision, 'disable_beta_transforms_warning'): + torchvision.disable_beta_transforms_warning() + +# Import torchxrayvision +import torchxrayvision as xrv + +# ---------------------- Grayscale Classification ---------------------- # + +def is_grayscale(image: Image.Image, threshold: float = 90.0) -> bool: + """ + Determine if the image is predominantly grayscale. + Removed multiple checks and kept only one check + """ + try: + # Ensure image is in RGB + image = image.convert("RGB") + w, h = image.size + pixels = image.getdata() + grayscale_pixels = sum(1 for pixel in pixels if pixel[0] == pixel[1] == pixel[2]) + total_pixels = w * h + grayscale_percentage = (grayscale_pixels / total_pixels) * 100 + return grayscale_percentage > threshold + except Exception as e: + logging.error(f"Error in is_grayscale: {e}") + return False + +# ---------------------- Inference Pipelines ---------------------- # + +class ChestXrayFullInference: + def __init__( + self, + chexnet_model_path: str, + blip2_model_name: str = "Salesforce/blip2-opt-2.7b", + blip2_device_map: str = 'auto', + chexnet_num_classes: int = 14, + report_generator_checkpoint: str = None, + device: str = None + ): + """ + Initialize the full inference pipeline with CheXNet, BLIP-2, and BioClip MedicalReportGenerator. + + Args: + chexnet_model_path (str): Path to the trained CheXNet model checkpoint. + blip2_model_name (str): Hugging Face model name for BLIP-2. + blip2_device_map (str): Device mapping for BLIP-2 ('auto' by default). + chexnet_num_classes (int): Number of classes for CheXNet. + report_generator_checkpoint (str): Path to the BioClip MedicalReportGenerator checkpoint. + device (str): Device to use ('cuda' or 'cpu'). + """ + self.logger = self._setup_logger() + self.device = torch.device(device) if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.logger.info(f"Using device: {self.device}") + + # Initialize CheXNet Predictor + self.chexnet_predictor = self._initialize_chexnet( + chexnet_model_path, chexnet_num_classes + ) + + # Initialize BLIP-2 Processor and Model + self.processor, self.blip_model = self._initialize_blip2( + blip2_model_name, blip2_device_map + ) + + # Initialize BioClip MedicalReportGenerator + self.report_generator = self._initialize_report_generator( + report_generator_checkpoint + ) + + # Define label columns + self.label_columns = [ + 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', + 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', + 'Atelectasis', 'Pneumothorax', 'Pleural Effusion', + 'Pleural Other', 'Fracture', 'Support Devices', 'No Finding' + ] + + def _setup_logger(self) -> logging.Logger: + """Set up logging configuration.""" + logger = logging.getLogger('ChestXrayFullInference') + logger.setLevel(logging.INFO) + + if not logger.handlers: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + )) + logger.addHandler(handler) + + return logger + + def _initialize_chexnet(self, model_path: str, num_classes: int) -> ModifiedCheXNet: + """Initialize the CheXNet model.""" + try: + self.logger.info("Initializing CheXNet model...") + chexnet = ModifiedCheXNet(num_classes=num_classes).to(self.device) + checkpoint = torch.load(model_path, map_location=self.device) + + # Handle different checkpoint formats + if 'model_state_dict' in checkpoint: + chexnet.load_state_dict(checkpoint['model_state_dict']) + else: + chexnet.load_state_dict(checkpoint) + + chexnet.eval() + self.logger.info("CheXNet model loaded successfully.") + return chexnet + + except Exception as e: + self.logger.error(f"Error initializing CheXNet model: {str(e)}") + raise + + def _initialize_blip2( + self, model_name: str, device_map: str + ) -> (Blip2Processor, Blip2ForConditionalGeneration): + """Initialize the BLIP-2 processor and model.""" + try: + self.logger.info("Initializing BLIP-2 model and processor...") + processor = Blip2Processor.from_pretrained(model_name, force_download=True) + blip_model = Blip2ForConditionalGeneration.from_pretrained( + model_name, + torch_dtype=torch.float32, + device_map=device_map + ) + blip_model.eval() + self.logger.info("BLIP-2 model and processor loaded successfully.") + return processor, blip_model + + except Exception as e: + self.logger.error(f"Error initializing BLIP-2 model: {str(e)}") + raise + + def _initialize_report_generator(self, checkpoint_path: str) -> BioClipMedicalReportGenerator: + """Initialize the BioClip MedicalReportGenerator.""" + try: + self.logger.info("Initializing BioClip MedicalReportGenerator...") + vision_hidden_size = self.blip_model.vision_model.config.hidden_size + report_gen = BioClipMedicalReportGenerator(input_embedding_dim=vision_hidden_size) + + # Load trained weights + checkpoint = torch.load(checkpoint_path, map_location=self.device) + report_gen.load_state_dict(checkpoint['model_state_dict']) + report_gen.to(self.device) + report_gen.eval() + self.logger.info("BioClip MedicalReportGenerator loaded successfully.") + return report_gen + + except Exception as e: + self.logger.error(f"Error initializing BioClip MedicalReportGenerator: {str(e)}") + raise + + def _get_transform(self) -> transforms.Compose: + """Get the transformation pipeline for CheXNet.""" + return transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], + [0.229, 0.224, 0.225]) + ]) + + def _convert_labels_to_findings(self, binary_labels: List[int]) -> str: + """Convert binary labels to a comma-separated string of findings.""" + findings = [label for label, val in zip(self.label_columns, binary_labels) if val == 1] + return ", ".join(findings) if findings else "No Findings" + + def predict_labels(self, image: Image.Image, threshold: float = 0.5) -> List[int]: + """ + Predict binary labels for the given image using CheXNet. + + Args: + image (PIL.Image.Image): Input image. + threshold (float): Probability threshold for positive prediction. + + Returns: + List[int]: Binary labels (0 or 1) for each condition. + """ + try: + self.logger.info("Predicting labels using CheXNet...") + transform = self._get_transform() + image_tensor = transform(image).unsqueeze(0).to(self.device) + + with torch.no_grad(): + output = self.chexnet_predictor(image_tensor) + probabilities = torch.sigmoid(output).cpu().numpy()[0] + + binary_labels = [1 if prob >= threshold else 0 for prob in probabilities] + self.logger.info(f"Predicted binary labels: {binary_labels}") + return binary_labels + + except Exception as e: + self.logger.error(f"Error predicting labels: {str(e)}") + raise + + def extract_image_features(self, image: Image.Image) -> torch.Tensor: + """ + Extract image features using BLIP-2. + + Args: + image (PIL.Image.Image): Input image. + + Returns: + torch.Tensor: Image features tensor. + """ + try: + self.logger.info("Extracting image features using BLIP-2...") + processed = self.processor(images=image, return_tensors="pt") + pixel_values = processed.pixel_values.to(self.device) + + with torch.no_grad(): + vision_outputs = self.blip_model.vision_model(pixel_values) + image_features = vision_outputs.pooler_output + + self.logger.info(f"Extracted image features with shape: {image_features.shape}") + return image_features + + except Exception as e: + self.logger.error(f"Error extracting image features: {str(e)}") + raise + + def generate_report(self, image: Union[str, Path, Image.Image], threshold: float = 0.5) -> Dict: + """ + Generate a medical report for the given chest X-ray image. + + Args: + image (str, Path, or PIL.Image.Image): Input image or path to the image. + threshold (float): Probability threshold for positive prediction. + + Returns: + Dict: Contains the generated report and binary labels. + """ + try: + if isinstance(image, (str, Path)): + self.logger.info(f"Generating report for image path: {image}") + image_path = Path(image) + if not image_path.exists(): + raise FileNotFoundError(f"Image file {image_path} does not exist.") + # Load image + image = Image.open(image_path).convert('RGB') + elif isinstance(image, Image.Image): + self.logger.info("Generating report for uploaded image.") + else: + raise TypeError("Image must be a string path or a PIL.Image.Image object.") + + # Predict labels + binary_labels = self.predict_labels(image, threshold=threshold) + + # Convert binary labels to findings string + findings = self._convert_labels_to_findings(binary_labels) + prompt = f"Findings: {findings}." + + # Tokenize prompt + self.logger.info("Tokenizing prompt...") + prompt_encoding = self.report_generator.tokenizer( + [prompt], + padding=True, + truncation=True, + return_tensors="pt", + max_length=512 + ).to(self.device) + + # Extract image features + image_features = self.extract_image_features(image) + + # Start report generation + self.logger.info("Starting report generation...") + # Corrected: Do not pass 'prompt' argument + generated_report = self.report_generator.generate_report( + input_embeddings=image_features, + labels=torch.tensor(binary_labels, dtype=torch.float32).unsqueeze(0).to(self.device) + ) + self.logger.info("Report generation completed.") + + # Check if generated_report is a list or similar iterable + if isinstance(generated_report, (list, tuple)): + if len(generated_report) == 0: + raise ValueError("MedicalReportGenerator returned an empty report list.") + generated_report_text = generated_report[0] + elif isinstance(generated_report, str): + generated_report_text = generated_report + else: + raise TypeError("MedicalReportGenerator.generate_report returned an unsupported type.") + + # Create labels dictionary + labels_dict = { + label: int(val) for label, val in zip(self.label_columns, binary_labels) + } + + self.logger.info("Report generation successful.") + return { + 'report': generated_report_text, + 'labels': labels_dict + } + + except Exception as e: + self.logger.error(f"Error generating report: {str(e)}") + raise + + +class ChestXrayBioViltInference: + def __init__( + self, + chexnet_model_path: str, + biovilt_checkpoint_path: str, + device: str = None + ): + """ + Initialize the inference pipeline with CheXNet and BioVilt + BioGPT. + + Args: + chexnet_model_path (str): Path to the trained CheXNet model checkpoint. + biovilt_checkpoint_path (str): Path to the BioVilt + BioGPT model checkpoint. + device (str): Device to use ('cuda' or 'cpu'). + """ + self.logger = self._setup_logger() + self.device = torch.device(device) if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.logger.info(f"Using device: {self.device}") + + # Initialize CheXNet Predictor + self.chexnet_predictor = self._initialize_chexnet( + chexnet_model_path, num_classes=14 # Corrected parameter name + ) + + # Initialize BioVilt components + self.image_encoder, self.alignment_model, self.report_generator = self._initialize_biovilt( + biovilt_checkpoint_path + ) + + # Define label columns + self.label_columns = [ + 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', + 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', + 'Atelectasis', 'Pneumothorax', 'Pleural Effusion', + 'Pleural Other', 'Fracture', 'Support Devices', 'No Finding' + ] + + def _setup_logger(self) -> logging.Logger: + """Set up logging configuration.""" + logger = logging.getLogger('ChestXrayBioViltInference') + logger.setLevel(logging.INFO) + + if not logger.handlers: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + )) + logger.addHandler(handler) + + return logger + + def _initialize_chexnet(self, model_path: str, num_classes: int) -> ModifiedCheXNet: + """Initialize the CheXNet model.""" + try: + self.logger.info("Initializing CheXNet model for BioVilt pipeline...") + chexnet = ModifiedCheXNet(num_classes=num_classes).to(self.device) + checkpoint = torch.load(model_path, map_location=self.device) + + # Handle different checkpoint formats + if 'model_state_dict' in checkpoint: + chexnet.load_state_dict(checkpoint['model_state_dict']) + else: + chexnet.load_state_dict(checkpoint) + + chexnet.eval() + self.logger.info("CheXNet model loaded successfully for BioVilt pipeline.") + return chexnet + + except Exception as e: + self.logger.error(f"Error initializing CheXNet model for BioVilt pipeline: {str(e)}") + raise + + def _initialize_biovilt(self, checkpoint_path: str): + """Initialize BioVilt Image Encoder, Alignment Model, and Report Generator.""" + try: + self.logger.info("Initializing BioVilt Image Encoder, Alignment Model, and Report Generator...") + image_encoder, alignment_model, report_generator = load_biovilt_checkpoint( + checkpoint_path, self.device + ) + self.logger.info("BioVilt components loaded successfully.") + return image_encoder, alignment_model, report_generator + + except Exception as e: + self.logger.error(f"Error initializing BioVilt components: {str(e)}") + raise + + def _get_transform(self) -> A.Compose: + """Get the transformation pipeline for CheXNet.""" + return A.Compose([ + A.Resize(224, 224), + A.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ), + ToTensorV2() + ]) + + def _convert_labels_to_findings(self, binary_labels: List[int]) -> str: + """Convert binary labels to a comma-separated string of findings.""" + findings = [label for label, val in zip(self.label_columns, binary_labels) if val == 1] + return ", ".join(findings) if findings else "No Findings" + + def predict_labels(self, image: Image.Image, threshold: float = 0.5) -> List[int]: + """ + Predict binary labels for the given image using CheXNet. + + Args: + image (PIL.Image.Image): Input image. + threshold (float): Probability threshold for positive prediction. + + Returns: + List[int]: Binary labels (0 or 1) for each condition. + """ + try: + self.logger.info("Predicting labels using CheXNet for BioVilt pipeline...") + transform = self._get_transform() + image_np = np.array(image) + transformed = transform(image=image_np) + image_tensor = transformed['image'].unsqueeze(0).to(self.device) + + with torch.no_grad(): + output = self.chexnet_predictor(image_tensor) + probabilities = torch.sigmoid(output).cpu().numpy()[0] + + binary_labels = [1 if prob >= threshold else 0 for prob in probabilities] + self.logger.info(f"Predicted binary labels for BioVilt pipeline: {binary_labels}") + return binary_labels + + except Exception as e: + self.logger.error(f"Error predicting labels for BioVilt pipeline: {str(e)}") + raise + + def generate_report(self, image: Union[str, Path, Image.Image], threshold: float = 0.5) -> Dict: + """ + Generate a medical report for the given chest X-ray image using BioVilt + BioGPT. + + Args: + image (str, Path, or PIL.Image.Image): Input image or path to the image. + threshold (float): Probability threshold for positive prediction. + + Returns: + Dict: Contains the generated report and binary labels. + """ + try: + if isinstance(image, (str, Path)): + self.logger.info(f"Generating BioVilt report for image path: {image}") + image_path = Path(image) + if not image_path.exists(): + raise FileNotFoundError(f"Image file {image_path} does not exist.") + # Load image + image = Image.open(image_path).convert('RGB') + elif isinstance(image, Image.Image): + self.logger.info("Generating BioVilt report for uploaded image.") + else: + raise TypeError("Image must be a string path or a PIL.Image.Image object.") + + # Predict labels + binary_labels = self.predict_labels(image, threshold=threshold) + + # Convert binary labels to findings string + findings = self._convert_labels_to_findings(binary_labels) + prompt = f"Findings: {findings}." + + # Tokenize prompt + self.logger.info("Tokenizing prompt...") + prompt_encoding = self.report_generator.tokenizer( + [prompt], + padding=True, + truncation=True, + return_tensors="pt", + max_length=512 + ).to(self.device) + + # Extract image embeddings using BioVilt Image Encoder + self.logger.info("Extracting image embeddings using BioVilt Image Encoder...") + image_np = np.array(image) + transform = A.Compose([ + A.Resize(224, 224), + A.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ), + ToTensorV2() + ]) + transformed = transform(image=image_np) + image_tensor = transformed['image'].unsqueeze(0).to(self.device) + + with torch.no_grad(): + image_encoder_output = self.image_encoder(image_tensor) + # Extract the tensor from ImageModelOutput + if hasattr(image_encoder_output, 'img_embedding'): + image_embeddings = image_encoder_output.img_embedding + else: + raise AttributeError("Image encoder output does not have 'img_embedding' attribute.") + + # Generate medical report + self.logger.info("Generating medical report using BioVilt + BioGPT...") + generated_report = self.report_generator( + image_embeddings=image_embeddings, + prompt_input_ids=prompt_encoding['input_ids'], + target_ids=None # Not needed during inference + ) + self.logger.info("Report generation completed using BioVilt + BioGPT.") + + # Check if generated_report is a list or similar iterable + if isinstance(generated_report, (list, tuple)): + if len(generated_report) == 0: + raise ValueError("MedicalReportGenerator returned an empty report list.") + generated_report_text = generated_report[0] + elif isinstance(generated_report, str): + generated_report_text = generated_report + else: + raise TypeError("MedicalReportGenerator.generate_report returned an unsupported type.") + + # Clean the generated report + cleaned_report = self.clean_report(generated_report_text) + + # Create labels dictionary + labels_dict = { + label: int(val) for label, val in zip(self.label_columns, binary_labels) + } + + self.logger.info("BioVilt report generation successful.") + return { + 'report': cleaned_report, + 'labels': labels_dict + } + + except Exception as e: + self.logger.error(f"Error generating BioVilt report: {str(e)}") + raise + + def clean_report(self, text: str) -> str: + """ + Remove non-English characters, any occurrence of 'madeupword' followed by digits, + and discard any text after the last period. + + Args: + text (str): The generated medical report text. + + Returns: + str: The cleaned medical report. + """ + try: + self.logger.info("Cleaning the generated BioVilt report...") + + # Remove 'madeupword' followed by any number of digits + text = re.sub(r'madeupword\d+', '', text, flags=re.IGNORECASE) + + # Remove any non-ASCII characters + text = text.encode('ascii', 'ignore').decode('ascii') + + # Remove extra spaces created by removals + text = ' '.join(text.split()) + + # Truncate the text after the last period + last_period_index = text.rfind('.') + if last_period_index != -1: + text = text[:last_period_index + 1] + else: + # If no period is found, return the text as is + self.logger.warning("No period found in the text. Returning the original text.") + + self.logger.info("BioVilt report cleaned successfully.") + return text + + except Exception as e: + self.logger.error(f"Error cleaning BioVilt report: {str(e)}") + raise + +def load_biovilt_checkpoint(checkpoint_path: str, device: torch.device): + """ + Load the BioVilt checkpoint and initialize the models. + + Args: + checkpoint_path (str): Path to the BioVilt checkpoint. + device (torch.device): Device to load the models onto. + + Returns: + Tuple containing image_encoder, alignment_model, report_generator + """ + logging.info(f"Loading BioVilt checkpoint from {checkpoint_path}...") + checkpoint = torch.load(checkpoint_path, map_location=device) + + # Initialize models + image_encoder = get_biovil_t_image_encoder() + alignment_model = ImageTextAlignmentModel(image_embedding_dim=512) + report_generator = BioViltMedicalReportGenerator(image_embedding_dim=512) + + # Load state dicts + image_encoder.load_state_dict(checkpoint['image_encoder_state_dict']) + alignment_model.load_state_dict(checkpoint['alignment_model_state_dict']) + report_generator.load_state_dict(checkpoint['report_generator_state_dict']) + + # Move to device + image_encoder = image_encoder.to(device) + alignment_model = alignment_model.to(device) + report_generator = report_generator.to(device) + + # Set to eval mode + image_encoder.eval() + alignment_model.eval() + report_generator.eval() + + logging.info("BioVilt models loaded successfully.") + return image_encoder, alignment_model, report_generator + +def load_bioclip_checkpoint(checkpoint_path: str, device: torch.device) -> BioClipMedicalReportGenerator: + """ + Load the BioClip MedicalReportGenerator checkpoint. + + Args: + checkpoint_path (str): Path to the BioClip MedicalReportGenerator checkpoint. + device (torch.device): Device to load the model onto. + + Returns: + BioClipMedicalReportGenerator: The loaded MedicalReportGenerator model. + """ + logging.info(f"Loading BioClip MedicalReportGenerator checkpoint from {checkpoint_path}...") + checkpoint = torch.load(checkpoint_path, map_location=device) + + # Initialize BioClip MedicalReportGenerator + vision_hidden_size = 768 # Update this based on your model's hidden size + report_generator = BioClipMedicalReportGenerator(input_embedding_dim=vision_hidden_size) + + # Load state dict + report_generator.load_state_dict(checkpoint['model_state_dict']) + + # Move to device and set to eval mode + report_generator.to(device) + report_generator.eval() + + logging.info("BioClip MedicalReportGenerator loaded successfully.") + return report_generator + +# ---------------------- Streamlit Application ---------------------- # + +def main(): + st.set_page_config(page_title="Chest X-ray Medical Report Generator", layout="centered") + st.title("Chest X-ray Medical Report Generator") + + st.markdown(""" + Upload a chest X-ray image, and click the **Generate Report** button to receive a detailed medical report along with predicted conditions. + """) + + # File uploader + uploaded_file = st.file_uploader("Upload a chest X-ray image", type=["png", "jpg", "jpeg"]) + + if uploaded_file is not None: + # Display the image + image = Image.open(uploaded_file).convert('RGB') + st.image(image, caption='Uploaded Chest X-ray Image', use_container_width=True) + + # Perform Grayscale Classification + with st.spinner("Verifying if the uploaded image is a chest X-ray..."): + is_cxr = is_grayscale(image, threshold=90.0) # Adjust threshold as needed + + if not is_cxr: + st.error("This image is not a chest X-ray image, please upload a chest X-ray image.") + st.stop() # Stop further execution + else: + st.success("Image verified as a chest X-ray. Proceeding with report generation.") + + # Initialize the inference pipelines + @st.cache_resource + def load_inference_pipelines(): + # Paths for BLIP2 + BioGPT + blip2_model_name = "Salesforce/blip2-opt-2.7b" + blip2_device_map = 'auto' + blip2_checkpoint = r"C:\Users\anand\Downloads\checkpoint_epoch_20.pt" # Update path as needed + + blip2_pipeline = ChestXrayFullInference( + chexnet_model_path=r"C:\Users\anand\Downloads\best_chexnet_finetuned_16_f1.pth", # Update path as needed + blip2_model_name=blip2_model_name, + blip2_device_map=blip2_device_map, + chexnet_num_classes=14, + report_generator_checkpoint=blip2_checkpoint + ) + + # Paths for BioVilt + BioGPT + biovilt_checkpoint_path = r"C:\Users\anand\Downloads\model_epoch_7.pt" # Update path as needed + + biovilt_pipeline = ChestXrayBioViltInference( + chexnet_model_path=r"C:\Users\anand\Downloads\best_chexnet_finetuned_16_f1.pth", # Update path as needed + biovilt_checkpoint_path=biovilt_checkpoint_path + ) + + return blip2_pipeline, biovilt_pipeline + + try: + blip2_pipeline, biovilt_pipeline = load_inference_pipelines() + except Exception as e: + st.error(f"Failed to load inference pipelines: {e}") + st.stop() + + # Define buttons for model selection + col1, col2 = st.columns(2) + + with col1: + blip2_button = st.button("Generate Report with BLIP2 + BioGPT") + + with col2: + biovilt_button = st.button("Generate Report with BioVilt + BioGPT") + + # Handle BLIP2 + BioGPT report generation + if blip2_button: + with st.spinner("Generating report with BLIP2 + BioGPT..."): + try: + result = blip2_pipeline.generate_report(image, threshold=0.65) + + # Display the report + st.subheader("Generated Medical Report (BLIP2 + BioGPT)") + st.write(result['report']) + + except Exception as e: + st.error(f"Failed to generate BLIP2 + BioGPT report: {e}") + + # Handle BioVilt + BioGPT report generation + if biovilt_button: + with st.spinner("Generating report with BioVilt + BioGPT..."): + try: + result = biovilt_pipeline.generate_report(image, threshold=0.65) + + # Display the report + st.subheader("Generated Medical Report (BioVilt + BioGPT)") + st.write(result['report']) + + except Exception as e: + st.error(f"Failed to generate BioVilt + BioGPT report: {e}") + +if __name__ == "__main__": + import pandas as pd + main()