Diff of /testing.py [000000] .. [d255cc]

Switch to unified view

a b/testing.py
1
import os
2
import cv2
3
import numpy as np
4
import torch
5
from torch.cuda.amp import autocast
6
7
import re
8
import glob
9
10
# 导入SAM2相关模块
11
from sam2.build_sam import build_sam2
12
from sam2.sam2_image_predictor import SAM2ImagePredictor
13
14
15
def inference_sam2_compare_models(
16
    image_path,
17
    mask_paths,  
18
    model_cfg,
19
    sam2_checkpoint,
20
    fine_tuned_path=None,
21
    num_samples_per_mask=10,
22
    output_dir=None,  
23
    show_results=False,
24
    device="cuda"
25
):
26
27
28
29
    def read_image(image_path):
30
        if not os.path.exists(image_path):
31
            raise FileNotFoundError(f"图像文件不存在: {image_path}")
32
            
33
        img = cv2.imread(image_path)
34
        if img is None:
35
            raise ValueError(f"无法读取图像: {image_path}")
36
        img = img[..., ::-1]  # BGR to RGB
37
        
38
        # 将图像调整到最大尺寸1024
39
        resize_factor = np.min([1024 / img.shape[1], 1024 / img.shape[0]])
40
        img = cv2.resize(img, (int(img.shape[1] * resize_factor), int(img.shape[0] * resize_factor)))
41
        return img, resize_factor
42
    
43
    # 在输入掩码内采样点
44
    def get_points(mask, num_points):
45
        points = []
46
        coords = np.argwhere(mask > 0)
47
        
48
        if len(coords) == 0:
49
            print("警告:掩码为空,无法采样点")
50
            return np.array([])
51
            
52
        for _ in range(num_points):
53
            if len(coords) > 0:
54
                idx = np.random.randint(len(coords))
55
                yx = coords[idx]
56
                points.append([[yx[1], yx[0]]])
57
        
58
        return np.array(points) if points else np.array([])
59
    
60
    # 读取图像
61
    image, resize_factor = read_image(image_path)
62
    
63
    # 读取所有掩码并采样点
64
    all_points = []
65
    all_point_labels = []
66
    all_masks = []
67
    
68
    for i, mask_path in enumerate(mask_paths):
69
        if not os.path.exists(mask_path):
70
            print(f"警告:掩码文件不存在: {mask_path},将跳过")
71
            continue
72
            
73
        mask = cv2.imread(mask_path, 0)  # 读取掩码
74
        if mask is None:
75
            print(f"警告:无法读取掩码: {mask_path},将跳过")
76
            continue
77
            
78
        # 调整掩码大小以匹配图像 - 使用resize_factor而不是r
79
        mask = cv2.resize(mask, (int(mask.shape[1] * resize_factor), int(mask.shape[0] * resize_factor)), 
80
                         interpolation=cv2.INTER_NEAREST)
81
        
82
        all_masks.append(mask)
83
        
84
        # 获取该掩码的点
85
        points = get_points(mask, num_samples_per_mask)
86
        if len(points) > 0:
87
            all_points.append(points)
88
            all_point_labels.append(np.ones((len(points), 1)))
89
    
90
    if not all_points:
91
        raise ValueError("所有掩码都无法采样点,无法进行分割")
92
    
93
    
94
    # 合并所有掩码的点和标签
95
    input_points = np.concatenate(all_points, axis=0)
96
    input_labels = np.concatenate(all_point_labels, axis=0)
97
    
98
    # 创建真实掩码的可视化
99
    ground_truth_map = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
100
    ground_truth_rgb = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.uint8)
101
    
102
    # 为每个掩码分配唯一的颜色和ID
103
    colors = {}
104
    
105
    for i, mask in enumerate(all_masks):
106
        # 生成随机颜色,但确保所有可视化中使用相同颜色
107
        if i+1 not in colors:
108
            colors[i+1] = [
109
                np.random.randint(255),
110
                np.random.randint(255),
111
                np.random.randint(255)
112
            ]
113
        
114
        # 更新掩码图和RGB图
115
        ground_truth_map[mask > 0] = i + 1
116
        ground_truth_rgb[mask > 0] = colors[i+1]
117
    
