Diff of /testing.py [000000] .. [d255cc]

Switch to side-by-side view

--- a
+++ b/testing.py
@@ -0,0 +1,513 @@
+import os
+import cv2
+import numpy as np
+import torch
+from torch.cuda.amp import autocast
+
+import re
+import glob
+
+# 导入SAM2相关模块
+from sam2.build_sam import build_sam2
+from sam2.sam2_image_predictor import SAM2ImagePredictor
+
+
+def inference_sam2_compare_models(
+    image_path,
+    mask_paths,  
+    model_cfg,
+    sam2_checkpoint,
+    fine_tuned_path=None,
+    num_samples_per_mask=10,
+    output_dir=None,  
+    show_results=False,
+    device="cuda"
+):
+
+
+
+    def read_image(image_path):
+        if not os.path.exists(image_path):
+            raise FileNotFoundError(f"图像文件不存在: {image_path}")
+            
+        img = cv2.imread(image_path)
+        if img is None:
+            raise ValueError(f"无法读取图像: {image_path}")
+        img = img[..., ::-1]  # BGR to RGB
+        
+        # 将图像调整到最大尺寸1024
+        resize_factor = np.min([1024 / img.shape[1], 1024 / img.shape[0]])
+        img = cv2.resize(img, (int(img.shape[1] * resize_factor), int(img.shape[0] * resize_factor)))
+        return img, resize_factor
+    
+    # 在输入掩码内采样点
+    def get_points(mask, num_points):
+        points = []
+        coords = np.argwhere(mask > 0)
+        
+        if len(coords) == 0:
+            print("警告:掩码为空,无法采样点")
+            return np.array([])
+            
+        for _ in range(num_points):
+            if len(coords) > 0:
+                idx = np.random.randint(len(coords))
+                yx = coords[idx]
+                points.append([[yx[1], yx[0]]])
+        
+        return np.array(points) if points else np.array([])
+    
+    # 读取图像
+    image, resize_factor = read_image(image_path)
+    
+    # 读取所有掩码并采样点
+    all_points = []
+    all_point_labels = []
+    all_masks = []
+    
+    for i, mask_path in enumerate(mask_paths):
+        if not os.path.exists(mask_path):
+            print(f"警告:掩码文件不存在: {mask_path},将跳过")
+            continue
+            
+        mask = cv2.imread(mask_path, 0)  # 读取掩码
+        if mask is None:
+            print(f"警告:无法读取掩码: {mask_path},将跳过")
+            continue
+            
+        # 调整掩码大小以匹配图像 - 使用resize_factor而不是r
+        mask = cv2.resize(mask, (int(mask.shape[1] * resize_factor), int(mask.shape[0] * resize_factor)), 
+                         interpolation=cv2.INTER_NEAREST)
+        
+        all_masks.append(mask)
+        
+        # 获取该掩码的点
+        points = get_points(mask, num_samples_per_mask)
+        if len(points) > 0:
+            all_points.append(points)
+            all_point_labels.append(np.ones((len(points), 1)))
+    
+    if not all_points:
+        raise ValueError("所有掩码都无法采样点,无法进行分割")
+    
+    
+    # 合并所有掩码的点和标签
+    input_points = np.concatenate(all_points, axis=0)
+    input_labels = np.concatenate(all_point_labels, axis=0)
+    
+    # 创建真实掩码的可视化
+    ground_truth_map = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
+    ground_truth_rgb = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.uint8)
+    
+    # 为每个掩码分配唯一的颜色和ID
+    colors = {}
+    
+    for i, mask in enumerate(all_masks):
+        # 生成随机颜色,但确保所有可视化中使用相同颜色
+        if i+1 not in colors:
+            colors[i+1] = [
+                np.random.randint(255),
+                np.random.randint(255),
+                np.random.randint(255)
+            ]
+        
+        # 更新掩码图和RGB图
+        ground_truth_map[mask > 0] = i + 1
+        ground_truth_rgb[mask > 0] = colors[i+1]
+    
+    # 创建真实掩码的混合图像
+    gt_blended_image = (ground_truth_rgb / 2 + image / 2).astype(np.uint8)
+    
+    # 初始化结果字典
+    results = {
+        "ground_truth": {
+            "map": ground_truth_map,
+            "rgb": ground_truth_rgb,
+            "blended": gt_blended_image
+        }
+    }
+    
+    # 使用两个模型进行推理
+    for model_key, model_path in [
+        ("original", sam2_checkpoint), 
+        ("fine_tuned", fine_tuned_path)
+    ]:
+        # 如果是微调模型但路径为None,则跳过
+        if model_key == "fine_tuned" and model_path is None:
+            continue
+            
+        print(f"\n使用{model_key}模型进行推理...")
+        
+        # 初始化SAM2模型
+        sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
+        predictor = SAM2ImagePredictor(sam2_model)
+        
+        # 如果是微调模型,加载微调权重
+        if model_key == "fine_tuned":
+            try:
+                predictor.model.load_state_dict(torch.load(model_path, weights_only=True))
+            except:
+                predictor.model.load_state_dict(torch.load(model_path))
+            print(f"已加载微调模型: {model_path}")
+        
+        # 设置模型为评估模式
+        predictor.model.eval()
+        
+        # 使用混合精度进行推理
+        with torch.no_grad():
+            # 图像编码器
+            predictor.set_image(image)
+            
+            print(f"{model_key}模型正在进行预测...")
+            print("输入点形状:", input_points.shape)
+            
+            # prompt编码器 + mask解码器
+            masks, scores, _ = predictor.predict(
+                point_coords=input_points,
+                point_labels=input_labels,
+                multimask_output=True
+            )
+            
+            print(f"{model_key}模型预测完成")
+            print(f"预测掩码数量: {masks.shape[0] if hasattr(masks, 'shape') else '未知'}")
+        
+        # 处理masks和scores
+        if isinstance(masks, torch.Tensor):
+            np_masks = masks.cpu().numpy()
+            if np_masks.ndim > 3:
+                np_masks = np_masks[:, 0]
+        else:
+            np_masks = masks
+            if np_masks.ndim > 3:
+                np_masks = np_masks[:, 0]
+        
+        if isinstance(scores, torch.Tensor):
+            np_scores = scores.cpu().numpy()
+            if np_scores.ndim > 1:
+                np_scores = np_scores[:, 0]
+        else:
+            np_scores = scores
+            if np_scores.ndim > 1:
+                np_scores = np_scores[:, 0]
+        
+        # 根据分数排序掩码(降序)
+        sorted_indices = np.argsort(np_scores)[::-1]
+        sorted_masks = np_masks[sorted_indices]
+        
+        # 创建分割图
+        seg_map = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
+        occupancy_mask = np.zeros((image.shape[0], image.shape[1]), dtype=bool)
+        
+        for i, mask in enumerate(sorted_masks):
+            # 确保掩码是布尔类型
+            mask_bool = mask.astype(bool)
+            
+            # 计算重叠比例
+            # 使用逻辑与运算而不是位与运算
+            overlap_ratio = np.logical_and(mask_bool, occupancy_mask).sum() / (mask_bool.sum() + 1e-6)
+            
+            if mask_bool.sum() > 0 and overlap_ratio > 0.15:
+                continue
+                
+            # 更新掩码和分割图
+            non_overlap = np.logical_and(mask_bool, np.logical_not(occupancy_mask))
+            seg_map[non_overlap] = i + 1
+            occupancy_mask[mask_bool] = True
+        
+        # 创建可视化图像(使用与真实掩码相同的颜色)
+        rgb_image = np.zeros((seg_map.shape[0], seg_map.shape[1], 3), dtype=np.uint8)
+        for id_class in range(1, seg_map.max() + 1):
+            # 尝试使用与真实掩码相同的颜色
+            color = colors.get(id_class, [
+                np.random.randint(255),
+                np.random.randint(255),
+                np.random.randint(255)
+            ])
+            rgb_image[seg_map == id_class] = color
+        
+        # 创建混合图像
+        blended_image = (rgb_image / 2 + image / 2).astype(np.uint8)
+        
+        # 创建差异图 (显示预测与真实值的不同)
+        difference = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.uint8)
+        
+        # 绿色表示正确预测的区域
+        correct = np.logical_and(ground_truth_map > 0, seg_map > 0)
+        difference[correct] = [0, 255, 0]
+        
+        # 红色表示未预测到的区域 (漏检)
+        missed = np.logical_and(ground_truth_map > 0, seg_map == 0)
+        difference[missed] = [255, 0, 0]
+        
+        # 蓝色表示错误预测的区域 (误检)
+        false_positive = np.logical_and(ground_truth_map == 0, seg_map > 0)
+        difference[false_positive] = [0, 0, 255]
+        
+        # 保存到结果字典
+        results[model_key] = {
+            "map": seg_map,
+            "rgb": rgb_image,
+            "blended": blended_image,
+            "difference": difference
+        }
+    
+    # 保存结果
+    if output_dir:
+        os.makedirs(output_dir, exist_ok=True)
+        
+        # 保存真实掩码
+        cv2.imwrite(os.path.join(output_dir, "gt_segmentation.png"), 
+                    cv2.cvtColor(results["ground_truth"]["rgb"], cv2.COLOR_RGB2BGR))
+        cv2.imwrite(os.path.join(output_dir, "gt_blended.png"), 
+                    cv2.cvtColor(results["ground_truth"]["blended"], cv2.COLOR_RGB2BGR))
+        
+        # 保存原始模型结果
+        if "original" in results:
+            cv2.imwrite(os.path.join(output_dir, "original_segmentation.png"), 
+                        cv2.cvtColor(results["original"]["rgb"], cv2.COLOR_RGB2BGR))
+            cv2.imwrite(os.path.join(output_dir, "original_blended.png"), 
+                        cv2.cvtColor(results["original"]["blended"], cv2.COLOR_RGB2BGR))
+            cv2.imwrite(os.path.join(output_dir, "original_difference.png"), 
+                        cv2.cvtColor(results["original"]["difference"], cv2.COLOR_RGB2BGR))
+        
+        # 保存微调模型结果
+        if "fine_tuned" in results:
+            cv2.imwrite(os.path.join(output_dir, "fine_tuned_segmentation.png"), 
+                        cv2.cvtColor(results["fine_tuned"]["rgb"], cv2.COLOR_RGB2BGR))
+            cv2.imwrite(os.path.join(output_dir, "fine_tuned_blended.png"), 
+                        cv2.cvtColor(results["fine_tuned"]["blended"], cv2.COLOR_RGB2BGR))
+            cv2.imwrite(os.path.join(output_dir, "fine_tuned_difference.png"), 
+                        cv2.cvtColor(results["fine_tuned"]["difference"], cv2.COLOR_RGB2BGR))
+        
+        # 创建三模型对比图 (并排显示)
+        # 计算每个模型的列宽
+        col_width = image.shape[1]
+        models_count = 1 + ("original" in results) + ("fine_tuned" in results)
+        
+        comparison = np.zeros((image.shape[0], col_width * models_count, 3), dtype=np.uint8)
+        
+        # 添加真实掩码
+        col_idx = 0
+        comparison[:, col_idx*col_width:(col_idx+1)*col_width] = results["ground_truth"]["blended"]
+        col_idx += 1
+        
+        # 添加原始模型结果
+        if "original" in results:
+            comparison[:, col_idx*col_width:(col_idx+1)*col_width] = results["original"]["blended"]
+            col_idx += 1
+        
+        # 添加微调模型结果
+        if "fine_tuned" in results:
+            comparison[:, col_idx*col_width:(col_idx+1)*col_width] = results["fine_tuned"]["blended"]
+        
+        # 添加标签
+        font = cv2.FONT_HERSHEY_SIMPLEX
+        
+        col_idx = 0
+        cv2.putText(comparison, "Ground Truth", (col_idx*col_width + 10, 30), font, 1, (255, 255, 255), 2)
+        col_idx += 1
+        
+        if "original" in results:
+            cv2.putText(comparison, "Original Model", (col_idx*col_width + 10, 30), font, 1, (255, 255, 255), 2)
+            col_idx += 1
+        
+        if "fine_tuned" in results:
+            cv2.putText(comparison, "Fine-tuned Model", (col_idx*col_width + 10, 30), font, 1, (255, 255, 255), 2)
+        
+        cv2.imwrite(os.path.join(output_dir, "model_comparison.png"), 
+                    cv2.cvtColor(comparison, cv2.COLOR_RGB2BGR))
+        
+        print(f"所有结果已保存到 {output_dir} 目录")
+    
+    # 显示结果(在有GUI的环境中)
+    if show_results:
+        try:
+            if "original" in results and "fine_tuned" in results:
+                cv2.imshow("模型对比", cv2.cvtColor(comparison, cv2.COLOR_RGB2BGR))
+            print("按任意键关闭窗口...")
+            cv2.waitKey(0)
+            cv2.destroyAllWindows()
+        except Exception as e:
+            print(f"无法显示图像,可能是在无GUI环境中运行: {e}")
+            print("结果已保存到指定目录")
+    
+    return results
+
+
+def find_related_masks_advanced(image_path, masks_dir):
+        """
+        使用更高级的方法查找与图像相关的掩码
+        """
+        
+        # 获取图像文件名(不含扩展名)
+        image_filename = os.path.basename(image_path)
+        image_name, _ = os.path.splitext(image_filename)
+        
+        # 基本情况:掩码名以图像名开头
+        pattern1 = os.path.join(masks_dir, f"{image_name}*.png")
+        masks1 = glob.glob(pattern1)
+        
+        # 从图像名中提取患者ID和序列号,用于更复杂的匹配
+        # 例如: a4c_PatientD0062_a4c_93.jsonD_116.jpg
+        # 提取: PatientD0062, 93, 116
+        matches = re.search(r'(Patient[A-Za-z0-9]+).*?(\d+)\.json[A-Z]_(\d+)', image_name)
+        
+        if matches:
+            patient_id = matches.group(1)
+            seq_num1 = matches.group(2)
+            seq_num2 = matches.group(3)
+            
+            # 使用提取的信息尝试其他可能的匹配模式
+            pattern2 = os.path.join(masks_dir, f"*{patient_id}*{seq_num1}*{seq_num2}*.png")
+            masks2 = glob.glob(pattern2)
+            
+            # 合并去重
+            all_masks = list(set(masks1 + masks2))
+        else:
+            all_masks = masks1
+        
+        print(f"为图像 {image_filename} 找到 {len(all_masks)} 个相关掩码:")
+        for mask in all_masks:
+            print(f"  - {os.path.basename(mask)}")
+        
+        return all_masks
+ 
+    
+    # 读取图像
+def process_image_by_name(data_dir, image_filename, model_cfg, sam2_checkpoint, fine_tuned_path=None, output_dir=None):
+        """
+        通过图像文件名处理单个图像,自动查找相关掩码
+        
+        参数:
+            data_dir: 数据根目录
+            image_filename: 图像文件名(不含路径)
+            model_cfg: 模型配置文件路径
+            sam2_checkpoint: 原始模型检查点路径
+            fine_tuned_path: 微调模型路径
+            output_dir: 输出目录
+        """
+        images_dir = os.path.join(data_dir, "JPEGImages")
+        masks_dir = os.path.join(data_dir, "Annotations")
+        
+        # 构建完整图像路径
+        image_path = os.path.join(images_dir, image_filename)
+        
+        if not os.path.exists(image_path):
+            print(f"图像文件不存在: {image_path}")
+            return
+        
+        # 自动查找相关掩码
+        mask_paths = find_related_masks_advanced(image_path, masks_dir)
+        
+        if not mask_paths:
+            print("未找到相关掩码文件")
+            return
+        
+        # 设置输出目录(使用图像名作为子目录)
+        if output_dir is None:
+            image_name, _ = os.path.splitext(image_filename)
+            output_dir = os.path.join("results", image_name)
+        
+        # 运行模型比较
+        results = inference_sam2_compare_models(
+            image_path=image_path,
+            mask_paths=mask_paths,
+            model_cfg=model_cfg,
+            sam2_checkpoint=sam2_checkpoint,
+            fine_tuned_path=fine_tuned_path,
+            output_dir=output_dir,
+            show_results=False
+        )
+        
+        return results
+
+
+if __name__ == "__main__":
+    # 设置基本路径
+    data_dir = "/media/ps/data/zhy/Sam2_new/sam2/data_train"
+    images_dir = os.path.join(data_dir, "JPEGImages")
+    masks_dir = os.path.join(data_dir, "Annotations")
+    
+    # 模型配置
+    model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
+    sam2_checkpoint = "checkpoints/sam2.1_hiera_large.pt"
+    fine_tuned_path = "models/model_final.torch"
+    
+    # 指定图像文件
+    image_filename = "a4c_PatientD0062_a4c_93.jsonD_116.jpg"
+    image_path = os.path.join(images_dir, image_filename)
+    
+    print(f"检查图像文件...")
+    if not os.path.exists(image_path):
+        print(f"图像文件不存在: {image_path}")
+        exit(1)
+    
+    # 自动查找相关掩码
+    mask_paths = find_related_masks_advanced(image_path, masks_dir)
+    
+    # 检查掩码文件
+    valid_masks = []
+    for mask_path in mask_paths:
+        if os.path.exists(mask_path):
+            valid_masks.append(mask_path)
+            print(f"掩码文件存在: {mask_path}")
+        else:
+            print(f"掩码文件不存在: {mask_path}")
+    
+    if not valid_masks:
+        print("没有有效的掩码文件")
+        exit(1)
+    
+    # 检查其他文件
+    for path, desc in [
+        (model_cfg, "模型配置文件"),
+        (sam2_checkpoint, "预训练检查点"),
+        (fine_tuned_path, "微调模型")
+    ]:
+        exists = os.path.exists(path)
+        print(f"{desc}: {path} {'存在' if exists else '不存在'}")
+    
+    # 设置输出目录
+    output_dir = os.path.join("results_comparison", os.path.splitext(image_filename)[0])
+    
+    try:
+        # 运行推理
+        print("\n开始模型对比推理...")
+        results = inference_sam2_compare_models(
+            image_path=image_path,
+            mask_paths=valid_masks,
+            model_cfg=model_cfg,
+            sam2_checkpoint=sam2_checkpoint,
+            fine_tuned_path=fine_tuned_path,
+            num_samples_per_mask=10,
+            output_dir=output_dir,
+            show_results=False
+        )
+        print("对比推理完成!")
+        
+        # 输出简单的性能指标
+        if "original" in results and "fine_tuned" in results:
+            gt_map = results["ground_truth"]["map"]
+            
+            for model_key in ["original", "fine_tuned"]:
+                model_map = results[model_key]["map"]
+                
+                # 计算基本指标
+                true_positive = np.sum(np.logical_and(gt_map > 0, model_map > 0))
+                false_negative = np.sum(np.logical_and(gt_map > 0, model_map == 0))
+                false_positive = np.sum(np.logical_and(gt_map == 0, model_map > 0))
+                
+                # 计算性能指标
+                precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0
+                recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
+                f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
+                
+                print(f"\n{model_key.capitalize()}模型性能:")
+                print(f"精确率(Precision): {precision:.4f}")
+                print(f"召回率(Recall): {recall:.4f}")  
+                print(f"F1分数: {f1:.4f}")
+    
+    except Exception as e:
+        print(f"推理过程中出错: {e}")
+        import traceback
+        traceback.print_exc()
\ No newline at end of file