|
a |
|
b/data_process.ipynb |
|
|
1 |
{ |
|
|
2 |
"cells": [ |
|
|
3 |
{ |
|
|
4 |
"cell_type": "code", |
|
|
5 |
"execution_count": 7, |
|
|
6 |
"id": "30cc0acc", |
|
|
7 |
"metadata": {}, |
|
|
8 |
"outputs": [ |
|
|
9 |
{ |
|
|
10 |
"name": "stdout", |
|
|
11 |
"output_type": "stream", |
|
|
12 |
"text": [ |
|
|
13 |
"找到的CSV文件: ['train.csv']\n", |
|
|
14 |
"\n", |
|
|
15 |
"读取文件: train.csv\n", |
|
|
16 |
"CSV文件列名: ['ImageId', 'MaskId', 'Label', 'PointX', 'PointY']\n", |
|
|
17 |
"CSV文件前5行:\n", |
|
|
18 |
" ImageId MaskId Label \\\n", |
|
|
19 |
"0 a4c_PatientA0030_a4c_20S_0.jpg a4c_PatientA0030_a4c_20S_0_LV_0.png LV \n", |
|
|
20 |
"1 a4c_PatientA0030_a4c_20S_0.jpg a4c_PatientA0030_a4c_20S_0_M_1.png M \n", |
|
|
21 |
"2 a4c_PatientA0030_a4c_20S_0.jpg a4c_PatientA0030_a4c_20S_0_LA_2.png LA \n", |
|
|
22 |
"3 a4c_PatientA0030_a4c_20S_0.jpg a4c_PatientA0030_a4c_20S_0_RV_3.png RV \n", |
|
|
23 |
"4 a4c_PatientA0030_a4c_20S_0.jpg a4c_PatientA0030_a4c_20S_0_RA_4.png RA \n", |
|
|
24 |
"\n", |
|
|
25 |
" PointX PointY \n", |
|
|
26 |
"0 373 164 \n", |
|
|
27 |
"1 522 296 \n", |
|
|
28 |
"2 454 404 \n", |
|
|
29 |
"3 289 287 \n", |
|
|
30 |
"4 334 457 \n", |
|
|
31 |
"CSV文件共有 6620 行\n", |
|
|
32 |
"\n", |
|
|
33 |
"图像文件夹存在: True\n", |
|
|
34 |
"图像文件夹中的文件数量: 1808\n", |
|
|
35 |
"图像文件夹中的前5个文件: ['a4c_PatientA0195_a4c_1D_141.jpg', 'a2c_PatientA0107_a2c_20S_959.jpg', 'a4c_PatientA0141_a4c_45D_508.jpg', 'a4c_PatientA0047_a4c_41D_63.jpg', 'a2c_PatientD0059_a2c_27.jsonS_996.jpg']\n", |
|
|
36 |
"\n", |
|
|
37 |
"标注文件夹存在: True\n", |
|
|
38 |
"标注文件夹中的文件数量: 6620\n", |
|
|
39 |
"标注文件夹中的前5个文件: ['a3c_PatientA0138_a3c_27D_1769_M_2.png', 'a4c_PatientA0157_a4c_28S_549_LA_1.png', 'a4c_PatientC0035_a4c_83D_35_M_1.png', 'a4c_PatientA0089_a4c_52D_512_LV_3.png', 'a4c_PatientC0038_a4c_43D_563_LV_0.png']\n" |
|
|
40 |
] |
|
|
41 |
} |
|
|
42 |
], |
|
|
43 |
"source": [ |
|
|
44 |
"import os\n", |
|
|
45 |
"import pandas as pd\n", |
|
|
46 |
"\n", |
|
|
47 |
"# 您的数据路径\n", |
|
|
48 |
"data_path = \"/media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train\"\n", |
|
|
49 |
"\n", |
|
|
50 |
"# 查找CSV文件\n", |
|
|
51 |
"csv_files = []\n", |
|
|
52 |
"for file in os.listdir(data_path):\n", |
|
|
53 |
" if file.endswith('.csv'):\n", |
|
|
54 |
" csv_files.append(file)\n", |
|
|
55 |
"\n", |
|
|
56 |
"print(f\"找到的CSV文件: {csv_files}\")\n", |
|
|
57 |
"\n", |
|
|
58 |
"# 如果存在CSV文件,读取并显示其内容\n", |
|
|
59 |
"if csv_files:\n", |
|
|
60 |
" for csv_file in csv_files:\n", |
|
|
61 |
" file_path = os.path.join(data_path, csv_file)\n", |
|
|
62 |
" print(f\"\\n读取文件: {csv_file}\")\n", |
|
|
63 |
" try:\n", |
|
|
64 |
" df = pd.read_csv(file_path)\n", |
|
|
65 |
" print(f\"CSV文件列名: {df.columns.tolist()}\")\n", |
|
|
66 |
" print(f\"CSV文件前5行:\")\n", |
|
|
67 |
" print(df.head(5))\n", |
|
|
68 |
" print(f\"CSV文件共有 {len(df)} 行\")\n", |
|
|
69 |
" except Exception as e:\n", |
|
|
70 |
" print(f\"读取文件出错: {e}\")\n", |
|
|
71 |
"else:\n", |
|
|
72 |
" print(\"数据目录中没有找到CSV文件\")\n", |
|
|
73 |
"\n", |
|
|
74 |
"# 检查图像和标注文件夹\n", |
|
|
75 |
"images_dir = os.path.join(data_path, \"JPEGImages\")\n", |
|
|
76 |
"annot_dir = os.path.join(data_path, \"Annotations\")\n", |
|
|
77 |
"\n", |
|
|
78 |
"print(f\"\\n图像文件夹存在: {os.path.exists(images_dir)}\")\n", |
|
|
79 |
"if os.path.exists(images_dir):\n", |
|
|
80 |
" print(f\"图像文件夹中的文件数量: {len(os.listdir(images_dir))}\")\n", |
|
|
81 |
" print(f\"图像文件夹中的前5个文件: {os.listdir(images_dir)[:5]}\")\n", |
|
|
82 |
"\n", |
|
|
83 |
"print(f\"\\n标注文件夹存在: {os.path.exists(annot_dir)}\")\n", |
|
|
84 |
"if os.path.exists(annot_dir):\n", |
|
|
85 |
" print(f\"标注文件夹中的文件数量: {len(os.listdir(annot_dir))}\")\n", |
|
|
86 |
" print(f\"标注文件夹中的前5个文件: {os.listdir(annot_dir)[:5]}\")" |
|
|
87 |
] |
|
|
88 |
}, |
|
|
89 |
{ |
|
|
90 |
"cell_type": "code", |
|
|
91 |
"execution_count": 3, |
|
|
92 |
"id": "79042bdd", |
|
|
93 |
"metadata": {}, |
|
|
94 |
"outputs": [ |
|
|
95 |
{ |
|
|
96 |
"name": "stdout", |
|
|
97 |
"output_type": "stream", |
|
|
98 |
"text": [ |
|
|
99 |
"找到 1808 个JSON文件在目录: /media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/sam2/sam2/data/heart_chambers_dataset\n", |
|
|
100 |
"子目录: ['a4c', 'a2c', 'a3c']\n" |
|
|
101 |
] |
|
|
102 |
}, |
|
|
103 |
{ |
|
|
104 |
"name": "stderr", |
|
|
105 |
"output_type": "stream", |
|
|
106 |
"text": [ |
|
|
107 |
"处理文件: 0%| | 3/1808 [00:00<01:03, 28.51it/s]/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 21407 (\\N{CJK UNIFIED IDEOGRAPH-539F}) missing from font(s) DejaVu Sans.\n", |
|
|
108 |
" plt.tight_layout()\n", |
|
|
109 |
"/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 22987 (\\N{CJK UNIFIED IDEOGRAPH-59CB}) missing from font(s) DejaVu Sans.\n", |
|
|
110 |
" plt.tight_layout()\n", |
|
|
111 |
"/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 22270 (\\N{CJK UNIFIED IDEOGRAPH-56FE}) missing from font(s) DejaVu Sans.\n", |
|
|
112 |
" plt.tight_layout()\n", |
|
|
113 |
"/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 20687 (\\N{CJK UNIFIED IDEOGRAPH-50CF}) missing from font(s) DejaVu Sans.\n", |
|
|
114 |
" plt.tight_layout()\n", |
|
|
115 |
"/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 25513 (\\N{CJK UNIFIED IDEOGRAPH-63A9}) missing from font(s) DejaVu Sans.\n", |
|
|
116 |
" plt.tight_layout()\n", |
|
|
117 |
"/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 30721 (\\N{CJK UNIFIED IDEOGRAPH-7801}) missing from font(s) DejaVu Sans.\n", |
|
|
118 |
" plt.tight_layout()\n", |
|
|
119 |
"/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 19981 (\\N{CJK UNIFIED IDEOGRAPH-4E0D}) missing from font(s) DejaVu Sans.\n", |
|
|
120 |
" plt.tight_layout()\n", |
|
|
121 |
"/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 21516 (\\N{CJK UNIFIED IDEOGRAPH-540C}) missing from font(s) DejaVu Sans.\n", |
|
|
122 |
" plt.tight_layout()\n", |
|
|
123 |
"/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 39068 (\\N{CJK UNIFIED IDEOGRAPH-989C}) missing from font(s) DejaVu Sans.\n", |
|
|
124 |
" plt.tight_layout()\n", |
|
|
125 |
"/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 33394 (\\N{CJK UNIFIED IDEOGRAPH-8272}) missing from font(s) DejaVu Sans.\n", |
|
|
126 |
" plt.tight_layout()\n", |
|
|
127 |
"/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 24102 (\\N{CJK UNIFIED IDEOGRAPH-5E26}) missing from font(s) DejaVu Sans.\n", |
|
|
128 |
" plt.tight_layout()\n", |
|
|
129 |
"/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 26377 (\\N{CJK UNIFIED IDEOGRAPH-6709}) missing from font(s) DejaVu Sans.\n", |
|
|
130 |
" plt.tight_layout()\n", |
|
|
131 |
"/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 21644 (\\N{CJK UNIFIED IDEOGRAPH-548C}) missing from font(s) DejaVu Sans.\n", |
|
|
132 |
" plt.tight_layout()\n", |
|
|
133 |
"/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 28857 (\\N{CJK UNIFIED IDEOGRAPH-70B9}) missing from font(s) DejaVu Sans.\n", |
|
|
134 |
" plt.tight_layout()\n", |
|
|
135 |
"/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 30340 (\\N{CJK UNIFIED IDEOGRAPH-7684}) missing from font(s) DejaVu Sans.\n", |
|
|
136 |
" plt.tight_layout()\n", |
|
|
137 |
"/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 21407 (\\N{CJK UNIFIED IDEOGRAPH-539F}) missing from font(s) DejaVu Sans.\n", |
|
|
138 |
" plt.savefig(output_path)\n", |
|
|
139 |
"/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 22987 (\\N{CJK UNIFIED IDEOGRAPH-59CB}) missing from font(s) DejaVu Sans.\n", |
|
|
140 |
" plt.savefig(output_path)\n", |
|
|
141 |
"/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 22270 (\\N{CJK UNIFIED IDEOGRAPH-56FE}) missing from font(s) DejaVu Sans.\n", |
|
|
142 |
" plt.savefig(output_path)\n", |
|
|
143 |
"/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 20687 (\\N{CJK UNIFIED IDEOGRAPH-50CF}) missing from font(s) DejaVu Sans.\n", |
|
|
144 |
" plt.savefig(output_path)\n", |
|
|
145 |
"/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 25513 (\\N{CJK UNIFIED IDEOGRAPH-63A9}) missing from font(s) DejaVu Sans.\n", |
|
|
146 |
" plt.savefig(output_path)\n", |
|
|
147 |
"/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 30721 (\\N{CJK UNIFIED IDEOGRAPH-7801}) missing from font(s) DejaVu Sans.\n", |
|
|
148 |
" plt.savefig(output_path)\n", |
|
|
149 |
"/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 19981 (\\N{CJK UNIFIED IDEOGRAPH-4E0D}) missing from font(s) DejaVu Sans.\n", |
|
|
150 |
" plt.savefig(output_path)\n", |
|
|
151 |
"/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 21516 (\\N{CJK UNIFIED IDEOGRAPH-540C}) missing from font(s) DejaVu Sans.\n", |
|
|
152 |
" plt.savefig(output_path)\n", |
|
|
153 |
"/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 39068 (\\N{CJK UNIFIED IDEOGRAPH-989C}) missing from font(s) DejaVu Sans.\n", |
|
|
154 |
" plt.savefig(output_path)\n", |
|
|
155 |
"/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 33394 (\\N{CJK UNIFIED IDEOGRAPH-8272}) missing from font(s) DejaVu Sans.\n", |
|
|
156 |
" plt.savefig(output_path)\n", |
|
|
157 |
"/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 24102 (\\N{CJK UNIFIED IDEOGRAPH-5E26}) missing from font(s) DejaVu Sans.\n", |
|
|
158 |
" plt.savefig(output_path)\n", |
|
|
159 |
"/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 26377 (\\N{CJK UNIFIED IDEOGRAPH-6709}) missing from font(s) DejaVu Sans.\n", |
|
|
160 |
" plt.savefig(output_path)\n", |
|
|
161 |
"/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 21644 (\\N{CJK UNIFIED IDEOGRAPH-548C}) missing from font(s) DejaVu Sans.\n", |
|
|
162 |
" plt.savefig(output_path)\n", |
|
|
163 |
"/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 28857 (\\N{CJK UNIFIED IDEOGRAPH-70B9}) missing from font(s) DejaVu Sans.\n", |
|
|
164 |
" plt.savefig(output_path)\n", |
|
|
165 |
"/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 30340 (\\N{CJK UNIFIED IDEOGRAPH-7684}) missing from font(s) DejaVu Sans.\n", |
|
|
166 |
" plt.savefig(output_path)\n", |
|
|
167 |
"处理文件: 1%| | 18/1808 [00:01<01:33, 19.19it/s]" |
|
|
168 |
] |
|
|
169 |
}, |
|
|
170 |
{ |
|
|
171 |
"name": "stderr", |
|
|
172 |
"output_type": "stream", |
|
|
173 |
"text": [ |
|
|
174 |
"处理文件: 100%|██████████| 1808/1808 [02:45<00:00, 10.89it/s]" |
|
|
175 |
] |
|
|
176 |
}, |
|
|
177 |
{ |
|
|
178 |
"name": "stdout", |
|
|
179 |
"output_type": "stream", |
|
|
180 |
"text": [ |
|
|
181 |
"数据集创建完成! 总共处理了 1808 个文件,生成了 6620 个掩码。\n", |
|
|
182 |
"图像保存在: /media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train/JPEGImages\n", |
|
|
183 |
"掩码保存在: /media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train/Annotations\n", |
|
|
184 |
"可视化结果保存在: /media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train/Visualization\n", |
|
|
185 |
"训练CSV文件: /media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train/train.csv\n" |
|
|
186 |
] |
|
|
187 |
}, |
|
|
188 |
{ |
|
|
189 |
"name": "stderr", |
|
|
190 |
"output_type": "stream", |
|
|
191 |
"text": [ |
|
|
192 |
"\n" |
|
|
193 |
] |
|
|
194 |
} |
|
|
195 |
], |
|
|
196 |
"source": [ |
|
|
197 |
"\n", |
|
|
198 |
"import json\n", |
|
|
199 |
"import base64\n", |
|
|
200 |
"import numpy as np\n", |
|
|
201 |
"import cv2\n", |
|
|
202 |
"from PIL import Image\n", |
|
|
203 |
"import io\n", |
|
|
204 |
"import shutil\n", |
|
|
205 |
"from shapely.geometry import Polygon\n", |
|
|
206 |
"import random\n", |
|
|
207 |
"import pandas as pd\n", |
|
|
208 |
"from tqdm import tqdm\n", |
|
|
209 |
"import matplotlib.pyplot as plt\n", |
|
|
210 |
"\n", |
|
|
211 |
"def ensure_dir(directory):\n", |
|
|
212 |
" \"\"\"确保目录存在,如果不存在则创建\"\"\"\n", |
|
|
213 |
" if not os.path.exists(directory):\n", |
|
|
214 |
" os.makedirs(directory)\n", |
|
|
215 |
"\n", |
|
|
216 |
"def decode_base64_to_image(base64_string):\n", |
|
|
217 |
" \"\"\"将Base64编码的图像数据解码为PIL图像\"\"\"\n", |
|
|
218 |
" if not base64_string or len(base64_string) < 100: # 简单检查以避免处理截断的字符串\n", |
|
|
219 |
" return None\n", |
|
|
220 |
" \n", |
|
|
221 |
" try:\n", |
|
|
222 |
" # 尝试解码Base64字符串\n", |
|
|
223 |
" image_data = base64.b64decode(base64_string)\n", |
|
|
224 |
" image = Image.open(io.BytesIO(image_data))\n", |
|
|
225 |
" return image\n", |
|
|
226 |
" except Exception as e:\n", |
|
|
227 |
" print(f\"解码Base64图像时出错: {e}\")\n", |
|
|
228 |
" return None\n", |
|
|
229 |
"\n", |
|
|
230 |
"def draw_polygon_mask(width, height, points, label):\n", |
|
|
231 |
" \"\"\"根据多边形点创建二值掩码\"\"\"\n", |
|
|
232 |
" # 创建空白掩码\n", |
|
|
233 |
" mask = np.zeros((height, width), dtype=np.uint8)\n", |
|
|
234 |
" \n", |
|
|
235 |
" # 将点格式转换为OpenCV需要的格式\n", |
|
|
236 |
" points_array = np.array(points, dtype=np.int32)\n", |
|
|
237 |
" \n", |
|
|
238 |
" # 绘制填充多边形\n", |
|
|
239 |
" cv2.fillPoly(mask, [points_array], 255)\n", |
|
|
240 |
" \n", |
|
|
241 |
" return mask\n", |
|
|
242 |
"\n", |
|
|
243 |
"def generate_point_in_mask(mask, num_points=1):\n", |
|
|
244 |
" \"\"\"在掩码内生成随机点\"\"\"\n", |
|
|
245 |
" # 找到掩码中值为255的像素位置\n", |
|
|
246 |
" y_indices, x_indices = np.where(mask == 255)\n", |
|
|
247 |
" \n", |
|
|
248 |
" if len(y_indices) == 0:\n", |
|
|
249 |
" return [] # 如果掩码为空,返回空列表\n", |
|
|
250 |
" \n", |
|
|
251 |
" points = []\n", |
|
|
252 |
" for _ in range(num_points):\n", |
|
|
253 |
" # 随机选择一个位置\n", |
|
|
254 |
" idx = random.randint(0, len(y_indices) - 1)\n", |
|
|
255 |
" x, y = int(x_indices[idx]), int(y_indices[idx])\n", |
|
|
256 |
" points.append((x, y))\n", |
|
|
257 |
" \n", |
|
|
258 |
" return points\n", |
|
|
259 |
"\n", |
|
|
260 |
"def visualize_masks(image, masks, labels, points, output_path):\n", |
|
|
261 |
" \"\"\"\n", |
|
|
262 |
" 可视化图像、掩码和点\n", |
|
|
263 |
" \n", |
|
|
264 |
" 参数:\n", |
|
|
265 |
" - image: 原始图像 (PIL图像对象或numpy数组)\n", |
|
|
266 |
" - masks: 掩码列表 (每个掩码是numpy数组)\n", |
|
|
267 |
" - labels: 标签列表\n", |
|
|
268 |
" - points: 每个掩码中的点列表 (每个点是(x,y)元组)\n", |
|
|
269 |
" - output_path: 输出图像路径\n", |
|
|
270 |
" \"\"\"\n", |
|
|
271 |
" # 确保image是numpy数组\n", |
|
|
272 |
" if isinstance(image, Image.Image):\n", |
|
|
273 |
" img_np = np.array(image)\n", |
|
|
274 |
" else:\n", |
|
|
275 |
" img_np = image.copy()\n", |
|
|
276 |
" \n", |
|
|
277 |
" # 创建可视化图像\n", |
|
|
278 |
" vis_img = img_np.copy()\n", |
|
|
279 |
" \n", |
|
|
280 |
" # 颜色映射 (标签到RGB颜色)\n", |
|
|
281 |
" color_map = {\n", |
|
|
282 |
" 'LV': (255, 0, 0), # 红色\n", |
|
|
283 |
" 'LA': (0, 255, 0), # 绿色\n", |
|
|
284 |
" 'RV': (0, 0, 255), # 蓝色\n", |
|
|
285 |
" 'RA': (0, 255, 255), # 黄色\n", |
|
|
286 |
" 'M': (255, 0, 255), # 紫色\n", |
|
|
287 |
" 'unknown': (128, 128, 128) # 灰色\n", |
|
|
288 |
" }\n", |
|
|
289 |
" \n", |
|
|
290 |
" # 创建图形和子图\n", |
|
|
291 |
" fig, axes = plt.subplots(1, 3, figsize=(18, 6))\n", |
|
|
292 |
" \n", |
|
|
293 |
" # 显示原始图像\n", |
|
|
294 |
" axes[0].imshow(img_np)\n", |
|
|
295 |
" axes[0].set_title('原始图像')\n", |
|
|
296 |
" axes[0].axis('off')\n", |
|
|
297 |
" \n", |
|
|
298 |
" # 创建合并掩码的彩色图像\n", |
|
|
299 |
" color_mask = np.zeros((*img_np.shape[:2], 3), dtype=np.uint8)\n", |
|
|
300 |
" \n", |
|
|
301 |
" # 将掩码叠加到图像上,使用不同颜色\n", |
|
|
302 |
" for mask, label, point in zip(masks, labels, points):\n", |
|
|
303 |
" # 获取标签对应的颜色\n", |
|
|
304 |
" color = color_map.get(label, color_map['unknown'])\n", |
|
|
305 |
" \n", |
|
|
306 |
" # 添加掩码颜色\n", |
|
|
307 |
" for i in range(3):\n", |
|
|
308 |
" color_mask[:, :, i][mask > 0] = color[i]\n", |
|
|
309 |
" \n", |
|
|
310 |
" # 显示彩色掩码\n", |
|
|
311 |
" axes[1].imshow(color_mask)\n", |
|
|
312 |
" axes[1].set_title('掩码 (不同颜色)')\n", |
|
|
313 |
" axes[1].axis('off')\n", |
|
|
314 |
" \n", |
|
|
315 |
" # 将掩码和点叠加到图像上\n", |
|
|
316 |
" overlay = img_np.copy()\n", |
|
|
317 |
" for mask, label, point in zip(masks, labels, points):\n", |
|
|
318 |
" # 获取标签对应的颜色\n", |
|
|
319 |
" color = color_map.get(label, color_map['unknown'])\n", |
|
|
320 |
" \n", |
|
|
321 |
" # 叠加掩码\n", |
|
|
322 |
" for i in range(3):\n", |
|
|
323 |
" overlay[:, :, i] = np.where(mask > 0, \n", |
|
|
324 |
" (overlay[:, :, i] * 0.7 + color[i] * 0.3).astype(np.uint8), \n", |
|
|
325 |
" overlay[:, :, i])\n", |
|
|
326 |
" \n", |
|
|
327 |
" # 在图像上标记点\n", |
|
|
328 |
" if point:\n", |
|
|
329 |
" # 绘制点\n", |
|
|
330 |
" cv2.circle(overlay, point, 5, (255, 255, 255), -1) # 白色实心圆\n", |
|
|
331 |
" # 确保点在图像内\n", |
|
|
332 |
" y, x = point[1], point[0]\n", |
|
|
333 |
" if 0 <= y < overlay.shape[0] and 0 <= x < overlay.shape[1]:\n", |
|
|
334 |
" cv2.circle(overlay, point, 5, (0, 0, 0), 1) # 黑色圆边框\n", |
|
|
335 |
" \n", |
|
|
336 |
" # 显示带有掩码和点的图像\n", |
|
|
337 |
" axes[2].imshow(overlay)\n", |
|
|
338 |
" axes[2].set_title('带有掩码和点的图像')\n", |
|
|
339 |
" for i, (label, color) in enumerate(color_map.items()):\n", |
|
|
340 |
" if label in labels:\n", |
|
|
341 |
" # 将RGB颜色转换为0-1范围\n", |
|
|
342 |
" normalized_color = [c/255 for c in color]\n", |
|
|
343 |
" axes[2].plot([], [], 'o', color=normalized_color, label=label)\n", |
|
|
344 |
" axes[2].legend(loc='upper right')\n", |
|
|
345 |
" axes[2].axis('off')\n", |
|
|
346 |
" \n", |
|
|
347 |
" # 保存图像\n", |
|
|
348 |
" plt.tight_layout()\n", |
|
|
349 |
" plt.savefig(output_path)\n", |
|
|
350 |
" plt.close(fig)\n", |
|
|
351 |
"\n", |
|
|
352 |
"def process_json_file(json_file, output_dir, index, subdir=\"\"):\n", |
|
|
353 |
" \"\"\"处理单个JSON文件,生成图像和掩码\"\"\"\n", |
|
|
354 |
" try:\n", |
|
|
355 |
" with open(json_file, 'r') as f:\n", |
|
|
356 |
" data = json.load(f)\n", |
|
|
357 |
" \n", |
|
|
358 |
" # 提取文件名(不含扩展名)\n", |
|
|
359 |
" base_name = os.path.splitext(os.path.basename(json_file))[0]\n", |
|
|
360 |
" \n", |
|
|
361 |
" # 使用子目录和索引创建唯一文件名\n", |
|
|
362 |
" if subdir:\n", |
|
|
363 |
" unique_id = f\"{subdir}_{base_name}_{index}\"\n", |
|
|
364 |
" else:\n", |
|
|
365 |
" unique_id = f\"{base_name}_{index}\"\n", |
|
|
366 |
" \n", |
|
|
367 |
" # 解码Base64图像\n", |
|
|
368 |
" image_data = data.get('imageData', '')\n", |
|
|
369 |
" image = decode_base64_to_image(image_data)\n", |
|
|
370 |
" \n", |
|
|
371 |
" if image is None:\n", |
|
|
372 |
" print(f\"无法解码图像: {json_file}\")\n", |
|
|
373 |
" return None\n", |
|
|
374 |
" \n", |
|
|
375 |
" # 获取图像尺寸\n", |
|
|
376 |
" width, height = image.size\n", |
|
|
377 |
" \n", |
|
|
378 |
" # 保存原始图像\n", |
|
|
379 |
" image_path = os.path.join(output_dir, 'JPEGImages', f\"{unique_id}.jpg\")\n", |
|
|
380 |
" image.save(image_path)\n", |
|
|
381 |
" \n", |
|
|
382 |
" shapes = data.get('shapes', [])\n", |
|
|
383 |
" \n", |
|
|
384 |
" # 用于存储每个掩码的信息\n", |
|
|
385 |
" mask_info = []\n", |
|
|
386 |
" \n", |
|
|
387 |
" # 存储可视化相关的数据\n", |
|
|
388 |
" all_masks = []\n", |
|
|
389 |
" all_labels = []\n", |
|
|
390 |
" all_points = []\n", |
|
|
391 |
" \n", |
|
|
392 |
" # 处理每个形状/掩码\n", |
|
|
393 |
" for i, shape in enumerate(shapes):\n", |
|
|
394 |
" label = shape.get('label', 'unknown')\n", |
|
|
395 |
" points = shape.get('points', [])\n", |
|
|
396 |
" \n", |
|
|
397 |
" if not points:\n", |
|
|
398 |
" continue\n", |
|
|
399 |
" \n", |
|
|
400 |
" # 创建掩码\n", |
|
|
401 |
" mask = draw_polygon_mask(width, height, points, label)\n", |
|
|
402 |
" \n", |
|
|
403 |
" # 生成掩码文件名\n", |
|
|
404 |
" mask_filename = f\"{unique_id}_{label}_{i}.png\"\n", |
|
|
405 |
" mask_path = os.path.join(output_dir, 'Annotations', mask_filename)\n", |
|
|
406 |
" \n", |
|
|
407 |
" # 保存掩码\n", |
|
|
408 |
" cv2.imwrite(mask_path, mask)\n", |
|
|
409 |
" \n", |
|
|
410 |
" # 在掩码内生成一个随机点\n", |
|
|
411 |
" random_points = generate_point_in_mask(mask)\n", |
|
|
412 |
" \n", |
|
|
413 |
" if random_points:\n", |
|
|
414 |
" mask_info.append({\n", |
|
|
415 |
" 'ImageId': f\"{unique_id}.jpg\",\n", |
|
|
416 |
" 'MaskId': mask_filename,\n", |
|
|
417 |
" 'Label': label,\n", |
|
|
418 |
" 'PointX': random_points[0][0],\n", |
|
|
419 |
" 'PointY': random_points[0][1]\n", |
|
|
420 |
" })\n", |
|
|
421 |
" \n", |
|
|
422 |
" # 存储可视化数据\n", |
|
|
423 |
" all_masks.append(mask)\n", |
|
|
424 |
" all_labels.append(label)\n", |
|
|
425 |
" all_points.append(random_points[0])\n", |
|
|
426 |
" \n", |
|
|
427 |
" # 如果有掩码,则随机可视化一个\n", |
|
|
428 |
" if all_masks:\n", |
|
|
429 |
" # 创建可视化目录\n", |
|
|
430 |
" vis_dir = os.path.join(output_dir, 'Visualization')\n", |
|
|
431 |
" ensure_dir(vis_dir)\n", |
|
|
432 |
" \n", |
|
|
433 |
" # 随机选择可视化\n", |
|
|
434 |
" if random.random() < 0.2: # 20%的概率进行可视化\n", |
|
|
435 |
" vis_path = os.path.join(vis_dir, f\"{unique_id}_visualization.png\")\n", |
|
|
436 |
" visualize_masks(image, all_masks, all_labels, all_points, vis_path)\n", |
|
|
437 |
" \n", |
|
|
438 |
" return mask_info\n", |
|
|
439 |
" \n", |
|
|
440 |
" except Exception as e:\n", |
|
|
441 |
" print(f\"处理文件 {json_file} 时出错: {e}\")\n", |
|
|
442 |
" import traceback\n", |
|
|
443 |
" traceback.print_exc()\n", |
|
|
444 |
" return None\n", |
|
|
445 |
"\n", |
|
|
446 |
"def get_all_json_files(root_dir):\n", |
|
|
447 |
" \"\"\"递归获取目录及其子目录中的所有JSON文件\"\"\"\n", |
|
|
448 |
" json_files = []\n", |
|
|
449 |
" subdirs = []\n", |
|
|
450 |
" \n", |
|
|
451 |
" # 遍历目录\n", |
|
|
452 |
" for dirpath, dirnames, filenames in os.walk(root_dir):\n", |
|
|
453 |
" rel_path = os.path.relpath(dirpath, root_dir)\n", |
|
|
454 |
" if rel_path == '.':\n", |
|
|
455 |
" rel_path = ''\n", |
|
|
456 |
" \n", |
|
|
457 |
" # 获取子目录名称(仅一级子目录)\n", |
|
|
458 |
" if dirpath == root_dir:\n", |
|
|
459 |
" subdirs.extend(dirnames)\n", |
|
|
460 |
" \n", |
|
|
461 |
" # 添加JSON文件和它们的相对路径\n", |
|
|
462 |
" for filename in filenames:\n", |
|
|
463 |
" if filename.endswith('.json'):\n", |
|
|
464 |
" json_files.append((os.path.join(dirpath, filename), rel_path))\n", |
|
|
465 |
" \n", |
|
|
466 |
" return json_files, subdirs\n", |
|
|
467 |
"\n", |
|
|
468 |
"def create_dataset_structure(json_dir, output_dir):\n", |
|
|
469 |
" \"\"\"创建SAM模型训练所需的数据集结构\"\"\"\n", |
|
|
470 |
" # 清空并创建输出目录\n", |
|
|
471 |
" if os.path.exists(output_dir):\n", |
|
|
472 |
" shutil.rmtree(output_dir)\n", |
|
|
473 |
" \n", |
|
|
474 |
" # 创建必要的子目录\n", |
|
|
475 |
" ensure_dir(os.path.join(output_dir, 'JPEGImages'))\n", |
|
|
476 |
" ensure_dir(os.path.join(output_dir, 'Annotations'))\n", |
|
|
477 |
" ensure_dir(os.path.join(output_dir, 'Visualization'))\n", |
|
|
478 |
" \n", |
|
|
479 |
" # 获取所有JSON文件和子目录\n", |
|
|
480 |
" json_files, subdirs = get_all_json_files(json_dir)\n", |
|
|
481 |
" \n", |
|
|
482 |
" print(f\"找到 {len(json_files)} 个JSON文件在目录: {json_dir}\")\n", |
|
|
483 |
" print(f\"子目录: {subdirs}\")\n", |
|
|
484 |
" \n", |
|
|
485 |
" all_mask_info = []\n", |
|
|
486 |
" \n", |
|
|
487 |
" # 处理每个JSON文件\n", |
|
|
488 |
" for i, (json_file, subdir) in enumerate(tqdm(json_files, desc=\"处理文件\")):\n", |
|
|
489 |
" mask_info = process_json_file(json_file, output_dir, i, subdir)\n", |
|
|
490 |
" if mask_info:\n", |
|
|
491 |
" all_mask_info.extend(mask_info)\n", |
|
|
492 |
" \n", |
|
|
493 |
" # 创建CSV文件\n", |
|
|
494 |
" df = pd.DataFrame(all_mask_info)\n", |
|
|
495 |
" csv_path = os.path.join(output_dir, 'train.csv')\n", |
|
|
496 |
" df.to_csv(csv_path, index=False)\n", |
|
|
497 |
" \n", |
|
|
498 |
" print(f\"数据集创建完成! 总共处理了 {len(json_files)} 个文件,生成了 {len(all_mask_info)} 个掩码。\")\n", |
|
|
499 |
" print(f\"图像保存在: {os.path.join(output_dir, 'JPEGImages')}\")\n", |
|
|
500 |
" print(f\"掩码保存在: {os.path.join(output_dir, 'Annotations')}\")\n", |
|
|
501 |
" print(f\"可视化结果保存在: {os.path.join(output_dir, 'Visualization')}\")\n", |
|
|
502 |
" print(f\"训练CSV文件: {csv_path}\")\n", |
|
|
503 |
"\n", |
|
|
504 |
"# 示例用法\n", |
|
|
505 |
"if __name__ == \"__main__\":\n", |
|
|
506 |
" # 如果你有JSON文件目录\n", |
|
|
507 |
" json_dir = \"/media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/sam2/sam2/data/heart_chambers_dataset\" # 替换为你的JSON文件目录\n", |
|
|
508 |
" output_dir = \"/media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train\"\n", |
|
|
509 |
" create_dataset_structure(json_dir, output_dir)\n", |
|
|
510 |
" \n" |
|
|
511 |
] |
|
|
512 |
}, |
|
|
513 |
{ |
|
|
514 |
"cell_type": "code", |
|
|
515 |
"execution_count": null, |
|
|
516 |
"id": "aa78b4c5", |
|
|
517 |
"metadata": {}, |
|
|
518 |
"outputs": [], |
|
|
519 |
"source": [] |
|
|
520 |
} |
|
|
521 |
], |
|
|
522 |
"metadata": { |
|
|
523 |
"kernelspec": { |
|
|
524 |
"display_name": "Python 3", |
|
|
525 |
"language": "python", |
|
|
526 |
"name": "python3" |
|
|
527 |
}, |
|
|
528 |
"language_info": { |
|
|
529 |
"codemirror_mode": { |
|
|
530 |
"name": "ipython", |
|
|
531 |
"version": 3 |
|
|
532 |
}, |
|
|
533 |
"file_extension": ".py", |
|
|
534 |
"mimetype": "text/x-python", |
|
|
535 |
"name": "python", |
|
|
536 |
"nbconvert_exporter": "python", |
|
|
537 |
"pygments_lexer": "ipython3", |
|
|
538 |
"version": "3.10.12" |
|
|
539 |
} |
|
|
540 |
}, |
|
|
541 |
"nbformat": 4, |
|
|
542 |
"nbformat_minor": 5 |
|
|
543 |
} |