118
    # 创建真实掩码的混合图像
119
    gt_blended_image = (ground_truth_rgb / 2 + image / 2).astype(np.uint8)
120
    
121
    # 初始化结果字典
122
    results = {
123
        "ground_truth": {
124
            "map": ground_truth_map,
125
            "rgb": ground_truth_rgb,
126
            "blended": gt_blended_image
127
        }
128
    }
129
    
130
    # 使用两个模型进行推理
131
    for model_key, model_path in [
132
        ("original", sam2_checkpoint), 
133
        ("fine_tuned", fine_tuned_path)
134
    ]:
135
        # 如果是微调模型但路径为None,则跳过
136
        if model_key == "fine_tuned" and model_path is None:
137
            continue
138
            
139
        print(f"\n使用{model_key}模型进行推理...")
140
        
141
        # 初始化SAM2模型
142
        sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
143
        predictor = SAM2ImagePredictor(sam2_model)
144
        
145
        # 如果是微调模型,加载微调权重
146
        if model_key == "fine_tuned":
147
            try:
148
                predictor.model.load_state_dict(torch.load(model_path, weights_only=True))
149
            except:
150
                predictor.model.load_state_dict(torch.load(model_path))
151
            print(f"已加载微调模型: {model_path}")
152
        
153
        # 设置模型为评估模式
154
        predictor.model.eval()
155
        
156
        # 使用混合精度进行推理
157
        with torch.no_grad():
158
            # 图像编码器
159
            predictor.set_image(image)
160
            
161
            print(f"{model_key}模型正在进行预测...")
162
            print("输入点形状:", input_points.shape)
163
            
164
            # prompt编码器 + mask解码器
165
            masks, scores, _ = predictor.predict(
166
                point_coords=input_points,
167
                point_labels=input_labels,
168
                multimask_output=True
169
            )
170
            
171
            print(f"{model_key}模型预测完成")
172
            print(f"预测掩码数量: {masks.shape[0] if hasattr(masks, 'shape') else '未知'}")
173
        
174
        # 处理masks和scores
175
        if isinstance(masks, torch.Tensor):
176
            np_masks = masks.cpu().numpy()
177
            if np_masks.ndim > 3:
178
                np_masks = np_masks[:, 0]
179
        else:
180
            np_masks = masks
181
            if np_masks.ndim > 3:
182
                np_masks = np_masks[:, 0]
183
        
184
        if isinstance(scores, torch.Tensor):
185
            np_scores = scores.cpu().numpy()
186
            if np_scores.ndim > 1:
187
                np_scores = np_scores[:, 0]
188
        else:
189
            np_scores = scores
190
            if np_scores.ndim > 1:
191
                np_scores = np_scores[:, 0]
192
        
193
        # 根据分数排序掩码(降序)
194
        sorted_indices = np.argsort(np_scores)[::-1]
195
        sorted_masks = np_masks[sorted_indices]
196
        
197
        # 创建分割图
198
        seg_map = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
199
        occupancy_mask = np.zeros((image.shape[0], image.shape[1]), dtype=bool)
200
        
201
        for i, mask in enumerate(sorted_masks):
202
            # 确保掩码是布尔类型
203
            mask_bool = mask.astype(bool)
204
            
205
            # 计算重叠比例
206
            # 使用逻辑与运算而不是位与运算
207
            overlap_ratio = np.logical_and(mask_bool, occupancy_mask).sum() / (mask_bool.sum() + 1e-6)
208
            
209
            if mask_bool.sum() > 0 and overlap_ratio > 0.15:
210
                continue
211
                
212
            # 更新掩码和分割图
213
            non_overlap = np.logical_and(mask_bool, np.logical_not(occupancy_mask))
214
            seg_map[non_overlap] = i + 1
215
            occupancy_mask[mask_bool] = True
216
        
217
        # 创建可视化图像(使用与真实掩码相同的颜色)
218
        rgb_image = np.zeros((seg_map.shape[0], seg_map.shape[1], 3), dtype=np.uint8)
219
        for id_class in range(1, seg_map.max() + 1):
220
            # 尝试使用与真实掩码相同的颜色
221
            color = colors.get(id_class, [
222
                np.random.randint(255),
223
                np.random.randint(255),
224
                np.random.randint(255)
225
            ])
