Diff of /data_process.ipynb [000000] .. [d255cc]

Switch to side-by-side view

--- a
+++ b/data_process.ipynb
@@ -0,0 +1,543 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "id": "30cc0acc",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "找到的CSV文件: ['train.csv']\n",
+      "\n",
+      "读取文件: train.csv\n",
+      "CSV文件列名: ['ImageId', 'MaskId', 'Label', 'PointX', 'PointY']\n",
+      "CSV文件前5行:\n",
+      "                          ImageId                               MaskId Label  \\\n",
+      "0  a4c_PatientA0030_a4c_20S_0.jpg  a4c_PatientA0030_a4c_20S_0_LV_0.png    LV   \n",
+      "1  a4c_PatientA0030_a4c_20S_0.jpg   a4c_PatientA0030_a4c_20S_0_M_1.png     M   \n",
+      "2  a4c_PatientA0030_a4c_20S_0.jpg  a4c_PatientA0030_a4c_20S_0_LA_2.png    LA   \n",
+      "3  a4c_PatientA0030_a4c_20S_0.jpg  a4c_PatientA0030_a4c_20S_0_RV_3.png    RV   \n",
+      "4  a4c_PatientA0030_a4c_20S_0.jpg  a4c_PatientA0030_a4c_20S_0_RA_4.png    RA   \n",
+      "\n",
+      "   PointX  PointY  \n",
+      "0     373     164  \n",
+      "1     522     296  \n",
+      "2     454     404  \n",
+      "3     289     287  \n",
+      "4     334     457  \n",
+      "CSV文件共有 6620 行\n",
+      "\n",
+      "图像文件夹存在: True\n",
+      "图像文件夹中的文件数量: 1808\n",
+      "图像文件夹中的前5个文件: ['a4c_PatientA0195_a4c_1D_141.jpg', 'a2c_PatientA0107_a2c_20S_959.jpg', 'a4c_PatientA0141_a4c_45D_508.jpg', 'a4c_PatientA0047_a4c_41D_63.jpg', 'a2c_PatientD0059_a2c_27.jsonS_996.jpg']\n",
+      "\n",
+      "标注文件夹存在: True\n",
+      "标注文件夹中的文件数量: 6620\n",
+      "标注文件夹中的前5个文件: ['a3c_PatientA0138_a3c_27D_1769_M_2.png', 'a4c_PatientA0157_a4c_28S_549_LA_1.png', 'a4c_PatientC0035_a4c_83D_35_M_1.png', 'a4c_PatientA0089_a4c_52D_512_LV_3.png', 'a4c_PatientC0038_a4c_43D_563_LV_0.png']\n"
+     ]
+    }
+   ],
+   "source": [
+    "import os\n",
+    "import pandas as pd\n",
+    "\n",
+    "# 您的数据路径\n",
+    "data_path = \"/media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train\"\n",
+    "\n",
+    "# 查找CSV文件\n",
+    "csv_files = []\n",
+    "for file in os.listdir(data_path):\n",
+    "    if file.endswith('.csv'):\n",
+    "        csv_files.append(file)\n",
+    "\n",
+    "print(f\"找到的CSV文件: {csv_files}\")\n",
+    "\n",
+    "# 如果存在CSV文件,读取并显示其内容\n",
+    "if csv_files:\n",
+    "    for csv_file in csv_files:\n",
+    "        file_path = os.path.join(data_path, csv_file)\n",
+    "        print(f\"\\n读取文件: {csv_file}\")\n",
+    "        try:\n",
+    "            df = pd.read_csv(file_path)\n",
+    "            print(f\"CSV文件列名: {df.columns.tolist()}\")\n",
+    "            print(f\"CSV文件前5行:\")\n",
+    "            print(df.head(5))\n",
+    "            print(f\"CSV文件共有 {len(df)} 行\")\n",
+    "        except Exception as e:\n",
+    "            print(f\"读取文件出错: {e}\")\n",
+    "else:\n",
+    "    print(\"数据目录中没有找到CSV文件\")\n",
+    "\n",
+    "# 检查图像和标注文件夹\n",
+    "images_dir = os.path.join(data_path, \"JPEGImages\")\n",
+    "annot_dir = os.path.join(data_path, \"Annotations\")\n",
+    "\n",
+    "print(f\"\\n图像文件夹存在: {os.path.exists(images_dir)}\")\n",
+    "if os.path.exists(images_dir):\n",
+    "    print(f\"图像文件夹中的文件数量: {len(os.listdir(images_dir))}\")\n",
+    "    print(f\"图像文件夹中的前5个文件: {os.listdir(images_dir)[:5]}\")\n",
+    "\n",
+    "print(f\"\\n标注文件夹存在: {os.path.exists(annot_dir)}\")\n",
+    "if os.path.exists(annot_dir):\n",
+    "    print(f\"标注文件夹中的文件数量: {len(os.listdir(annot_dir))}\")\n",
+    "    print(f\"标注文件夹中的前5个文件: {os.listdir(annot_dir)[:5]}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "id": "79042bdd",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "找到 1808 个JSON文件在目录: /media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/sam2/sam2/data/heart_chambers_dataset\n",
+      "子目录: ['a4c', 'a2c', 'a3c']\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "处理文件:   0%|          | 3/1808 [00:00<01:03, 28.51it/s]/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 21407 (\\N{CJK UNIFIED IDEOGRAPH-539F}) missing from font(s) DejaVu Sans.\n",
+      "  plt.tight_layout()\n",
+      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 22987 (\\N{CJK UNIFIED IDEOGRAPH-59CB}) missing from font(s) DejaVu Sans.\n",
+      "  plt.tight_layout()\n",
+      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 22270 (\\N{CJK UNIFIED IDEOGRAPH-56FE}) missing from font(s) DejaVu Sans.\n",
+      "  plt.tight_layout()\n",
+      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 20687 (\\N{CJK UNIFIED IDEOGRAPH-50CF}) missing from font(s) DejaVu Sans.\n",
+      "  plt.tight_layout()\n",
+      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 25513 (\\N{CJK UNIFIED IDEOGRAPH-63A9}) missing from font(s) DejaVu Sans.\n",
+      "  plt.tight_layout()\n",
+      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 30721 (\\N{CJK UNIFIED IDEOGRAPH-7801}) missing from font(s) DejaVu Sans.\n",
+      "  plt.tight_layout()\n",
+      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 19981 (\\N{CJK UNIFIED IDEOGRAPH-4E0D}) missing from font(s) DejaVu Sans.\n",
+      "  plt.tight_layout()\n",
+      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 21516 (\\N{CJK UNIFIED IDEOGRAPH-540C}) missing from font(s) DejaVu Sans.\n",
+      "  plt.tight_layout()\n",
+      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 39068 (\\N{CJK UNIFIED IDEOGRAPH-989C}) missing from font(s) DejaVu Sans.\n",
+      "  plt.tight_layout()\n",
+      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 33394 (\\N{CJK UNIFIED IDEOGRAPH-8272}) missing from font(s) DejaVu Sans.\n",
+      "  plt.tight_layout()\n",
+      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 24102 (\\N{CJK UNIFIED IDEOGRAPH-5E26}) missing from font(s) DejaVu Sans.\n",
+      "  plt.tight_layout()\n",
+      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 26377 (\\N{CJK UNIFIED IDEOGRAPH-6709}) missing from font(s) DejaVu Sans.\n",
+      "  plt.tight_layout()\n",
+      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 21644 (\\N{CJK UNIFIED IDEOGRAPH-548C}) missing from font(s) DejaVu Sans.\n",
+      "  plt.tight_layout()\n",
+      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 28857 (\\N{CJK UNIFIED IDEOGRAPH-70B9}) missing from font(s) DejaVu Sans.\n",
+      "  plt.tight_layout()\n",
+      "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 30340 (\\N{CJK UNIFIED IDEOGRAPH-7684}) missing from font(s) DejaVu Sans.\n",
+      "  plt.tight_layout()\n",
+      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 21407 (\\N{CJK UNIFIED IDEOGRAPH-539F}) missing from font(s) DejaVu Sans.\n",
+      "  plt.savefig(output_path)\n",
+      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 22987 (\\N{CJK UNIFIED IDEOGRAPH-59CB}) missing from font(s) DejaVu Sans.\n",
+      "  plt.savefig(output_path)\n",
+      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 22270 (\\N{CJK UNIFIED IDEOGRAPH-56FE}) missing from font(s) DejaVu Sans.\n",
+      "  plt.savefig(output_path)\n",
+      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 20687 (\\N{CJK UNIFIED IDEOGRAPH-50CF}) missing from font(s) DejaVu Sans.\n",
+      "  plt.savefig(output_path)\n",
+      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 25513 (\\N{CJK UNIFIED IDEOGRAPH-63A9}) missing from font(s) DejaVu Sans.\n",
+      "  plt.savefig(output_path)\n",
+      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 30721 (\\N{CJK UNIFIED IDEOGRAPH-7801}) missing from font(s) DejaVu Sans.\n",
+      "  plt.savefig(output_path)\n",
+      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 19981 (\\N{CJK UNIFIED IDEOGRAPH-4E0D}) missing from font(s) DejaVu Sans.\n",
+      "  plt.savefig(output_path)\n",
+      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 21516 (\\N{CJK UNIFIED IDEOGRAPH-540C}) missing from font(s) DejaVu Sans.\n",
+      "  plt.savefig(output_path)\n",
+      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 39068 (\\N{CJK UNIFIED IDEOGRAPH-989C}) missing from font(s) DejaVu Sans.\n",
+      "  plt.savefig(output_path)\n",
+      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 33394 (\\N{CJK UNIFIED IDEOGRAPH-8272}) missing from font(s) DejaVu Sans.\n",
+      "  plt.savefig(output_path)\n",
+      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 24102 (\\N{CJK UNIFIED IDEOGRAPH-5E26}) missing from font(s) DejaVu Sans.\n",
+      "  plt.savefig(output_path)\n",
+      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 26377 (\\N{CJK UNIFIED IDEOGRAPH-6709}) missing from font(s) DejaVu Sans.\n",
+      "  plt.savefig(output_path)\n",
+      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 21644 (\\N{CJK UNIFIED IDEOGRAPH-548C}) missing from font(s) DejaVu Sans.\n",
+      "  plt.savefig(output_path)\n",
+      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 28857 (\\N{CJK UNIFIED IDEOGRAPH-70B9}) missing from font(s) DejaVu Sans.\n",
+      "  plt.savefig(output_path)\n",
+      "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 30340 (\\N{CJK UNIFIED IDEOGRAPH-7684}) missing from font(s) DejaVu Sans.\n",
+      "  plt.savefig(output_path)\n",
+      "处理文件:   1%|          | 18/1808 [00:01<01:33, 19.19it/s]"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "处理文件: 100%|██████████| 1808/1808 [02:45<00:00, 10.89it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "数据集创建完成! 总共处理了 1808 个文件,生成了 6620 个掩码。\n",
+      "图像保存在: /media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train/JPEGImages\n",
+      "掩码保存在: /media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train/Annotations\n",
+      "可视化结果保存在: /media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train/Visualization\n",
+      "训练CSV文件: /media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train/train.csv\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "\n",
+    "import json\n",
+    "import base64\n",
+    "import numpy as np\n",
+    "import cv2\n",
+    "from PIL import Image\n",
+    "import io\n",
+    "import shutil\n",
+    "from shapely.geometry import Polygon\n",
+    "import random\n",
+    "import pandas as pd\n",
+    "from tqdm import tqdm\n",
+    "import matplotlib.pyplot as plt\n",
+    "\n",
+    "def ensure_dir(directory):\n",
+    "    \"\"\"确保目录存在,如果不存在则创建\"\"\"\n",
+    "    if not os.path.exists(directory):\n",
+    "        os.makedirs(directory)\n",
+    "\n",
+    "def decode_base64_to_image(base64_string):\n",
+    "    \"\"\"将Base64编码的图像数据解码为PIL图像\"\"\"\n",
+    "    if not base64_string or len(base64_string) < 100:  # 简单检查以避免处理截断的字符串\n",
+    "        return None\n",
+    "    \n",
+    "    try:\n",
+    "        # 尝试解码Base64字符串\n",
+    "        image_data = base64.b64decode(base64_string)\n",
+    "        image = Image.open(io.BytesIO(image_data))\n",
+    "        return image\n",
+    "    except Exception as e:\n",
+    "        print(f\"解码Base64图像时出错: {e}\")\n",
+    "        return None\n",
+    "\n",
+    "def draw_polygon_mask(width, height, points, label):\n",
+    "    \"\"\"根据多边形点创建二值掩码\"\"\"\n",
+    "    # 创建空白掩码\n",
+    "    mask = np.zeros((height, width), dtype=np.uint8)\n",
+    "    \n",
+    "    # 将点格式转换为OpenCV需要的格式\n",
+    "    points_array = np.array(points, dtype=np.int32)\n",
+    "    \n",
+    "    # 绘制填充多边形\n",
+    "    cv2.fillPoly(mask, [points_array], 255)\n",
+    "    \n",
+    "    return mask\n",
+    "\n",
+    "def generate_point_in_mask(mask, num_points=1):\n",
+    "    \"\"\"在掩码内生成随机点\"\"\"\n",
+    "    # 找到掩码中值为255的像素位置\n",
+    "    y_indices, x_indices = np.where(mask == 255)\n",
+    "    \n",
+    "    if len(y_indices) == 0:\n",
+    "        return []  # 如果掩码为空,返回空列表\n",
+    "    \n",
+    "    points = []\n",
+    "    for _ in range(num_points):\n",
+    "        # 随机选择一个位置\n",
+    "        idx = random.randint(0, len(y_indices) - 1)\n",
+    "        x, y = int(x_indices[idx]), int(y_indices[idx])\n",
+    "        points.append((x, y))\n",
+    "    \n",
+    "    return points\n",
+    "\n",
+    "def visualize_masks(image, masks, labels, points, output_path):\n",
+    "    \"\"\"\n",
+    "    可视化图像、掩码和点\n",
+    "    \n",
+    "    参数:\n",
+    "    - image: 原始图像 (PIL图像对象或numpy数组)\n",
+    "    - masks: 掩码列表 (每个掩码是numpy数组)\n",
+    "    - labels: 标签列表\n",
+    "    - points: 每个掩码中的点列表 (每个点是(x,y)元组)\n",
+    "    - output_path: 输出图像路径\n",
+    "    \"\"\"\n",
+    "    # 确保image是numpy数组\n",
+    "    if isinstance(image, Image.Image):\n",
+    "        img_np = np.array(image)\n",
+    "    else:\n",
+    "        img_np = image.copy()\n",
+    "    \n",
+    "    # 创建可视化图像\n",
+    "    vis_img = img_np.copy()\n",
+    "    \n",
+    "    # 颜色映射 (标签到RGB颜色)\n",
+    "    color_map = {\n",
+    "        'LV': (255, 0, 0),    # 红色\n",
+    "        'LA': (0, 255, 0),    # 绿色\n",
+    "        'RV': (0, 0, 255),    # 蓝色\n",
+    "        'RA': (0, 255, 255),  # 黄色\n",
+    "        'M': (255, 0, 255),   # 紫色\n",
+    "        'unknown': (128, 128, 128)  # 灰色\n",
+    "    }\n",
+    "    \n",
+    "    # 创建图形和子图\n",
+    "    fig, axes = plt.subplots(1, 3, figsize=(18, 6))\n",
+    "    \n",
+    "    # 显示原始图像\n",
+    "    axes[0].imshow(img_np)\n",
+    "    axes[0].set_title('原始图像')\n",
+    "    axes[0].axis('off')\n",
+    "    \n",
+    "    # 创建合并掩码的彩色图像\n",
+    "    color_mask = np.zeros((*img_np.shape[:2], 3), dtype=np.uint8)\n",
+    "    \n",
+    "    # 将掩码叠加到图像上,使用不同颜色\n",
+    "    for mask, label, point in zip(masks, labels, points):\n",
+    "        # 获取标签对应的颜色\n",
+    "        color = color_map.get(label, color_map['unknown'])\n",
+    "        \n",
+    "        # 添加掩码颜色\n",
+    "        for i in range(3):\n",
+    "            color_mask[:, :, i][mask > 0] = color[i]\n",
+    "    \n",
+    "    # 显示彩色掩码\n",
+    "    axes[1].imshow(color_mask)\n",
+    "    axes[1].set_title('掩码 (不同颜色)')\n",
+    "    axes[1].axis('off')\n",
+    "    \n",
+    "    # 将掩码和点叠加到图像上\n",
+    "    overlay = img_np.copy()\n",
+    "    for mask, label, point in zip(masks, labels, points):\n",
+    "        # 获取标签对应的颜色\n",
+    "        color = color_map.get(label, color_map['unknown'])\n",
+    "        \n",
+    "        # 叠加掩码\n",
+    "        for i in range(3):\n",
+    "            overlay[:, :, i] = np.where(mask > 0, \n",
+    "                                      (overlay[:, :, i] * 0.7 + color[i] * 0.3).astype(np.uint8), \n",
+    "                                      overlay[:, :, i])\n",
+    "        \n",
+    "        # 在图像上标记点\n",
+    "        if point:\n",
+    "            # 绘制点\n",
+    "            cv2.circle(overlay, point, 5, (255, 255, 255), -1)  # 白色实心圆\n",
+    "            # 确保点在图像内\n",
+    "            y, x = point[1], point[0]\n",
+    "            if 0 <= y < overlay.shape[0] and 0 <= x < overlay.shape[1]:\n",
+    "                cv2.circle(overlay, point, 5, (0, 0, 0), 1)  # 黑色圆边框\n",
+    "    \n",
+    "    # 显示带有掩码和点的图像\n",
+    "    axes[2].imshow(overlay)\n",
+    "    axes[2].set_title('带有掩码和点的图像')\n",
+    "    for i, (label, color) in enumerate(color_map.items()):\n",
+    "        if label in labels:\n",
+    "            # 将RGB颜色转换为0-1范围\n",
+    "            normalized_color = [c/255 for c in color]\n",
+    "            axes[2].plot([], [], 'o', color=normalized_color, label=label)\n",
+    "    axes[2].legend(loc='upper right')\n",
+    "    axes[2].axis('off')\n",
+    "    \n",
+    "    # 保存图像\n",
+    "    plt.tight_layout()\n",
+    "    plt.savefig(output_path)\n",
+    "    plt.close(fig)\n",
+    "\n",
+    "def process_json_file(json_file, output_dir, index, subdir=\"\"):\n",
+    "    \"\"\"处理单个JSON文件,生成图像和掩码\"\"\"\n",
+    "    try:\n",
+    "        with open(json_file, 'r') as f:\n",
+    "            data = json.load(f)\n",
+    "        \n",
+    "        # 提取文件名(不含扩展名)\n",
+    "        base_name = os.path.splitext(os.path.basename(json_file))[0]\n",
+    "        \n",
+    "        # 使用子目录和索引创建唯一文件名\n",
+    "        if subdir:\n",
+    "            unique_id = f\"{subdir}_{base_name}_{index}\"\n",
+    "        else:\n",
+    "            unique_id = f\"{base_name}_{index}\"\n",
+    "        \n",
+    "        # 解码Base64图像\n",
+    "        image_data = data.get('imageData', '')\n",
+    "        image = decode_base64_to_image(image_data)\n",
+    "        \n",
+    "        if image is None:\n",
+    "            print(f\"无法解码图像: {json_file}\")\n",
+    "            return None\n",
+    "        \n",
+    "        # 获取图像尺寸\n",
+    "        width, height = image.size\n",
+    "        \n",
+    "        # 保存原始图像\n",
+    "        image_path = os.path.join(output_dir, 'JPEGImages', f\"{unique_id}.jpg\")\n",
+    "        image.save(image_path)\n",
+    "        \n",
+    "        shapes = data.get('shapes', [])\n",
+    "        \n",
+    "        # 用于存储每个掩码的信息\n",
+    "        mask_info = []\n",
+    "        \n",
+    "        # 存储可视化相关的数据\n",
+    "        all_masks = []\n",
+    "        all_labels = []\n",
+    "        all_points = []\n",
+    "        \n",
+    "        # 处理每个形状/掩码\n",
+    "        for i, shape in enumerate(shapes):\n",
+    "            label = shape.get('label', 'unknown')\n",
+    "            points = shape.get('points', [])\n",
+    "            \n",
+    "            if not points:\n",
+    "                continue\n",
+    "                \n",
+    "            # 创建掩码\n",
+    "            mask = draw_polygon_mask(width, height, points, label)\n",
+    "            \n",
+    "            # 生成掩码文件名\n",
+    "            mask_filename = f\"{unique_id}_{label}_{i}.png\"\n",
+    "            mask_path = os.path.join(output_dir, 'Annotations', mask_filename)\n",
+    "            \n",
+    "            # 保存掩码\n",
+    "            cv2.imwrite(mask_path, mask)\n",
+    "            \n",
+    "            # 在掩码内生成一个随机点\n",
+    "            random_points = generate_point_in_mask(mask)\n",
+    "            \n",
+    "            if random_points:\n",
+    "                mask_info.append({\n",
+    "                    'ImageId': f\"{unique_id}.jpg\",\n",
+    "                    'MaskId': mask_filename,\n",
+    "                    'Label': label,\n",
+    "                    'PointX': random_points[0][0],\n",
+    "                    'PointY': random_points[0][1]\n",
+    "                })\n",
+    "                \n",
+    "                # 存储可视化数据\n",
+    "                all_masks.append(mask)\n",
+    "                all_labels.append(label)\n",
+    "                all_points.append(random_points[0])\n",
+    "        \n",
+    "        # 如果有掩码,则随机可视化一个\n",
+    "        if all_masks:\n",
+    "            # 创建可视化目录\n",
+    "            vis_dir = os.path.join(output_dir, 'Visualization')\n",
+    "            ensure_dir(vis_dir)\n",
+    "            \n",
+    "            # 随机选择可视化\n",
+    "            if random.random() < 0.2:  # 20%的概率进行可视化\n",
+    "                vis_path = os.path.join(vis_dir, f\"{unique_id}_visualization.png\")\n",
+    "                visualize_masks(image, all_masks, all_labels, all_points, vis_path)\n",
+    "        \n",
+    "        return mask_info\n",
+    "    \n",
+    "    except Exception as e:\n",
+    "        print(f\"处理文件 {json_file} 时出错: {e}\")\n",
+    "        import traceback\n",
+    "        traceback.print_exc()\n",
+    "        return None\n",
+    "\n",
+    "def get_all_json_files(root_dir):\n",
+    "    \"\"\"递归获取目录及其子目录中的所有JSON文件\"\"\"\n",
+    "    json_files = []\n",
+    "    subdirs = []\n",
+    "    \n",
+    "    # 遍历目录\n",
+    "    for dirpath, dirnames, filenames in os.walk(root_dir):\n",
+    "        rel_path = os.path.relpath(dirpath, root_dir)\n",
+    "        if rel_path == '.':\n",
+    "            rel_path = ''\n",
+    "            \n",
+    "        # 获取子目录名称(仅一级子目录)\n",
+    "        if dirpath == root_dir:\n",
+    "            subdirs.extend(dirnames)\n",
+    "            \n",
+    "        # 添加JSON文件和它们的相对路径\n",
+    "        for filename in filenames:\n",
+    "            if filename.endswith('.json'):\n",
+    "                json_files.append((os.path.join(dirpath, filename), rel_path))\n",
+    "    \n",
+    "    return json_files, subdirs\n",
+    "\n",
+    "def create_dataset_structure(json_dir, output_dir):\n",
+    "    \"\"\"创建SAM模型训练所需的数据集结构\"\"\"\n",
+    "    # 清空并创建输出目录\n",
+    "    if os.path.exists(output_dir):\n",
+    "        shutil.rmtree(output_dir)\n",
+    "    \n",
+    "    # 创建必要的子目录\n",
+    "    ensure_dir(os.path.join(output_dir, 'JPEGImages'))\n",
+    "    ensure_dir(os.path.join(output_dir, 'Annotations'))\n",
+    "    ensure_dir(os.path.join(output_dir, 'Visualization'))\n",
+    "    \n",
+    "    # 获取所有JSON文件和子目录\n",
+    "    json_files, subdirs = get_all_json_files(json_dir)\n",
+    "    \n",
+    "    print(f\"找到 {len(json_files)} 个JSON文件在目录: {json_dir}\")\n",
+    "    print(f\"子目录: {subdirs}\")\n",
+    "    \n",
+    "    all_mask_info = []\n",
+    "    \n",
+    "    # 处理每个JSON文件\n",
+    "    for i, (json_file, subdir) in enumerate(tqdm(json_files, desc=\"处理文件\")):\n",
+    "        mask_info = process_json_file(json_file, output_dir, i, subdir)\n",
+    "        if mask_info:\n",
+    "            all_mask_info.extend(mask_info)\n",
+    "    \n",
+    "    # 创建CSV文件\n",
+    "    df = pd.DataFrame(all_mask_info)\n",
+    "    csv_path = os.path.join(output_dir, 'train.csv')\n",
+    "    df.to_csv(csv_path, index=False)\n",
+    "    \n",
+    "    print(f\"数据集创建完成! 总共处理了 {len(json_files)} 个文件,生成了 {len(all_mask_info)} 个掩码。\")\n",
+    "    print(f\"图像保存在: {os.path.join(output_dir, 'JPEGImages')}\")\n",
+    "    print(f\"掩码保存在: {os.path.join(output_dir, 'Annotations')}\")\n",
+    "    print(f\"可视化结果保存在: {os.path.join(output_dir, 'Visualization')}\")\n",
+    "    print(f\"训练CSV文件: {csv_path}\")\n",
+    "\n",
+    "# 示例用法\n",
+    "if __name__ == \"__main__\":\n",
+    "    # 如果你有JSON文件目录\n",
+    "    json_dir = \"/media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/sam2/sam2/data/heart_chambers_dataset\"  # 替换为你的JSON文件目录\n",
+    "    output_dir = \"/media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train\"\n",
+    "    create_dataset_structure(json_dir, output_dir)\n",
+    "    \n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "aa78b4c5",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.10.12"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}