Diff of /data_process.ipynb [000000] .. [d255cc]

Switch to unified view

a b/data_process.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": 7,
6
   "id": "30cc0acc",
7
   "metadata": {},
8
   "outputs": [
9
    {
10
     "name": "stdout",
11
     "output_type": "stream",
12
     "text": [
13
      "找到的CSV文件: ['train.csv']\n",
14
      "\n",
15
      "读取文件: train.csv\n",
16
      "CSV文件列名: ['ImageId', 'MaskId', 'Label', 'PointX', 'PointY']\n",
17
      "CSV文件前5行:\n",
18
      "                          ImageId                               MaskId Label  \\\n",
19
      "0  a4c_PatientA0030_a4c_20S_0.jpg  a4c_PatientA0030_a4c_20S_0_LV_0.png    LV   \n",
20
      "1  a4c_PatientA0030_a4c_20S_0.jpg   a4c_PatientA0030_a4c_20S_0_M_1.png     M   \n",
21
      "2  a4c_PatientA0030_a4c_20S_0.jpg  a4c_PatientA0030_a4c_20S_0_LA_2.png    LA   \n",
22
      "3  a4c_PatientA0030_a4c_20S_0.jpg  a4c_PatientA0030_a4c_20S_0_RV_3.png    RV   \n",
23
      "4  a4c_PatientA0030_a4c_20S_0.jpg  a4c_PatientA0030_a4c_20S_0_RA_4.png    RA   \n",
24
      "\n",
25
      "   PointX  PointY  \n",
26
      "0     373     164  \n",
27
      "1     522     296  \n",
28
      "2     454     404  \n",
29
      "3     289     287  \n",
30
      "4     334     457  \n",
31
      "CSV文件共有 6620 行\n",
32
      "\n",
33
      "图像文件夹存在: True\n",
34
      "图像文件夹中的文件数量: 1808\n",
35
      "图像文件夹中的前5个文件: ['a4c_PatientA0195_a4c_1D_141.jpg', 'a2c_PatientA0107_a2c_20S_959.jpg', 'a4c_PatientA0141_a4c_45D_508.jpg', 'a4c_PatientA0047_a4c_41D_63.jpg', 'a2c_PatientD0059_a2c_27.jsonS_996.jpg']\n",
36
      "\n",
37
      "标注文件夹存在: True\n",
38
      "标注文件夹中的文件数量: 6620\n",
39
      "标注文件夹中的前5个文件: ['a3c_PatientA0138_a3c_27D_1769_M_2.png', 'a4c_PatientA0157_a4c_28S_549_LA_1.png', 'a4c_PatientC0035_a4c_83D_35_M_1.png', 'a4c_PatientA0089_a4c_52D_512_LV_3.png', 'a4c_PatientC0038_a4c_43D_563_LV_0.png']\n"
40
     ]
41
    }
42
   ],
43
   "source": [
44
    "import os\n",
45
    "import pandas as pd\n",
46
    "\n",
47
    "# 您的数据路径\n",
48
    "data_path = \"/media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train\"\n",
49
    "\n",
50
    "# 查找CSV文件\n",
51
    "csv_files = []\n",
52
    "for file in os.listdir(data_path):\n",
53
    "    if file.endswith('.csv'):\n",
54
    "        csv_files.append(file)\n",
55
    "\n",
56
    "print(f\"找到的CSV文件: {csv_files}\")\n",
57
    "\n",
58
    "# 如果存在CSV文件,读取并显示其内容\n",
59
    "if csv_files:\n",
60
    "    for csv_file in csv_files:\n",
61
    "        file_path = os.path.join(data_path, csv_file)\n",
62
    "        print(f\"\\n读取文件: {csv_file}\")\n",
63
    "        try:\n",
64
    "            df = pd.read_csv(file_path)\n",
65
    "            print(f\"CSV文件列名: {df.columns.tolist()}\")\n",
66
    "            print(f\"CSV文件前5行:\")\n",
67
    "            print(df.head(5))\n",
68
    "            print(f\"CSV文件共有 {len(df)} 行\")\n",
69
    "        except Exception as e:\n",
70
    "            print(f\"读取文件出错: {e}\")\n",
71
    "else:\n",
72
    "    print(\"数据目录中没有找到CSV文件\")\n",
73
    "\n",
74
    "# 检查图像和标注文件夹\n",
75
    "images_dir = os.path.join(data_path, \"JPEGImages\")\n",
76
    "annot_dir = os.path.join(data_path, \"Annotations\")\n",
77
    "\n",
78
    "print(f\"\\n图像文件夹存在: {os.path.exists(images_dir)}\")\n",
79
    "if os.path.exists(images_dir):\n",
80
    "    print(f\"图像文件夹中的文件数量: {len(os.listdir(images_dir))}\")\n",
81
    "    print(f\"图像文件夹中的前5个文件: {os.listdir(images_dir)[:5]}\")\n",
82
    "\n",
83
    "print(f\"\\n标注文件夹存在: {os.path.exists(annot_dir)}\")\n",
84
    "if os.path.exists(annot_dir):\n",
85
    "    print(f\"标注文件夹中的文件数量: {len(os.listdir(annot_dir))}\")\n",
86
    "    print(f\"标注文件夹中的前5个文件: {os.listdir(annot_dir)[:5]}\")"
87
   ]