226
            rgb_image[seg_map == id_class] = color
227
        
228
        # 创建混合图像
229
        blended_image = (rgb_image / 2 + image / 2).astype(np.uint8)
230
        
231
        # 创建差异图 (显示预测与真实值的不同)
232
        difference = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.uint8)
233
        
234
        # 绿色表示正确预测的区域
235
        correct = np.logical_and(ground_truth_map > 0, seg_map > 0)
236
        difference[correct] = [0, 255, 0]
237
        
238
        # 红色表示未预测到的区域 (漏检)
239
        missed = np.logical_and(ground_truth_map > 0, seg_map == 0)
240
        difference[missed] = [255, 0, 0]
241
        
242
        # 蓝色表示错误预测的区域 (误检)
243
        false_positive = np.logical_and(ground_truth_map == 0, seg_map > 0)
244
        difference[false_positive] = [0, 0, 255]
245
        
246
        # 保存到结果字典
247
        results[model_key] = {
248
            "map": seg_map,
249
            "rgb": rgb_image,
250
            "blended": blended_image,
251
            "difference": difference
252
        }
253
    
254
    # 保存结果
255
    if output_dir:
256
        os.makedirs(output_dir, exist_ok=True)
257
        
258
        # 保存真实掩码
259
        cv2.imwrite(os.path.join(output_dir, "gt_segmentation.png"), 
260
                    cv2.cvtColor(results["ground_truth"]["rgb"], cv2.COLOR_RGB2BGR))
261
        cv2.imwrite(os.path.join(output_dir, "gt_blended.png"), 
262
                    cv2.cvtColor(results["ground_truth"]["blended"], cv2.COLOR_RGB2BGR))
263
        
264
        # 保存原始模型结果
265
        if "original" in results:
266
            cv2.imwrite(os.path.join(output_dir, "original_segmentation.png"), 
267
                        cv2.cvtColor(results["original"]["rgb"], cv2.COLOR_RGB2BGR))
268
            cv2.imwrite(os.path.join(output_dir, "original_blended.png"), 
269
                        cv2.cvtColor(results["original"]["blended"], cv2.COLOR_RGB2BGR))
270
            cv2.imwrite(os.path.join(output_dir, "original_difference.png"), 
271
                        cv2.cvtColor(results["original"]["difference"], cv2.COLOR_RGB2BGR))
272
        
273
        # 保存微调模型结果
274
        if "fine_tuned" in results:
275
            cv2.imwrite(os.path.join(output_dir, "fine_tuned_segmentation.png"), 
276
                        cv2.cvtColor(results["fine_tuned"]["rgb"], cv2.COLOR_RGB2BGR))
277
            cv2.imwrite(os.path.join(output_dir, "fine_tuned_blended.png"), 
278
                        cv2.cvtColor(results["fine_tuned"]["blended"], cv2.COLOR_RGB2BGR))
279
            cv2.imwrite(os.path.join(output_dir, "fine_tuned_difference.png"), 
280
                        cv2.cvtColor(results["fine_tuned"]["difference"], cv2.COLOR_RGB2BGR))
281
        
282
        # 创建三模型对比图 (并排显示)
283
        # 计算每个模型的列宽
284
        col_width = image.shape[1]
285
        models_count = 1 + ("original" in results) + ("fine_tuned" in results)
286
        
287
        comparison = np.zeros((image.shape[0], col_width * models_count, 3), dtype=np.uint8)
288
        
289
        # 添加真实掩码
290
        col_idx = 0
291
        comparison[:, col_idx*col_width:(col_idx+1)*col_width] = results["ground_truth"]["blended"]
292
        col_idx += 1
293
        
294
        # 添加原始模型结果
295
        if "original" in results:
296
            comparison[:, col_idx*col_width:(col_idx+1)*col_width] = results["original"]["blended"]
297
            col_idx += 1
298
        
299
        # 添加微调模型结果
300
        if "fine_tuned" in results:
301
            comparison[:, col_idx*col_width:(col_idx+1)*col_width] = results["fine_tuned"]["blended"]
302
        
303
        # 添加标签
304
        font = cv2.FONT_HERSHEY_SIMPLEX
305
        
306
        col_idx = 0
