|
a |
|
b/testing.py |
|
|
1 |
import os |
|
|
2 |
import cv2 |
|
|
3 |
import numpy as np |
|
|
4 |
import torch |
|
|
5 |
from torch.cuda.amp import autocast |
|
|
6 |
|
|
|
7 |
import re |
|
|
8 |
import glob |
|
|
9 |
|
|
|
10 |
# 导入SAM2相关模块 |
|
|
11 |
from sam2.build_sam import build_sam2 |
|
|
12 |
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
def inference_sam2_compare_models( |
|
|
16 |
image_path, |
|
|
17 |
mask_paths, |
|
|
18 |
model_cfg, |
|
|
19 |
sam2_checkpoint, |
|
|
20 |
fine_tuned_path=None, |
|
|
21 |
num_samples_per_mask=10, |
|
|
22 |
output_dir=None, |
|
|
23 |
show_results=False, |
|
|
24 |
device="cuda" |
|
|
25 |
): |
|
|
26 |
|
|
|
27 |
|
|
|
28 |
|
|
|
29 |
def read_image(image_path): |
|
|
30 |
if not os.path.exists(image_path): |
|
|
31 |
raise FileNotFoundError(f"图像文件不存在: {image_path}") |
|
|
32 |
|
|
|
33 |
img = cv2.imread(image_path) |
|
|
34 |
if img is None: |
|
|
35 |
raise ValueError(f"无法读取图像: {image_path}") |
|
|
36 |
img = img[..., ::-1] # BGR to RGB |
|
|
37 |
|
|
|
38 |
# 将图像调整到最大尺寸1024 |
|
|
39 |
resize_factor = np.min([1024 / img.shape[1], 1024 / img.shape[0]]) |
|
|
40 |
img = cv2.resize(img, (int(img.shape[1] * resize_factor), int(img.shape[0] * resize_factor))) |
|
|
41 |
return img, resize_factor |
|
|
42 |
|
|
|
43 |
# 在输入掩码内采样点 |
|
|
44 |
def get_points(mask, num_points): |
|
|
45 |
points = [] |
|
|
46 |
coords = np.argwhere(mask > 0) |
|
|
47 |
|
|
|
48 |
if len(coords) == 0: |
|
|
49 |
print("警告:掩码为空,无法采样点") |
|
|
50 |
return np.array([]) |
|
|
51 |
|
|
|
52 |
for _ in range(num_points): |
|
|
53 |
if len(coords) > 0: |
|
|
54 |
idx = np.random.randint(len(coords)) |
|
|
55 |
yx = coords[idx] |
|
|
56 |
points.append([[yx[1], yx[0]]]) |
|
|
57 |
|
|
|
58 |
return np.array(points) if points else np.array([]) |
|
|
59 |
|
|
|
60 |
# 读取图像 |
|
|
61 |
image, resize_factor = read_image(image_path) |
|
|
62 |
|
|
|
63 |
# 读取所有掩码并采样点 |
|
|
64 |
all_points = [] |
|
|
65 |
all_point_labels = [] |
|
|
66 |
all_masks = [] |
|
|
67 |
|
|
|
68 |
for i, mask_path in enumerate(mask_paths): |
|
|
69 |
if not os.path.exists(mask_path): |
|
|
70 |
print(f"警告:掩码文件不存在: {mask_path},将跳过") |
|
|
71 |
continue |
|
|
72 |
|
|
|
73 |
mask = cv2.imread(mask_path, 0) # 读取掩码 |
|
|
74 |
if mask is None: |
|
|
75 |
print(f"警告:无法读取掩码: {mask_path},将跳过") |
|
|
76 |
continue |
|
|
77 |
|
|
|
78 |
# 调整掩码大小以匹配图像 - 使用resize_factor而不是r |
|
|
79 |
mask = cv2.resize(mask, (int(mask.shape[1] * resize_factor), int(mask.shape[0] * resize_factor)), |
|
|
80 |
interpolation=cv2.INTER_NEAREST) |
|
|
81 |
|
|
|
82 |
all_masks.append(mask) |
|
|
83 |
|
|
|
84 |
# 获取该掩码的点 |
|
|
85 |
points = get_points(mask, num_samples_per_mask) |
|
|
86 |
if len(points) > 0: |
|
|
87 |
all_points.append(points) |
|
|
88 |
all_point_labels.append(np.ones((len(points), 1))) |
|
|
89 |
|
|
|
90 |
if not all_points: |
|
|
91 |
raise ValueError("所有掩码都无法采样点,无法进行分割") |
|
|
92 |
|
|
|
93 |
|
|
|
94 |
# 合并所有掩码的点和标签 |
|
|
95 |
input_points = np.concatenate(all_points, axis=0) |
|
|
96 |
input_labels = np.concatenate(all_point_labels, axis=0) |
|
|
97 |
|
|
|
98 |
# 创建真实掩码的可视化 |
|
|
99 |
ground_truth_map = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) |
|
|
100 |
ground_truth_rgb = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.uint8) |
|
|
101 |
|
|
|
102 |
# 为每个掩码分配唯一的颜色和ID |
|
|
103 |
colors = {} |
|
|
104 |
|
|
|
105 |
for i, mask in enumerate(all_masks): |
|
|
106 |
# 生成随机颜色,但确保所有可视化中使用相同颜色 |
|
|
107 |
if i+1 not in colors: |
|
|
108 |
colors[i+1] = [ |
|
|
109 |
np.random.randint(255), |
|
|
110 |
np.random.randint(255), |
|
|
111 |
np.random.randint(255) |
|
|
112 |
] |
|
|
113 |
|
|
|
114 |
# 更新掩码图和RGB图 |
|
|
115 |
ground_truth_map[mask > 0] = i + 1 |
|
|
116 |
ground_truth_rgb[mask > 0] = colors[i+1] |
|
|
117 |
|
|
|
118 |
# 创建真实掩码的混合图像 |
|
|
119 |
gt_blended_image = (ground_truth_rgb / 2 + image / 2).astype(np.uint8) |
|
|
120 |
|
|
|
121 |
# 初始化结果字典 |
|
|
122 |
results = { |
|
|
123 |
"ground_truth": { |
|
|
124 |
"map": ground_truth_map, |
|
|
125 |
"rgb": ground_truth_rgb, |
|
|
126 |
"blended": gt_blended_image |
|
|
127 |
} |
|
|
128 |
} |
|
|
129 |
|
|
|
130 |
# 使用两个模型进行推理 |
|
|
131 |
for model_key, model_path in [ |
|
|
132 |
("original", sam2_checkpoint), |
|
|
133 |
("fine_tuned", fine_tuned_path) |
|
|
134 |
]: |
|
|
135 |
# 如果是微调模型但路径为None,则跳过 |
|
|
136 |
if model_key == "fine_tuned" and model_path is None: |
|
|
137 |
continue |
|
|
138 |
|
|
|
139 |
print(f"\n使用{model_key}模型进行推理...") |
|
|
140 |
|
|
|
141 |
# 初始化SAM2模型 |
|
|
142 |
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device) |
|
|
143 |
predictor = SAM2ImagePredictor(sam2_model) |
|
|
144 |
|
|
|
145 |
# 如果是微调模型,加载微调权重 |
|
|
146 |
if model_key == "fine_tuned": |
|
|
147 |
try: |
|
|
148 |
predictor.model.load_state_dict(torch.load(model_path, weights_only=True)) |
|
|
149 |
except: |
|
|
150 |
predictor.model.load_state_dict(torch.load(model_path)) |
|
|
151 |
print(f"已加载微调模型: {model_path}") |
|
|
152 |
|
|
|
153 |
# 设置模型为评估模式 |
|
|
154 |
predictor.model.eval() |
|
|
155 |
|
|
|
156 |
# 使用混合精度进行推理 |
|
|
157 |
with torch.no_grad(): |
|
|
158 |
# 图像编码器 |
|
|
159 |
predictor.set_image(image) |
|
|
160 |
|
|
|
161 |
print(f"{model_key}模型正在进行预测...") |
|
|
162 |
print("输入点形状:", input_points.shape) |
|
|
163 |
|
|
|
164 |
# prompt编码器 + mask解码器 |
|
|
165 |
masks, scores, _ = predictor.predict( |
|
|
166 |
point_coords=input_points, |
|
|
167 |
point_labels=input_labels, |
|
|
168 |
multimask_output=True |
|
|
169 |
) |
|
|
170 |
|
|
|
171 |
print(f"{model_key}模型预测完成") |
|
|
172 |
print(f"预测掩码数量: {masks.shape[0] if hasattr(masks, 'shape') else '未知'}") |
|
|
173 |
|
|
|
174 |
# 处理masks和scores |
|
|
175 |
if isinstance(masks, torch.Tensor): |
|
|
176 |
np_masks = masks.cpu().numpy() |
|
|
177 |
if np_masks.ndim > 3: |
|
|
178 |
np_masks = np_masks[:, 0] |
|
|
179 |
else: |
|
|
180 |
np_masks = masks |
|
|
181 |
if np_masks.ndim > 3: |
|
|
182 |
np_masks = np_masks[:, 0] |
|
|
183 |
|
|
|
184 |
if isinstance(scores, torch.Tensor): |
|
|
185 |
np_scores = scores.cpu().numpy() |
|
|
186 |
if np_scores.ndim > 1: |
|
|
187 |
np_scores = np_scores[:, 0] |
|
|
188 |
else: |
|
|
189 |
np_scores = scores |
|
|
190 |
if np_scores.ndim > 1: |
|
|
191 |
np_scores = np_scores[:, 0] |
|
|
192 |
|
|
|
193 |
# 根据分数排序掩码(降序) |
|
|
194 |
sorted_indices = np.argsort(np_scores)[::-1] |
|
|
195 |
sorted_masks = np_masks[sorted_indices] |
|
|
196 |
|
|
|
197 |
# 创建分割图 |
|
|
198 |
seg_map = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) |
|
|
199 |
occupancy_mask = np.zeros((image.shape[0], image.shape[1]), dtype=bool) |
|
|
200 |
|
|
|
201 |
for i, mask in enumerate(sorted_masks): |
|
|
202 |
# 确保掩码是布尔类型 |
|
|
203 |
mask_bool = mask.astype(bool) |
|
|
204 |
|
|
|
205 |
# 计算重叠比例 |
|
|
206 |
# 使用逻辑与运算而不是位与运算 |
|
|
207 |
overlap_ratio = np.logical_and(mask_bool, occupancy_mask).sum() / (mask_bool.sum() + 1e-6) |
|
|
208 |
|
|
|
209 |
if mask_bool.sum() > 0 and overlap_ratio > 0.15: |
|
|
210 |
continue |
|
|
211 |
|
|
|
212 |
# 更新掩码和分割图 |
|
|
213 |
non_overlap = np.logical_and(mask_bool, np.logical_not(occupancy_mask)) |
|
|
214 |
seg_map[non_overlap] = i + 1 |
|
|
215 |
occupancy_mask[mask_bool] = True |
|
|
216 |
|
|
|
217 |
# 创建可视化图像(使用与真实掩码相同的颜色) |
|
|
218 |
rgb_image = np.zeros((seg_map.shape[0], seg_map.shape[1], 3), dtype=np.uint8) |
|
|
219 |
for id_class in range(1, seg_map.max() + 1): |
|
|
220 |
# 尝试使用与真实掩码相同的颜色 |
|
|
221 |
color = colors.get(id_class, [ |
|
|
222 |
np.random.randint(255), |
|
|
223 |
np.random.randint(255), |
|
|
224 |
np.random.randint(255) |
|
|
225 |
]) |
|
|
226 |
rgb_image[seg_map == id_class] = color |
|
|
227 |
|
|
|
228 |
# 创建混合图像 |
|
|
229 |
blended_image = (rgb_image / 2 + image / 2).astype(np.uint8) |
|
|
230 |
|
|
|
231 |
# 创建差异图 (显示预测与真实值的不同) |
|
|
232 |
difference = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.uint8) |
|
|
233 |
|
|
|
234 |
# 绿色表示正确预测的区域 |
|
|
235 |
correct = np.logical_and(ground_truth_map > 0, seg_map > 0) |
|
|
236 |
difference[correct] = [0, 255, 0] |
|
|
237 |
|
|
|
238 |
# 红色表示未预测到的区域 (漏检) |
|
|
239 |
missed = np.logical_and(ground_truth_map > 0, seg_map == 0) |
|
|
240 |
difference[missed] = [255, 0, 0] |
|
|
241 |
|
|
|
242 |
# 蓝色表示错误预测的区域 (误检) |
|
|
243 |
false_positive = np.logical_and(ground_truth_map == 0, seg_map > 0) |
|
|
244 |
difference[false_positive] = [0, 0, 255] |
|
|
245 |
|
|
|
246 |
# 保存到结果字典 |
|
|
247 |
results[model_key] = { |
|
|
248 |
"map": seg_map, |
|
|
249 |
"rgb": rgb_image, |
|
|
250 |
"blended": blended_image, |
|
|
251 |
"difference": difference |
|
|
252 |
} |
|
|
253 |
|
|
|
254 |
# 保存结果 |
|
|
255 |
if output_dir: |
|
|
256 |
os.makedirs(output_dir, exist_ok=True) |
|
|
257 |
|
|
|
258 |
# 保存真实掩码 |
|
|
259 |
cv2.imwrite(os.path.join(output_dir, "gt_segmentation.png"), |
|
|
260 |
cv2.cvtColor(results["ground_truth"]["rgb"], cv2.COLOR_RGB2BGR)) |
|
|
261 |
cv2.imwrite(os.path.join(output_dir, "gt_blended.png"), |
|
|
262 |
cv2.cvtColor(results["ground_truth"]["blended"], cv2.COLOR_RGB2BGR)) |
|
|
263 |
|
|
|
264 |
# 保存原始模型结果 |
|
|
265 |
if "original" in results: |
|
|
266 |
cv2.imwrite(os.path.join(output_dir, "original_segmentation.png"), |
|
|
267 |
cv2.cvtColor(results["original"]["rgb"], cv2.COLOR_RGB2BGR)) |
|
|
268 |
cv2.imwrite(os.path.join(output_dir, "original_blended.png"), |
|
|
269 |
cv2.cvtColor(results["original"]["blended"], cv2.COLOR_RGB2BGR)) |
|
|
270 |
cv2.imwrite(os.path.join(output_dir, "original_difference.png"), |
|
|
271 |
cv2.cvtColor(results["original"]["difference"], cv2.COLOR_RGB2BGR)) |
|
|
272 |
|
|
|
273 |
# 保存微调模型结果 |
|
|
274 |
if "fine_tuned" in results: |
|
|
275 |
cv2.imwrite(os.path.join(output_dir, "fine_tuned_segmentation.png"), |
|
|
276 |
cv2.cvtColor(results["fine_tuned"]["rgb"], cv2.COLOR_RGB2BGR)) |
|
|
277 |
cv2.imwrite(os.path.join(output_dir, "fine_tuned_blended.png"), |
|
|
278 |
cv2.cvtColor(results["fine_tuned"]["blended"], cv2.COLOR_RGB2BGR)) |
|
|
279 |
cv2.imwrite(os.path.join(output_dir, "fine_tuned_difference.png"), |
|
|
280 |
cv2.cvtColor(results["fine_tuned"]["difference"], cv2.COLOR_RGB2BGR)) |
|
|
281 |
|
|
|
282 |
# 创建三模型对比图 (并排显示) |
|
|
283 |
# 计算每个模型的列宽 |
|
|
284 |
col_width = image.shape[1] |
|
|
285 |
models_count = 1 + ("original" in results) + ("fine_tuned" in results) |
|
|
286 |
|
|
|
287 |
comparison = np.zeros((image.shape[0], col_width * models_count, 3), dtype=np.uint8) |
|
|
288 |
|
|
|
289 |
# 添加真实掩码 |
|
|
290 |
col_idx = 0 |
|
|
291 |
comparison[:, col_idx*col_width:(col_idx+1)*col_width] = results["ground_truth"]["blended"] |
|
|
292 |
col_idx += 1 |
|
|
293 |
|
|
|
294 |
# 添加原始模型结果 |
|
|
295 |
if "original" in results: |
|
|
296 |
comparison[:, col_idx*col_width:(col_idx+1)*col_width] = results["original"]["blended"] |
|
|
297 |
col_idx += 1 |
|
|
298 |
|
|
|
299 |
# 添加微调模型结果 |
|
|
300 |
if "fine_tuned" in results: |
|
|
301 |
comparison[:, col_idx*col_width:(col_idx+1)*col_width] = results["fine_tuned"]["blended"] |
|
|
302 |
|
|
|
303 |
# 添加标签 |
|
|
304 |
font = cv2.FONT_HERSHEY_SIMPLEX |
|
|
305 |
|
|
|
306 |
col_idx = 0 |
|
|
307 |
cv2.putText(comparison, "Ground Truth", (col_idx*col_width + 10, 30), font, 1, (255, 255, 255), 2) |
|
|
308 |
col_idx += 1 |
|
|
309 |
|
|
|
310 |
if "original" in results: |
|
|
311 |
cv2.putText(comparison, "Original Model", (col_idx*col_width + 10, 30), font, 1, (255, 255, 255), 2) |
|
|
312 |
col_idx += 1 |
|
|
313 |
|
|
|
314 |
if "fine_tuned" in results: |
|
|
315 |
cv2.putText(comparison, "Fine-tuned Model", (col_idx*col_width + 10, 30), font, 1, (255, 255, 255), 2) |
|
|
316 |
|
|
|
317 |
cv2.imwrite(os.path.join(output_dir, "model_comparison.png"), |
|
|
318 |
cv2.cvtColor(comparison, cv2.COLOR_RGB2BGR)) |
|
|
319 |
|
|
|
320 |
print(f"所有结果已保存到 {output_dir} 目录") |
|
|
321 |
|
|
|
322 |
# 显示结果(在有GUI的环境中) |
|
|
323 |
if show_results: |
|
|
324 |
try: |
|
|
325 |
if "original" in results and "fine_tuned" in results: |
|
|
326 |
cv2.imshow("模型对比", cv2.cvtColor(comparison, cv2.COLOR_RGB2BGR)) |
|
|
327 |
print("按任意键关闭窗口...") |
|
|
328 |
cv2.waitKey(0) |
|
|
329 |
cv2.destroyAllWindows() |
|
|
330 |
except Exception as e: |
|
|
331 |
print(f"无法显示图像,可能是在无GUI环境中运行: {e}") |
|
|
332 |
print("结果已保存到指定目录") |
|
|
333 |
|
|
|
334 |
return results |
|
|
335 |
|
|
|
336 |
|
|
|
337 |
def find_related_masks_advanced(image_path, masks_dir): |
|
|
338 |
""" |
|
|
339 |
使用更高级的方法查找与图像相关的掩码 |
|
|
340 |
""" |
|
|
341 |
|
|
|
342 |
# 获取图像文件名(不含扩展名) |
|
|
343 |
image_filename = os.path.basename(image_path) |
|
|
344 |
image_name, _ = os.path.splitext(image_filename) |
|
|
345 |
|
|
|
346 |
# 基本情况:掩码名以图像名开头 |
|
|
347 |
pattern1 = os.path.join(masks_dir, f"{image_name}*.png") |
|
|
348 |
masks1 = glob.glob(pattern1) |
|
|
349 |
|
|
|
350 |
# 从图像名中提取患者ID和序列号,用于更复杂的匹配 |
|
|
351 |
# 例如: a4c_PatientD0062_a4c_93.jsonD_116.jpg |
|
|
352 |
# 提取: PatientD0062, 93, 116 |
|
|
353 |
matches = re.search(r'(Patient[A-Za-z0-9]+).*?(\d+)\.json[A-Z]_(\d+)', image_name) |
|
|
354 |
|
|
|
355 |
if matches: |
|
|
356 |
patient_id = matches.group(1) |
|
|
357 |
seq_num1 = matches.group(2) |
|
|
358 |
seq_num2 = matches.group(3) |
|
|
359 |
|
|
|
360 |
# 使用提取的信息尝试其他可能的匹配模式 |
|
|
361 |
pattern2 = os.path.join(masks_dir, f"*{patient_id}*{seq_num1}*{seq_num2}*.png") |
|
|
362 |
masks2 = glob.glob(pattern2) |
|
|
363 |
|
|
|
364 |
# 合并去重 |
|
|
365 |
all_masks = list(set(masks1 + masks2)) |
|
|
366 |
else: |
|
|
367 |
all_masks = masks1 |
|
|
368 |
|
|
|
369 |
print(f"为图像 {image_filename} 找到 {len(all_masks)} 个相关掩码:") |
|
|
370 |
for mask in all_masks: |
|
|
371 |
print(f" - {os.path.basename(mask)}") |
|
|
372 |
|
|
|
373 |
return all_masks |
|
|
374 |
|
|
|
375 |
|
|
|
376 |
# 读取图像 |
|
|
377 |
def process_image_by_name(data_dir, image_filename, model_cfg, sam2_checkpoint, fine_tuned_path=None, output_dir=None): |
|
|
378 |
""" |
|
|
379 |
通过图像文件名处理单个图像,自动查找相关掩码 |
|
|
380 |
|
|
|
381 |
参数: |
|
|
382 |
data_dir: 数据根目录 |
|
|
383 |
image_filename: 图像文件名(不含路径) |
|
|
384 |
model_cfg: 模型配置文件路径 |
|
|
385 |
sam2_checkpoint: 原始模型检查点路径 |
|
|
386 |
fine_tuned_path: 微调模型路径 |
|
|
387 |
output_dir: 输出目录 |
|
|
388 |
""" |
|
|
389 |
images_dir = os.path.join(data_dir, "JPEGImages") |
|
|
390 |
masks_dir = os.path.join(data_dir, "Annotations") |
|
|
391 |
|
|
|
392 |
# 构建完整图像路径 |
|
|
393 |
image_path = os.path.join(images_dir, image_filename) |
|
|
394 |
|
|
|
395 |
if not os.path.exists(image_path): |
|
|
396 |
print(f"图像文件不存在: {image_path}") |
|
|
397 |
return |
|
|
398 |
|
|
|
399 |
# 自动查找相关掩码 |
|
|
400 |
mask_paths = find_related_masks_advanced(image_path, masks_dir) |
|
|
401 |
|
|
|
402 |
if not mask_paths: |
|
|
403 |
print("未找到相关掩码文件") |
|
|
404 |
return |
|
|
405 |
|
|
|
406 |
# 设置输出目录(使用图像名作为子目录) |
|
|
407 |
if output_dir is None: |
|
|
408 |
image_name, _ = os.path.splitext(image_filename) |
|
|
409 |
output_dir = os.path.join("results", image_name) |
|
|
410 |
|
|
|
411 |
# 运行模型比较 |
|
|
412 |
results = inference_sam2_compare_models( |
|
|
413 |
image_path=image_path, |
|
|
414 |
mask_paths=mask_paths, |
|
|
415 |
model_cfg=model_cfg, |
|
|
416 |
sam2_checkpoint=sam2_checkpoint, |
|
|
417 |
fine_tuned_path=fine_tuned_path, |
|
|
418 |
output_dir=output_dir, |
|
|
419 |
show_results=False |
|
|
420 |
) |
|
|
421 |
|
|
|
422 |
return results |
|
|
423 |
|
|
|
424 |
|
|
|
425 |
if __name__ == "__main__": |
|
|
426 |
# 设置基本路径 |
|
|
427 |
data_dir = "/media/ps/data/zhy/Sam2_new/sam2/data_train" |
|
|
428 |
images_dir = os.path.join(data_dir, "JPEGImages") |
|
|
429 |
masks_dir = os.path.join(data_dir, "Annotations") |
|
|
430 |
|
|
|
431 |
# 模型配置 |
|
|
432 |
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" |
|
|
433 |
sam2_checkpoint = "checkpoints/sam2.1_hiera_large.pt" |
|
|
434 |
fine_tuned_path = "models/model_final.torch" |
|
|
435 |
|
|
|
436 |
# 指定图像文件 |
|
|
437 |
image_filename = "a4c_PatientD0062_a4c_93.jsonD_116.jpg" |
|
|
438 |
image_path = os.path.join(images_dir, image_filename) |
|
|
439 |
|
|
|
440 |
print(f"检查图像文件...") |
|
|
441 |
if not os.path.exists(image_path): |
|
|
442 |
print(f"图像文件不存在: {image_path}") |
|
|
443 |
exit(1) |
|
|
444 |
|
|
|
445 |
# 自动查找相关掩码 |
|
|
446 |
mask_paths = find_related_masks_advanced(image_path, masks_dir) |
|
|
447 |
|
|
|
448 |
# 检查掩码文件 |
|
|
449 |
valid_masks = [] |
|
|
450 |
for mask_path in mask_paths: |
|
|
451 |
if os.path.exists(mask_path): |
|
|
452 |
valid_masks.append(mask_path) |
|
|
453 |
print(f"掩码文件存在: {mask_path}") |
|
|
454 |
else: |
|
|
455 |
print(f"掩码文件不存在: {mask_path}") |
|
|
456 |
|
|
|
457 |
if not valid_masks: |
|
|
458 |
print("没有有效的掩码文件") |
|
|
459 |
exit(1) |
|
|
460 |
|
|
|
461 |
# 检查其他文件 |
|
|
462 |
for path, desc in [ |
|
|
463 |
(model_cfg, "模型配置文件"), |
|
|
464 |
(sam2_checkpoint, "预训练检查点"), |
|
|
465 |
(fine_tuned_path, "微调模型") |
|
|
466 |
]: |
|
|
467 |
exists = os.path.exists(path) |
|
|
468 |
print(f"{desc}: {path} {'存在' if exists else '不存在'}") |
|
|
469 |
|
|
|
470 |
# 设置输出目录 |
|
|
471 |
output_dir = os.path.join("results_comparison", os.path.splitext(image_filename)[0]) |
|
|
472 |
|
|
|
473 |
try: |
|
|
474 |
# 运行推理 |
|
|
475 |
print("\n开始模型对比推理...") |
|
|
476 |
results = inference_sam2_compare_models( |
|
|
477 |
image_path=image_path, |
|
|
478 |
mask_paths=valid_masks, |
|
|
479 |
model_cfg=model_cfg, |
|
|
480 |
sam2_checkpoint=sam2_checkpoint, |
|
|
481 |
fine_tuned_path=fine_tuned_path, |
|
|
482 |
num_samples_per_mask=10, |
|
|
483 |
output_dir=output_dir, |
|
|
484 |
show_results=False |
|
|
485 |
) |
|
|
486 |
print("对比推理完成!") |
|
|
487 |
|
|
|
488 |
# 输出简单的性能指标 |
|
|
489 |
if "original" in results and "fine_tuned" in results: |
|
|
490 |
gt_map = results["ground_truth"]["map"] |
|
|
491 |
|
|
|
492 |
for model_key in ["original", "fine_tuned"]: |
|
|
493 |
model_map = results[model_key]["map"] |
|
|
494 |
|
|
|
495 |
# 计算基本指标 |
|
|
496 |
true_positive = np.sum(np.logical_and(gt_map > 0, model_map > 0)) |
|
|
497 |
false_negative = np.sum(np.logical_and(gt_map > 0, model_map == 0)) |
|
|
498 |
false_positive = np.sum(np.logical_and(gt_map == 0, model_map > 0)) |
|
|
499 |
|
|
|
500 |
# 计算性能指标 |
|
|
501 |
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0 |
|
|
502 |
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0 |
|
|
503 |
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 |
|
|
504 |
|
|
|
505 |
print(f"\n{model_key.capitalize()}模型性能:") |
|
|
506 |
print(f"精确率(Precision): {precision:.4f}") |
|
|
507 |
print(f"召回率(Recall): {recall:.4f}") |
|
|
508 |
print(f"F1分数: {f1:.4f}") |
|
|
509 |
|
|
|
510 |
except Exception as e: |
|
|
511 |
print(f"推理过程中出错: {e}") |
|
|
512 |
import traceback |
|
|
513 |
traceback.print_exc() |