88
  },
89
  {
90
   "cell_type": "code",
91
   "execution_count": 3,
92
   "id": "79042bdd",
93
   "metadata": {},
94
   "outputs": [
95
    {
96
     "name": "stdout",
97
     "output_type": "stream",
98
     "text": [
99
      "找到 1808 个JSON文件在目录: /media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/sam2/sam2/data/heart_chambers_dataset\n",
100
      "子目录: ['a4c', 'a2c', 'a3c']\n"
101
     ]
102
    },
103
    {
104
     "name": "stderr",
105
     "output_type": "stream",
106
     "text": [
107
      "处理文件:   0%|          | 3/1808 [00:00<01:03, 28.51it/s]/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 21407 (\\N{CJK UNIFIED IDEOGRAPH-539F}) missing from font(s) DejaVu Sans.\n",
108
      "  plt.tight_layout()\n",
109
      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 22987 (\\N{CJK UNIFIED IDEOGRAPH-59CB}) missing from font(s) DejaVu Sans.\n",
110
      "  plt.tight_layout()\n",
111
      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 22270 (\\N{CJK UNIFIED IDEOGRAPH-56FE}) missing from font(s) DejaVu Sans.\n",
112
      "  plt.tight_layout()\n",
113
      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 20687 (\\N{CJK UNIFIED IDEOGRAPH-50CF}) missing from font(s) DejaVu Sans.\n",
114
      "  plt.tight_layout()\n",
115
      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 25513 (\\N{CJK UNIFIED IDEOGRAPH-63A9}) missing from font(s) DejaVu Sans.\n",
116
      "  plt.tight_layout()\n",
117
      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 30721 (\\N{CJK UNIFIED IDEOGRAPH-7801}) missing from font(s) DejaVu Sans.\n",
118
      "  plt.tight_layout()\n",
119
      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 19981 (\\N{CJK UNIFIED IDEOGRAPH-4E0D}) missing from font(s) DejaVu Sans.\n",
120
      "  plt.tight_layout()\n",
121
      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 21516 (\\N{CJK UNIFIED IDEOGRAPH-540C}) missing from font(s) DejaVu Sans.\n",
122
      "  plt.tight_layout()\n",
123
      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 39068 (\\N{CJK UNIFIED IDEOGRAPH-989C}) missing from font(s) DejaVu Sans.\n",
124
      "  plt.tight_layout()\n",
125
      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 33394 (\\N{CJK UNIFIED IDEOGRAPH-8272}) missing from font(s) DejaVu Sans.\n",
126
      "  plt.tight_layout()\n",
127
      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 24102 (\\N{CJK UNIFIED IDEOGRAPH-5E26}) missing from font(s) DejaVu Sans.\n",
128
      "  plt.tight_layout()\n",
129
      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 26377 (\\N{CJK UNIFIED IDEOGRAPH-6709}) missing from font(s) DejaVu Sans.\n",
130
      "  plt.tight_layout()\n",
131
      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 21644 (\\N{CJK UNIFIED IDEOGRAPH-548C}) missing from font(s) DejaVu Sans.\n",
132
      "  plt.tight_layout()\n",
133
      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 28857 (\\N{CJK UNIFIED IDEOGRAPH-70B9}) missing from font(s) DejaVu Sans.\n",
134
      "  plt.tight_layout()\n",
135
      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 30340 (\\N{CJK UNIFIED IDEOGRAPH-7684}) missing from font(s) DejaVu Sans.\n",
136
      "  plt.tight_layout()\n",
137
      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 21407 (\\N{CJK UNIFIED IDEOGRAPH-539F}) missing from font(s) DejaVu Sans.\n",
138
      "  plt.savefig(output_path)\n",
139
      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 22987 (\\N{CJK UNIFIED IDEOGRAPH-59CB}) missing from font(s) DejaVu Sans.\n",
140
      "  plt.savefig(output_path)\n",
141
      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 22270 (\\N{CJK UNIFIED IDEOGRAPH-56FE}) missing from font(s) DejaVu Sans.\n",
142
      "  plt.savefig(output_path)\n",
143
      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 20687 (\\N{CJK UNIFIED IDEOGRAPH-50CF}) missing from font(s) DejaVu Sans.\n",
144
      "  plt.savefig(output_path)\n",
145
      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 25513 (\\N{CJK UNIFIED IDEOGRAPH-63A9}) missing from font(s) DejaVu Sans.\n",
146
      "  plt.savefig(output_path)\n",
147
      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 30721 (\\N{CJK UNIFIED IDEOGRAPH-7801}) missing from font(s) DejaVu Sans.\n",
148
      "  plt.savefig(output_path)\n",
149
      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 19981 (\\N{CJK UNIFIED IDEOGRAPH-4E0D}) missing from font(s) DejaVu Sans.\n",
150
      "  plt.savefig(output_path)\n",
151
      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 21516 (\\N{CJK UNIFIED IDEOGRAPH-540C}) missing from font(s) DejaVu Sans.\n",
152
      "  plt.savefig(output_path)\n",
153
      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 39068 (\\N{CJK UNIFIED IDEOGRAPH-989C}) missing from font(s) DejaVu Sans.\n",
154
      "  plt.savefig(output_path)\n",
155
      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 33394 (\\N{CJK UNIFIED IDEOGRAPH-8272}) missing from font(s) DejaVu Sans.\n",
156
      "  plt.savefig(output_path)\n",
157
      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 24102 (\\N{CJK UNIFIED IDEOGRAPH-5E26}) missing from font(s) DejaVu Sans.\n",
158
      "  plt.savefig(output_path)\n",
159
      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 26377 (\\N{CJK UNIFIED IDEOGRAPH-6709}) missing from font(s) DejaVu Sans.\n",
160
      "  plt.savefig(output_path)\n",
161
      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 21644 (\\N{CJK UNIFIED IDEOGRAPH-548C}) missing from font(s) DejaVu Sans.\n",
162
      "  plt.savefig(output_path)\n",
163
      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 28857 (\\N{CJK UNIFIED IDEOGRAPH-70B9}) missing from font(s) DejaVu Sans.\n",
164
      "  plt.savefig(output_path)\n",
165
      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 30340 (\\N{CJK UNIFIED IDEOGRAPH-7684}) missing from font(s) DejaVu Sans.\n",
166
      "  plt.savefig(output_path)\n",
167
      "处理文件:   1%|          | 18/1808 [00:01<01:33, 19.19it/s]"
168
     ]
169
    },
170
    {
171
     "name": "stderr",
172
     "output_type": "stream",
173
     "text": [
174
      "处理文件: 100%|██████████| 1808/1808 [02:45<00:00, 10.89it/s]"
175
     ]
176
    },
177
    {
178
     "name": "stdout",
179
     "output_type": "stream",
180
     "text": [
181
      "数据集创建完成! 总共处理了 1808 个文件,生成了 6620 个掩码。\n",
182
      "图像保存在: /media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train/JPEGImages\n",
183
      "掩码保存在: /media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train/Annotations\n",
184
      "可视化结果保存在: /media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train/Visualization\n",
185
      "训练CSV文件: /media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train/train.csv\n"
186
     ]
187
    },
188
    {
189
     "name": "stderr",
190
     "output_type": "stream",
191
     "text": [
192
      "\n"
193
     ]
194
    }
195
   ],
196
   "source": [
197
    "\n",
198
    "import json\n",
199
    "import base64\n",
200
    "import numpy as np\n",
201
    "import cv2\n",
202
    "from PIL import Image\n",
203
    "import io\n",
204
    "import shutil\n",
205
    "from shapely.geometry import Polygon\n",
206
    "import random\n",
207
    "import pandas as pd\n",
208
    "from tqdm import tqdm\n",
209
    "import matplotlib.pyplot as plt\n",
210
    "\n",
211
    "def ensure_dir(directory):\n",
212
    "    \"\"\"确保目录存在,如果不存在则创建\"\"\"\n",
213
    "    if not os.path.exists(directory):\n",
214
    "        os.makedirs(directory)\n",
215
    "\n",
216
    "def decode_base64_to_image(base64_string):\n",
217
    "    \"\"\"将Base64编码的图像数据解码为PIL图像\"\"\"\n",
218
    "    if not base64_string or len(base64_string) < 100:  # 简单检查以避免处理截断的字符串\n",
219
    "        return None\n",
220
    "    \n",
221
    "    try:\n",
222
    "        # 尝试解码Base64字符串\n",
223
    "        image_data = base64.b64decode(base64_string)\n",
224
    "        image = Image.open(io.BytesIO(image_data))\n",
225
    "        return image\n",
226
    "    except Exception as e:\n",
227
    "        print(f\"解码Base64图像时出错: {e}\")\n",
228
    "        return None\n",
229
    "\n",
230
    "def draw_polygon_mask(width, height, points, label):\n",
231
    "    \"\"\"根据多边形点创建二值掩码\"\"\"\n",
232
    "    # 创建空白掩码\n",
233
    "    mask = np.zeros((height, width), dtype=np.uint8)\n",
234
    "    \n",
235
    "    # 将点格式转换为OpenCV需要的格式\n",
236
    "    points_array = np.array(points, dtype=np.int32)\n",
237
    "    \n",
238
    "    # 绘制填充多边形\n",
239
    "    cv2.fillPoly(mask, [points_array], 255)\n",
240
    "    \n",
241
    "    return mask\n",
242
    "\n",
243
    "def generate_point_in_mask(mask, num_points=1):\n",
244
    "    \"\"\"在掩码内生成随机点\"\"\"\n",
245
    "    # 找到掩码中值为255的像素位置\n",
246
    "    y_indices, x_indices = np.where(mask == 255)\n",
247
    "    \n",
248
    "    if len(y_indices) == 0:\n",
249
    "        return []  # 如果掩码为空,返回空列表\n",
250
    "    \n",
251
    "    points = []\n",
252
    "    for _ in range(num_points):\n",
253
    "        # 随机选择一个位置\n",
254
    "        idx = random.randint(0, len(y_indices) - 1)\n",
255
    "        x, y = int(x_indices[idx]), int(y_indices[idx])\n",
256
    "        points.append((x, y))\n",
257
    "    \n",
258
    "    return points\n",
259
    "\n",
260
    "def visualize_masks(image, masks, labels, points, output_path):\n",
261
    "    \"\"\"\n",
262
    "    可视化图像、掩码和点\n",
263
    "    \n",
264
    "    参数:\n",
265
    "    - image: 原始图像 (PIL图像对象或numpy数组)\n",
266
    "    - masks: 掩码列表 (每个掩码是numpy数组)\n",
267
    "    - labels: 标签列表\n",
268
    "    - points: 每个掩码中的点列表 (每个点是(x,y)元组)\n",
269
    "    - output_path: 输出图像路径\n",
270
    "    \"\"\"\n",
271
    "    # 确保image是numpy数组\n",
272
    "    if isinstance(image, Image.Image):\n",
273
    "        img_np = np.array(image)\n",
274
    "    else:\n",
275
    "        img_np = image.copy()\n",
276
    "    \n",
277
    "    # 创建可视化图像\n",
278
    "    vis_img = img_np.copy()\n",
279
    "    \n",
280
    "    # 颜色映射 (标签到RGB颜色)\n",
281
    "    color_map = {\n",
282
    "        'LV': (255, 0, 0),    # 红色\n",
283
    "        'LA': (0, 255, 0),    # 绿色\n",
284
    "        'RV': (0, 0, 255),    # 蓝色\n",
285
    "        'RA': (0, 255, 255),  # 黄色\n",
286
    "        'M': (255, 0, 255),   # 紫色\n",
287
    "        'unknown': (128, 128, 128)  # 灰色\n",
288
    "    }\n",
289
    "    \n",
290
    "    # 创建图形和子图\n",
291
    "    fig, axes = plt.subplots(1, 3, figsize=(18, 6))\n",
292
    "    \n",
293
    "    # 显示原始图像\n",
294
    "    axes[0].imshow(img_np)\n",
295
    "    axes[0].set_title('原始图像')\n",
296
    "    axes[0].axis('off')\n",
297
    "    \n",
298
    "    # 创建合并掩码的彩色图像\n",
299
    "    color_mask = np.zeros((*img_np.shape[:2], 3), dtype=np.uint8)\n",
300
    "    \n",
301
    "    # 将掩码叠加到图像上,使用不同颜色\n",
302
    "    for mask, label, point in zip(masks, labels, points):\n",
303
    "        # 获取标签对应的颜色\n",
304
    "        color = color_map.get(label, color_map['unknown'])\n",
305
    "        \n",
306
    "        # 添加掩码颜色\n",
307
    "        for i in range(3):\n",
308
    "            color_mask[:, :, i][mask > 0] = color[i]\n",
309
    "    \n",
310
    "    # 显示彩色掩码\n",
311
    "    axes[1].imshow(color_mask)\n",
312
    "    axes[1].set_title('掩码 (不同颜色)')\n",
313
    "    axes[1].axis('off')\n",
314
    "    \n",
315
    "    # 将掩码和点叠加到图像上\n",
316
    "    overlay = img_np.copy()\n",
317
    "    for mask, label, point in zip(masks, labels, points):\n",
318
    "        # 获取标签对应的颜色\n",
319
    "        color = color_map.get(label, color_map['unknown'])\n",
320
    "        \n",
321
    "        # 叠加掩码\n",
322
    "        for i in range(3):\n",
323
    "            overlay[:, :, i] = np.where(mask > 0, \n",
324
    "                                      (overlay[:, :, i] * 0.7 + color[i] * 0.3).astype(np.uint8), \n",
325
    "                                      overlay[:, :, i])\n",
326
    "        \n",
327
    "        # 在图像上标记点\n",
328
    "        if point:\n",
329
    "            # 绘制点\n",
330
    "            cv2.circle(overlay, point, 5, (255, 255, 255), -1)  # 白色实心圆\n",
331
    "            # 确保点在图像内\n",
332
    "            y, x = point[1], point[0]\n",
333
    "            if 0 <= y < overlay.shape[0] and 0 <= x < overlay.shape[1]:\n",
334
    "                cv2.circle(overlay, point, 5, (0, 0, 0), 1)  # 黑色圆边框\n",
335
    "    \n",
336
    "    # 显示带有掩码和点的图像\n",
337
    "    axes[2].imshow(overlay)\n",
338
    "    axes[2].set_title('带有掩码和点的图像')\n",
339
    "    for i, (label, color) in enumerate(color_map.items()):\n",
340
    "        if label in labels:\n",
341
    "            # 将RGB颜色转换为0-1范围\n",
342
    "            normalized_color = [c/255 for c in color]\n",
343
    "            axes[2].plot([], [], 'o', color=normalized_color, label=label)\n",
344
    "    axes[2].legend(loc='upper right')\n",
345
    "    axes[2].axis('off')\n",
346
    "    \n",
347
    "    # 保存图像\n",
348
    "    plt.tight_layout()\n",
349
    "    plt.savefig(output_path)\n",
350
    "    plt.close(fig)\n",
351
    "\n",
352
    "def process_json_file(json_file, output_dir, index, subdir=\"\"):\n",
353
    "    \"\"\"处理单个JSON文件,生成图像和掩码\"\"\"\n",
354
    "    try:\n",
355
    "        with open(json_file, 'r') as f:\n",
356
    "            data = json.load(f)\n",
357
    "        \n",
358
    "        # 提取文件名(不含扩展名)\n",
359
    "        base_name = os.path.splitext(os.path.basename(json_file))[0]\n",
360
    "        \n",
361
    "        # 使用子目录和索引创建唯一文件名\n",
362
    "        if subdir:\n",
363
    "            unique_id = f\"{subdir}_{base_name}_{index}\"\n",
364
    "        else:\n",
365
    "            unique_id = f\"{base_name}_{index}\"\n",
366
    "        \n",
367
    "        # 解码Base64图像\n",
368
    "        image_data = data.get('imageData', '')\n",
369
    "        image = decode_base64_to_image(image_data)\n",
370
    "        \n",
371
    "        if image is None:\n",
372
    "            print(f\"无法解码图像: {json_file}\")\n",
373
    "            return None\n",
374
    "        \n",
375
    "        # 获取图像尺寸\n",
376
    "        width, height = image.size\n",
377
    "        \n",
378
    "        # 保存原始图像\n",
379
    "        image_path = os.path.join(output_dir, 'JPEGImages', f\"{unique_id}.jpg\")\n",
380
    "        image.save(image_path)\n",
381
    "        \n",
382
    "        shapes = data.get('shapes', [])\n",
383
    "        \n",
384
    "        # 用于存储每个掩码的信息\n",
385
    "        mask_info = []\n",
386
    "        \n",
387
    "        # 存储可视化相关的数据\n",
388
    "        all_masks = []\n",
389
    "        all_labels = []\n",
390
    "        all_points = []\n",
391
    "        \n",
392
    "        # 处理每个形状/掩码\n",
393
    "        for i, shape in enumerate(shapes):\n",
394
    "            label = shape.get('label', 'unknown')\n",
395
    "            points = shape.get('points', [])\n",
396
    "            \n",
397
    "            if not points:\n",
398
    "                continue\n",
399
    "                \n",
400
    "            # 创建掩码\n",
401
    "            mask = draw_polygon_mask(width, height, points, label)\n",
402
    "            \n",
403
    "            # 生成掩码文件名\n",
404
    "            mask_filename = f\"{unique_id}_{label}_{i}.png\"\n",
405
    "            mask_path = os.path.join(output_dir, 'Annotations', mask_filename)\n",
406
    "            \n",
407
    "            # 保存掩码\n",
408
    "            cv2.imwrite(mask_path, mask)\n",
409
    "            \n",
410
    "            # 在掩码内生成一个随机点\n",
411
    "            random_points = generate_point_in_mask(mask)\n",
412
    "            \n",
413
    "            if random_points:\n",
414
    "                mask_info.append({\n",
415
    "                    'ImageId': f\"{unique_id}.jpg\",\n",
416
    "                    'MaskId': mask_filename,\n",
417
    "                    'Label': label,\n",
418
    "                    'PointX': random_points[0][0],\n",
419
    "                    'PointY': random_points[0][1]\n",
420
    "                })\n",
421
    "                \n",
422
    "                # 存储可视化数据\n",
423
    "                all_masks.append(mask)\n",
424
    "                all_labels.append(label)\n",
425
    "                all_points.append(random_points[0])\n",
426
    "        \n",
427
    "        # 如果有掩码,则随机可视化一个\n",
428
    "        if all_masks:\n",
429
    "            # 创建可视化目录\n",
430
    "            vis_dir = os.path.join(output_dir, 'Visualization')\n",
431
    "            ensure_dir(vis_dir)\n",
432
    "            \n",
433
    "            # 随机选择可视化\n",
434
    "            if random.random() < 0.2:  # 20%的概率进行可视化\n",
435
    "                vis_path = os.path.join(vis_dir, f\"{unique_id}_visualization.png\")\n",
436
    "                visualize_masks(image, all_masks, all_labels, all_points, vis_path)\n",
437
    "        \n",
438
    "        return mask_info\n",
439
    "    \n",
440
    "    except Exception as e:\n",
441
    "        print(f\"处理文件 {json_file} 时出错: {e}\")\n",
442
    "        import traceback\n",
443
    "        traceback.print_exc()\n",
444
    "        return None\n",
445
    "\n",
446
    "def get_all_json_files(root_dir):\n",
447
    "    \"\"\"递归获取目录及其子目录中的所有JSON文件\"\"\"\n",
448
    "    json_files = []\n",
449
    "    subdirs = []\n",
450
    "    \n",
451
    "    # 遍历目录\n",
452
    "    for dirpath, dirnames, filenames in os.walk(root_dir):\n",
453
    "        rel_path = os.path.relpath(dirpath, root_dir)\n",
454
    "        if rel_path == '.':\n",
455
    "            rel_path = ''\n",
456
    "            \n",
457
    "        # 获取子目录名称(仅一级子目录)\n",
458
    "        if dirpath == root_dir:\n",
459
    "            subdirs.extend(dirnames)\n",
460
    "            \n",
461
    "        # 添加JSON文件和它们的相对路径\n",
462
    "        for filename in filenames:\n",
463
    "            if filename.endswith('.json'):\n",
464
    "                json_files.append((os.path.join(dirpath, filename), rel_path))\n",
465
    "    \n",
466
    "    return json_files, subdirs\n",
467
    "\n",
468
    "def create_dataset_structure(json_dir, output_dir):\n",
469
    "    \"\"\"创建SAM模型训练所需的数据集结构\"\"\"\n",
470
    "    # 清空并创建输出目录\n",
471
    "    if os.path.exists(output_dir):\n",
472
    "        shutil.rmtree(output_dir)\n",
473
    "    \n",
474
    "    # 创建必要的子目录\n",
475
    "    ensure_dir(os.path.join(output_dir, 'JPEGImages'))\n",
476
    "    ensure_dir(os.path.join(output_dir, 'Annotations'))\n",
477
    "    ensure_dir(os.path.join(output_dir, 'Visualization'))\n",
478
    "    \n",
479
    "    # 获取所有JSON文件和子目录\n",
480
    "    json_files, subdirs = get_all_json_files(json_dir)\n",
481
    "    \n",
482
    "    print(f\"找到 {len(json_files)} 个JSON文件在目录: {json_dir}\")\n",
483
    "    print(f\"子目录: {subdirs}\")\n",
484
    "    \n",
485
    "    all_mask_info = []\n",
486
    "    \n",
487
    "    # 处理每个JSON文件\n",
488
    "    for i, (json_file, subdir) in enumerate(tqdm(json_files, desc=\"处理文件\")):\n",
489
    "        mask_info = process_json_file(json_file, output_dir, i, subdir)\n",
490
    "        if mask_info:\n",
491
    "            all_mask_info.extend(mask_info)\n",
492
    "    \n",
493
    "    # 创建CSV文件\n",
494
    "    df = pd.DataFrame(all_mask_info)\n",
495
    "    csv_path = os.path.join(output_dir, 'train.csv')\n",
496
    "    df.to_csv(csv_path, index=False)\n",
497
    "    \n",
498
    "    print(f\"数据集创建完成! 总共处理了 {len(json_files)} 个文件,生成了 {len(all_mask_info)} 个掩码。\")\n",
499
    "    print(f\"图像保存在: {os.path.join(output_dir, 'JPEGImages')}\")\n",
500
    "    print(f\"掩码保存在: {os.path.join(output_dir, 'Annotations')}\")\n",
501
    "    print(f\"可视化结果保存在: {os.path.join(output_dir, 'Visualization')}\")\n",
502
    "    print(f\"训练CSV文件: {csv_path}\")\n",
503
    "\n",
504
    "# 示例用法\n",
505
    "if __name__ == \"__main__\":\n",
506
    "    # 如果你有JSON文件目录\n",
507
    "    json_dir = \"/media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/sam2/sam2/data/heart_chambers_dataset\"  # 替换为你的JSON文件目录\n",
508
    "    output_dir = \"/media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train\"\n",
509
    "    create_dataset_structure(json_dir, output_dir)\n",
510
    "    \n"
511
   ]
512
  },
513
  {
514
   "cell_type": "code",
515
   "execution_count": null,
516
   "id": "aa78b4c5",
517
   "metadata": {},
518
   "outputs": [],
519
   "source": []
520
  }
521
 ],
522
 "metadata": {
523
  "kernelspec": {
524
   "display_name": "Python 3",
525
   "language": "python",
526
   "name": "python3"
527
  },
528
  "language_info": {
529
   "codemirror_mode": {
530
    "name": "ipython",
531
    "version": 3
532
   },
533
   "file_extension": ".py",
534
   "mimetype": "text/x-python",
535
   "name": "python",
536
   "nbconvert_exporter": "python",
537
   "pygments_lexer": "ipython3",
538
   "version": "3.10.12"
539
  }
540
 },
541
 "nbformat": 4,
542
 "nbformat_minor": 5
543
}