307
        cv2.putText(comparison, "Ground Truth", (col_idx*col_width + 10, 30), font, 1, (255, 255, 255), 2)
308
        col_idx += 1
309
        
310
        if "original" in results:
311
            cv2.putText(comparison, "Original Model", (col_idx*col_width + 10, 30), font, 1, (255, 255, 255), 2)
312
            col_idx += 1
313
        
314
        if "fine_tuned" in results:
315
            cv2.putText(comparison, "Fine-tuned Model", (col_idx*col_width + 10, 30), font, 1, (255, 255, 255), 2)
316
        
317
        cv2.imwrite(os.path.join(output_dir, "model_comparison.png"), 
318
                    cv2.cvtColor(comparison, cv2.COLOR_RGB2BGR))
319
        
320
        print(f"所有结果已保存到 {output_dir} 目录")
321
    
322
    # 显示结果(在有GUI的环境中)
323
    if show_results:
324
        try:
325
            if "original" in results and "fine_tuned" in results:
326
                cv2.imshow("模型对比", cv2.cvtColor(comparison, cv2.COLOR_RGB2BGR))
327
            print("按任意键关闭窗口...")
328
            cv2.waitKey(0)
329
            cv2.destroyAllWindows()
330
        except Exception as e:
331
            print(f"无法显示图像,可能是在无GUI环境中运行: {e}")
332
            print("结果已保存到指定目录")
333
    
334
    return results
335
336
337
def find_related_masks_advanced(image_path, masks_dir):
338
        """
339
        使用更高级的方法查找与图像相关的掩码
340
        """
341
        
342
        # 获取图像文件名(不含扩展名)
343
        image_filename = os.path.basename(image_path)
344
        image_name, _ = os.path.splitext(image_filename)
345
        
346
        # 基本情况:掩码名以图像名开头
347
        pattern1 = os.path.join(masks_dir, f"{image_name}*.png")
348
        masks1 = glob.glob(pattern1)
349
        
350
        # 从图像名中提取患者ID和序列号,用于更复杂的匹配
351
        # 例如: a4c_PatientD0062_a4c_93.jsonD_116.jpg
352
        # 提取: PatientD0062, 93, 116
353
        matches = re.search(r'(Patient[A-Za-z0-9]+).*?(\d+)\.json[A-Z]_(\d+)', image_name)
354
        
355
        if matches:
356
            patient_id = matches.group(1)
357
            seq_num1 = matches.group(2)
358
            seq_num2 = matches.group(3)
359
            
360
            # 使用提取的信息尝试其他可能的匹配模式
361
            pattern2 = os.path.join(masks_dir, f"*{patient_id}*{seq_num1}*{seq_num2}*.png")
362
            masks2 = glob.glob(pattern2)
363
            
364
            # 合并去重
365
            all_masks = list(set(masks1 + masks2))
366
        else:
367
            all_masks = masks1
368
        
369
        print(f"为图像 {image_filename} 找到 {len(all_masks)} 个相关掩码:")
370
        for mask in all_masks:
371
            print(f"  - {os.path.basename(mask)}")
372
        
373
        return all_masks
374
 
375
    
376
    # 读取图像
377
def process_image_by_name(data_dir, image_filename, model_cfg, sam2_checkpoint, fine_tuned_path=None, output_dir=None):
378
        """
379
        通过图像文件名处理单个图像,自动查找相关掩码
380
        
381
        参数:
382
            data_dir: 数据根目录
383
            image_filename: 图像文件名(不含路径)
384
            model_cfg: 模型配置文件路径
385
            sam2_checkpoint: 原始模型检查点路径
386
            fine_tuned_path: 微调模型路径
387
            output_dir: 输出目录
388
        """
389
        images_dir = os.path.join(data_dir, "JPEGImages")
390
        masks_dir = os.path.join(data_dir, "Annotations")
391
        
392
        # 构建完整图像路径
393
        image_path = os.path.join(images_dir, image_filename)
394
        
395
        if not os.path.exists(image_path):
396
            print(f"图像文件不存在: {image_path}")
397
            return
398
        
399
        # 自动查找相关掩码
400
        mask_paths = find_related_masks_advanced(image_path, masks_dir)
401
        
402
        if not mask_paths:
