--- a +++ b/testing.py @@ -0,0 +1,513 @@ +import os +import cv2 +import numpy as np +import torch +from torch.cuda.amp import autocast + +import re +import glob + +# 导入SAM2相关模块 +from sam2.build_sam import build_sam2 +from sam2.sam2_image_predictor import SAM2ImagePredictor + + +def inference_sam2_compare_models( + image_path, + mask_paths, + model_cfg, + sam2_checkpoint, + fine_tuned_path=None, + num_samples_per_mask=10, + output_dir=None, + show_results=False, + device="cuda" +): + + + + def read_image(image_path): + if not os.path.exists(image_path): + raise FileNotFoundError(f"图像文件不存在: {image_path}") + + img = cv2.imread(image_path) + if img is None: + raise ValueError(f"无法读取图像: {image_path}") + img = img[..., ::-1] # BGR to RGB + + # 将图像调整到最大尺寸1024 + resize_factor = np.min([1024 / img.shape[1], 1024 / img.shape[0]]) + img = cv2.resize(img, (int(img.shape[1] * resize_factor), int(img.shape[0] * resize_factor))) + return img, resize_factor + + # 在输入掩码内采样点 + def get_points(mask, num_points): + points = [] + coords = np.argwhere(mask > 0) + + if len(coords) == 0: + print("警告:掩码为空,无法采样点") + return np.array([]) + + for _ in range(num_points): + if len(coords) > 0: + idx = np.random.randint(len(coords)) + yx = coords[idx] + points.append([[yx[1], yx[0]]]) + + return np.array(points) if points else np.array([]) + + # 读取图像 + image, resize_factor = read_image(image_path) + + # 读取所有掩码并采样点 + all_points = [] + all_point_labels = [] + all_masks = [] + + for i, mask_path in enumerate(mask_paths): + if not os.path.exists(mask_path): + print(f"警告:掩码文件不存在: {mask_path},将跳过") + continue + + mask = cv2.imread(mask_path, 0) # 读取掩码 + if mask is None: + print(f"警告:无法读取掩码: {mask_path},将跳过") + continue + + # 调整掩码大小以匹配图像 - 使用resize_factor而不是r + mask = cv2.resize(mask, (int(mask.shape[1] * resize_factor), int(mask.shape[0] * resize_factor)), + interpolation=cv2.INTER_NEAREST) + + all_masks.append(mask) + + # 获取该掩码的点 + points = get_points(mask, num_samples_per_mask) + if len(points) > 0: + all_points.append(points) + all_point_labels.append(np.ones((len(points), 1))) + + if not all_points: + raise ValueError("所有掩码都无法采样点,无法进行分割") + + + # 合并所有掩码的点和标签 + input_points = np.concatenate(all_points, axis=0) + input_labels = np.concatenate(all_point_labels, axis=0) + + # 创建真实掩码的可视化 + ground_truth_map = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) + ground_truth_rgb = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.uint8) + + # 为每个掩码分配唯一的颜色和ID + colors = {} + + for i, mask in enumerate(all_masks): + # 生成随机颜色,但确保所有可视化中使用相同颜色 + if i+1 not in colors: + colors[i+1] = [ + np.random.randint(255), + np.random.randint(255), + np.random.randint(255) + ] + + # 更新掩码图和RGB图 + ground_truth_map[mask > 0] = i + 1 + ground_truth_rgb[mask > 0] = colors[i+1] + + # 创建真实掩码的混合图像 + gt_blended_image = (ground_truth_rgb / 2 + image / 2).astype(np.uint8) + + # 初始化结果字典 + results = { + "ground_truth": { + "map": ground_truth_map, + "rgb": ground_truth_rgb, + "blended": gt_blended_image + } + } + + # 使用两个模型进行推理 + for model_key, model_path in [ + ("original", sam2_checkpoint), + ("fine_tuned", fine_tuned_path) + ]: + # 如果是微调模型但路径为None,则跳过 + if model_key == "fine_tuned" and model_path is None: + continue + + print(f"\n使用{model_key}模型进行推理...") + + # 初始化SAM2模型 + sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device) + predictor = SAM2ImagePredictor(sam2_model) + + # 如果是微调模型,加载微调权重 + if model_key == "fine_tuned": + try: + predictor.model.load_state_dict(torch.load(model_path, weights_only=True)) + except: + predictor.model.load_state_dict(torch.load(model_path)) + print(f"已加载微调模型: {model_path}") + + # 设置模型为评估模式 + predictor.model.eval() + + # 使用混合精度进行推理 + with torch.no_grad(): + # 图像编码器 + predictor.set_image(image) + + print(f"{model_key}模型正在进行预测...") + print("输入点形状:", input_points.shape) + + # prompt编码器 + mask解码器 + masks, scores, _ = predictor.predict( + point_coords=input_points, + point_labels=input_labels, + multimask_output=True + ) + + print(f"{model_key}模型预测完成") + print(f"预测掩码数量: {masks.shape[0] if hasattr(masks, 'shape') else '未知'}") + + # 处理masks和scores + if isinstance(masks, torch.Tensor): + np_masks = masks.cpu().numpy() + if np_masks.ndim > 3: + np_masks = np_masks[:, 0] + else: + np_masks = masks + if np_masks.ndim > 3: + np_masks = np_masks[:, 0] + + if isinstance(scores, torch.Tensor): + np_scores = scores.cpu().numpy() + if np_scores.ndim > 1: + np_scores = np_scores[:, 0] + else: + np_scores = scores + if np_scores.ndim > 1: + np_scores = np_scores[:, 0] + + # 根据分数排序掩码(降序) + sorted_indices = np.argsort(np_scores)[::-1] + sorted_masks = np_masks[sorted_indices] + + # 创建分割图 + seg_map = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) + occupancy_mask = np.zeros((image.shape[0], image.shape[1]), dtype=bool) + + for i, mask in enumerate(sorted_masks): + # 确保掩码是布尔类型 + mask_bool = mask.astype(bool) + + # 计算重叠比例 + # 使用逻辑与运算而不是位与运算 + overlap_ratio = np.logical_and(mask_bool, occupancy_mask).sum() / (mask_bool.sum() + 1e-6) + + if mask_bool.sum() > 0 and overlap_ratio > 0.15: + continue + + # 更新掩码和分割图 + non_overlap = np.logical_and(mask_bool, np.logical_not(occupancy_mask)) + seg_map[non_overlap] = i + 1 + occupancy_mask[mask_bool] = True + + # 创建可视化图像(使用与真实掩码相同的颜色) + rgb_image = np.zeros((seg_map.shape[0], seg_map.shape[1], 3), dtype=np.uint8) + for id_class in range(1, seg_map.max() + 1): + # 尝试使用与真实掩码相同的颜色 + color = colors.get(id_class, [ + np.random.randint(255), + np.random.randint(255), + np.random.randint(255) + ]) + rgb_image[seg_map == id_class] = color + + # 创建混合图像 + blended_image = (rgb_image / 2 + image / 2).astype(np.uint8) + + # 创建差异图 (显示预测与真实值的不同) + difference = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.uint8) + + # 绿色表示正确预测的区域 + correct = np.logical_and(ground_truth_map > 0, seg_map > 0) + difference[correct] = [0, 255, 0] + + # 红色表示未预测到的区域 (漏检) + missed = np.logical_and(ground_truth_map > 0, seg_map == 0) + difference[missed] = [255, 0, 0] + + # 蓝色表示错误预测的区域 (误检) + false_positive = np.logical_and(ground_truth_map == 0, seg_map > 0) + difference[false_positive] = [0, 0, 255] + + # 保存到结果字典 + results[model_key] = { + "map": seg_map, + "rgb": rgb_image, + "blended": blended_image, + "difference": difference + } + + # 保存结果 + if output_dir: + os.makedirs(output_dir, exist_ok=True) + + # 保存真实掩码 + cv2.imwrite(os.path.join(output_dir, "gt_segmentation.png"), + cv2.cvtColor(results["ground_truth"]["rgb"], cv2.COLOR_RGB2BGR)) + cv2.imwrite(os.path.join(output_dir, "gt_blended.png"), + cv2.cvtColor(results["ground_truth"]["blended"], cv2.COLOR_RGB2BGR)) + + # 保存原始模型结果 + if "original" in results: + cv2.imwrite(os.path.join(output_dir, "original_segmentation.png"), + cv2.cvtColor(results["original"]["rgb"], cv2.COLOR_RGB2BGR)) + cv2.imwrite(os.path.join(output_dir, "original_blended.png"), + cv2.cvtColor(results["original"]["blended"], cv2.COLOR_RGB2BGR)) + cv2.imwrite(os.path.join(output_dir, "original_difference.png"), + cv2.cvtColor(results["original"]["difference"], cv2.COLOR_RGB2BGR)) + + # 保存微调模型结果 + if "fine_tuned" in results: + cv2.imwrite(os.path.join(output_dir, "fine_tuned_segmentation.png"), + cv2.cvtColor(results["fine_tuned"]["rgb"], cv2.COLOR_RGB2BGR)) + cv2.imwrite(os.path.join(output_dir, "fine_tuned_blended.png"), + cv2.cvtColor(results["fine_tuned"]["blended"], cv2.COLOR_RGB2BGR)) + cv2.imwrite(os.path.join(output_dir, "fine_tuned_difference.png"), + cv2.cvtColor(results["fine_tuned"]["difference"], cv2.COLOR_RGB2BGR)) + + # 创建三模型对比图 (并排显示) + # 计算每个模型的列宽 + col_width = image.shape[1] + models_count = 1 + ("original" in results) + ("fine_tuned" in results) + + comparison = np.zeros((image.shape[0], col_width * models_count, 3), dtype=np.uint8) + + # 添加真实掩码 + col_idx = 0 + comparison[:, col_idx*col_width:(col_idx+1)*col_width] = results["ground_truth"]["blended"] + col_idx += 1 + + # 添加原始模型结果 + if "original" in results: + comparison[:, col_idx*col_width:(col_idx+1)*col_width] = results["original"]["blended"] + col_idx += 1 + + # 添加微调模型结果 + if "fine_tuned" in results: + comparison[:, col_idx*col_width:(col_idx+1)*col_width] = results["fine_tuned"]["blended"] + + # 添加标签 + font = cv2.FONT_HERSHEY_SIMPLEX + + col_idx = 0 + cv2.putText(comparison, "Ground Truth", (col_idx*col_width + 10, 30), font, 1, (255, 255, 255), 2) + col_idx += 1 + + if "original" in results: + cv2.putText(comparison, "Original Model", (col_idx*col_width + 10, 30), font, 1, (255, 255, 255), 2) + col_idx += 1 + + if "fine_tuned" in results: + cv2.putText(comparison, "Fine-tuned Model", (col_idx*col_width + 10, 30), font, 1, (255, 255, 255), 2) + + cv2.imwrite(os.path.join(output_dir, "model_comparison.png"), + cv2.cvtColor(comparison, cv2.COLOR_RGB2BGR)) + + print(f"所有结果已保存到 {output_dir} 目录") + + # 显示结果(在有GUI的环境中) + if show_results: + try: + if "original" in results and "fine_tuned" in results: + cv2.imshow("模型对比", cv2.cvtColor(comparison, cv2.COLOR_RGB2BGR)) + print("按任意键关闭窗口...") + cv2.waitKey(0) + cv2.destroyAllWindows() + except Exception as e: + print(f"无法显示图像,可能是在无GUI环境中运行: {e}") + print("结果已保存到指定目录") + + return results + + +def find_related_masks_advanced(image_path, masks_dir): + """ + 使用更高级的方法查找与图像相关的掩码 + """ + + # 获取图像文件名(不含扩展名) + image_filename = os.path.basename(image_path) + image_name, _ = os.path.splitext(image_filename) + + # 基本情况:掩码名以图像名开头 + pattern1 = os.path.join(masks_dir, f"{image_name}*.png") + masks1 = glob.glob(pattern1) + + # 从图像名中提取患者ID和序列号,用于更复杂的匹配 + # 例如: a4c_PatientD0062_a4c_93.jsonD_116.jpg + # 提取: PatientD0062, 93, 116 + matches = re.search(r'(Patient[A-Za-z0-9]+).*?(\d+)\.json[A-Z]_(\d+)', image_name) + + if matches: + patient_id = matches.group(1) + seq_num1 = matches.group(2) + seq_num2 = matches.group(3) + + # 使用提取的信息尝试其他可能的匹配模式 + pattern2 = os.path.join(masks_dir, f"*{patient_id}*{seq_num1}*{seq_num2}*.png") + masks2 = glob.glob(pattern2) + + # 合并去重 + all_masks = list(set(masks1 + masks2)) + else: + all_masks = masks1 + + print(f"为图像 {image_filename} 找到 {len(all_masks)} 个相关掩码:") + for mask in all_masks: + print(f" - {os.path.basename(mask)}") + + return all_masks + + + # 读取图像 +def process_image_by_name(data_dir, image_filename, model_cfg, sam2_checkpoint, fine_tuned_path=None, output_dir=None): + """ + 通过图像文件名处理单个图像,自动查找相关掩码 + + 参数: + data_dir: 数据根目录 + image_filename: 图像文件名(不含路径) + model_cfg: 模型配置文件路径 + sam2_checkpoint: 原始模型检查点路径 + fine_tuned_path: 微调模型路径 + output_dir: 输出目录 + """ + images_dir = os.path.join(data_dir, "JPEGImages") + masks_dir = os.path.join(data_dir, "Annotations") + + # 构建完整图像路径 + image_path = os.path.join(images_dir, image_filename) + + if not os.path.exists(image_path): + print(f"图像文件不存在: {image_path}") + return + + # 自动查找相关掩码 + mask_paths = find_related_masks_advanced(image_path, masks_dir) + + if not mask_paths: + print("未找到相关掩码文件") + return + + # 设置输出目录(使用图像名作为子目录) + if output_dir is None: + image_name, _ = os.path.splitext(image_filename) + output_dir = os.path.join("results", image_name) + + # 运行模型比较 + results = inference_sam2_compare_models( + image_path=image_path, + mask_paths=mask_paths, + model_cfg=model_cfg, + sam2_checkpoint=sam2_checkpoint, + fine_tuned_path=fine_tuned_path, + output_dir=output_dir, + show_results=False + ) + + return results + + +if __name__ == "__main__": + # 设置基本路径 + data_dir = "/media/ps/data/zhy/Sam2_new/sam2/data_train" + images_dir = os.path.join(data_dir, "JPEGImages") + masks_dir = os.path.join(data_dir, "Annotations") + + # 模型配置 + model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" + sam2_checkpoint = "checkpoints/sam2.1_hiera_large.pt" + fine_tuned_path = "models/model_final.torch" + + # 指定图像文件 + image_filename = "a4c_PatientD0062_a4c_93.jsonD_116.jpg" + image_path = os.path.join(images_dir, image_filename) + + print(f"检查图像文件...") + if not os.path.exists(image_path): + print(f"图像文件不存在: {image_path}") + exit(1) + + # 自动查找相关掩码 + mask_paths = find_related_masks_advanced(image_path, masks_dir) + + # 检查掩码文件 + valid_masks = [] + for mask_path in mask_paths: + if os.path.exists(mask_path): + valid_masks.append(mask_path) + print(f"掩码文件存在: {mask_path}") + else: + print(f"掩码文件不存在: {mask_path}") + + if not valid_masks: + print("没有有效的掩码文件") + exit(1) + + # 检查其他文件 + for path, desc in [ + (model_cfg, "模型配置文件"), + (sam2_checkpoint, "预训练检查点"), + (fine_tuned_path, "微调模型") + ]: + exists = os.path.exists(path) + print(f"{desc}: {path} {'存在' if exists else '不存在'}") + + # 设置输出目录 + output_dir = os.path.join("results_comparison", os.path.splitext(image_filename)[0]) + + try: + # 运行推理 + print("\n开始模型对比推理...") + results = inference_sam2_compare_models( + image_path=image_path, + mask_paths=valid_masks, + model_cfg=model_cfg, + sam2_checkpoint=sam2_checkpoint, + fine_tuned_path=fine_tuned_path, + num_samples_per_mask=10, + output_dir=output_dir, + show_results=False + ) + print("对比推理完成!") + + # 输出简单的性能指标 + if "original" in results and "fine_tuned" in results: + gt_map = results["ground_truth"]["map"] + + for model_key in ["original", "fine_tuned"]: + model_map = results[model_key]["map"] + + # 计算基本指标 + true_positive = np.sum(np.logical_and(gt_map > 0, model_map > 0)) + false_negative = np.sum(np.logical_and(gt_map > 0, model_map == 0)) + false_positive = np.sum(np.logical_and(gt_map == 0, model_map > 0)) + + # 计算性能指标 + precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0 + recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0 + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 + + print(f"\n{model_key.capitalize()}模型性能:") + print(f"精确率(Precision): {precision:.4f}") + print(f"召回率(Recall): {recall:.4f}") + print(f"F1分数: {f1:.4f}") + + except Exception as e: + print(f"推理过程中出错: {e}") + import traceback + traceback.print_exc() \ No newline at end of file