403
            print("未找到相关掩码文件")
404
            return
405
        
406
        # 设置输出目录(使用图像名作为子目录)
407
        if output_dir is None:
408
            image_name, _ = os.path.splitext(image_filename)
409
            output_dir = os.path.join("results", image_name)
410
        
411
        # 运行模型比较
412
        results = inference_sam2_compare_models(
413
            image_path=image_path,
414
            mask_paths=mask_paths,
415
            model_cfg=model_cfg,
416
            sam2_checkpoint=sam2_checkpoint,
417
            fine_tuned_path=fine_tuned_path,
418
            output_dir=output_dir,
419
            show_results=False
420
        )
421
        
422
        return results
423
424
425
if __name__ == "__main__":
426
    # 设置基本路径
427
    data_dir = "/media/ps/data/zhy/Sam2_new/sam2/data_train"
428
    images_dir = os.path.join(data_dir, "JPEGImages")
429
    masks_dir = os.path.join(data_dir, "Annotations")
430
    
431
    # 模型配置
432
    model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
433
    sam2_checkpoint = "checkpoints/sam2.1_hiera_large.pt"
434
    fine_tuned_path = "models/model_final.torch"
435
    
436
    # 指定图像文件
437
    image_filename = "a4c_PatientD0062_a4c_93.jsonD_116.jpg"
438
    image_path = os.path.join(images_dir, image_filename)
439
    
440
    print(f"检查图像文件...")
441
    if not os.path.exists(image_path):
442
        print(f"图像文件不存在: {image_path}")
443
        exit(1)
444
    
445
    # 自动查找相关掩码
446
    mask_paths = find_related_masks_advanced(image_path, masks_dir)
447
    
448
    # 检查掩码文件
449
    valid_masks = []
450
    for mask_path in mask_paths:
451
        if os.path.exists(mask_path):
452
            valid_masks.append(mask_path)
453
            print(f"掩码文件存在: {mask_path}")
454
        else:
455
            print(f"掩码文件不存在: {mask_path}")
456
    
457
    if not valid_masks:
458
        print("没有有效的掩码文件")
459
        exit(1)
460
    
461
    # 检查其他文件
462
    for path, desc in [
463
        (model_cfg, "模型配置文件"),
464
        (sam2_checkpoint, "预训练检查点"),
465
        (fine_tuned_path, "微调模型")
466
    ]:
467
        exists = os.path.exists(path)
468
        print(f"{desc}: {path} {'存在' if exists else '不存在'}")
469
    
470
    # 设置输出目录
471
    output_dir = os.path.join("results_comparison", os.path.splitext(image_filename)[0])
472
    
473
    try:
474
        # 运行推理
475
        print("\n开始模型对比推理...")
476
        results = inference_sam2_compare_models(
477
            image_path=image_path,
478
            mask_paths=valid_masks,
479
            model_cfg=model_cfg,
480
            sam2_checkpoint=sam2_checkpoint,
481
            fine_tuned_path=fine_tuned_path,
482
            num_samples_per_mask=10,
483
            output_dir=output_dir,
484
            show_results=False
485
        )
486
        print("对比推理完成!")
487
        
488
        # 输出简单的性能指标
489
        if "original" in results and "fine_tuned" in results:
490
            gt_map = results["ground_truth"]["map"]
491
            
492
            for model_key in ["original", "fine_tuned"]:
493
                model_map = results[model_key]["map"]
494
                
495
                # 计算基本指标
496
                true_positive = np.sum(np.logical_and(gt_map > 0, model_map > 0))
497
                false_negative = np.sum(np.logical_and(gt_map > 0, model_map == 0))
498
                false_positive = np.sum(np.logical_and(gt_map == 0, model_map > 0))
499
                
500
                # 计算性能指标
501
                precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0
502
                recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
503
                f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
504
                
505
                print(f"\n{model_key.capitalize()}模型性能:")
506
                print(f"精确率(Precision): {precision:.4f}")
507
                print(f"召回率(Recall): {recall:.4f}")  
508
                print(f"F1分数: {f1:.4f}")
509
    
510
    except Exception as e:
511
        print(f"推理过程中出错: {e}")
512
        import traceback
513
        traceback.print_exc()