--- a +++ b/data_process.ipynb @@ -0,0 +1,543 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 7, + "id": "30cc0acc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "找到的CSV文件: ['train.csv']\n", + "\n", + "读取文件: train.csv\n", + "CSV文件列名: ['ImageId', 'MaskId', 'Label', 'PointX', 'PointY']\n", + "CSV文件前5行:\n", + " ImageId MaskId Label \\\n", + "0 a4c_PatientA0030_a4c_20S_0.jpg a4c_PatientA0030_a4c_20S_0_LV_0.png LV \n", + "1 a4c_PatientA0030_a4c_20S_0.jpg a4c_PatientA0030_a4c_20S_0_M_1.png M \n", + "2 a4c_PatientA0030_a4c_20S_0.jpg a4c_PatientA0030_a4c_20S_0_LA_2.png LA \n", + "3 a4c_PatientA0030_a4c_20S_0.jpg a4c_PatientA0030_a4c_20S_0_RV_3.png RV \n", + "4 a4c_PatientA0030_a4c_20S_0.jpg a4c_PatientA0030_a4c_20S_0_RA_4.png RA \n", + "\n", + " PointX PointY \n", + "0 373 164 \n", + "1 522 296 \n", + "2 454 404 \n", + "3 289 287 \n", + "4 334 457 \n", + "CSV文件共有 6620 行\n", + "\n", + "图像文件夹存在: True\n", + "图像文件夹中的文件数量: 1808\n", + "图像文件夹中的前5个文件: ['a4c_PatientA0195_a4c_1D_141.jpg', 'a2c_PatientA0107_a2c_20S_959.jpg', 'a4c_PatientA0141_a4c_45D_508.jpg', 'a4c_PatientA0047_a4c_41D_63.jpg', 'a2c_PatientD0059_a2c_27.jsonS_996.jpg']\n", + "\n", + "标注文件夹存在: True\n", + "标注文件夹中的文件数量: 6620\n", + "标注文件夹中的前5个文件: ['a3c_PatientA0138_a3c_27D_1769_M_2.png', 'a4c_PatientA0157_a4c_28S_549_LA_1.png', 'a4c_PatientC0035_a4c_83D_35_M_1.png', 'a4c_PatientA0089_a4c_52D_512_LV_3.png', 'a4c_PatientC0038_a4c_43D_563_LV_0.png']\n" + ] + } + ], + "source": [ + "import os\n", + "import pandas as pd\n", + "\n", + "# 您的数据路径\n", + "data_path = \"/media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train\"\n", + "\n", + "# 查找CSV文件\n", + "csv_files = []\n", + "for file in os.listdir(data_path):\n", + " if file.endswith('.csv'):\n", + " csv_files.append(file)\n", + "\n", + "print(f\"找到的CSV文件: {csv_files}\")\n", + "\n", + "# 如果存在CSV文件,读取并显示其内容\n", + "if csv_files:\n", + " for csv_file in csv_files:\n", + " file_path = os.path.join(data_path, csv_file)\n", + " print(f\"\\n读取文件: {csv_file}\")\n", + " try:\n", + " df = pd.read_csv(file_path)\n", + " print(f\"CSV文件列名: {df.columns.tolist()}\")\n", + " print(f\"CSV文件前5行:\")\n", + " print(df.head(5))\n", + " print(f\"CSV文件共有 {len(df)} 行\")\n", + " except Exception as e:\n", + " print(f\"读取文件出错: {e}\")\n", + "else:\n", + " print(\"数据目录中没有找到CSV文件\")\n", + "\n", + "# 检查图像和标注文件夹\n", + "images_dir = os.path.join(data_path, \"JPEGImages\")\n", + "annot_dir = os.path.join(data_path, \"Annotations\")\n", + "\n", + "print(f\"\\n图像文件夹存在: {os.path.exists(images_dir)}\")\n", + "if os.path.exists(images_dir):\n", + " print(f\"图像文件夹中的文件数量: {len(os.listdir(images_dir))}\")\n", + " print(f\"图像文件夹中的前5个文件: {os.listdir(images_dir)[:5]}\")\n", + "\n", + "print(f\"\\n标注文件夹存在: {os.path.exists(annot_dir)}\")\n", + "if os.path.exists(annot_dir):\n", + " print(f\"标注文件夹中的文件数量: {len(os.listdir(annot_dir))}\")\n", + " print(f\"标注文件夹中的前5个文件: {os.listdir(annot_dir)[:5]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "79042bdd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "找到 1808 个JSON文件在目录: /media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/sam2/sam2/data/heart_chambers_dataset\n", + "子目录: ['a4c', 'a2c', 'a3c']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "处理文件: 0%| | 3/1808 [00:00<01:03, 28.51it/s]/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 21407 (\\N{CJK UNIFIED IDEOGRAPH-539F}) missing from font(s) DejaVu Sans.\n", + " plt.tight_layout()\n", + "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 22987 (\\N{CJK UNIFIED IDEOGRAPH-59CB}) missing from font(s) DejaVu Sans.\n", + " plt.tight_layout()\n", + "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 22270 (\\N{CJK UNIFIED IDEOGRAPH-56FE}) missing from font(s) DejaVu Sans.\n", + " plt.tight_layout()\n", + "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 20687 (\\N{CJK UNIFIED IDEOGRAPH-50CF}) missing from font(s) DejaVu Sans.\n", + " plt.tight_layout()\n", + "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 25513 (\\N{CJK UNIFIED IDEOGRAPH-63A9}) missing from font(s) DejaVu Sans.\n", + " plt.tight_layout()\n", + "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 30721 (\\N{CJK UNIFIED IDEOGRAPH-7801}) missing from font(s) DejaVu Sans.\n", + " plt.tight_layout()\n", + "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 19981 (\\N{CJK UNIFIED IDEOGRAPH-4E0D}) missing from font(s) DejaVu Sans.\n", + " plt.tight_layout()\n", + "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 21516 (\\N{CJK UNIFIED IDEOGRAPH-540C}) missing from font(s) DejaVu Sans.\n", + " plt.tight_layout()\n", + "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 39068 (\\N{CJK UNIFIED IDEOGRAPH-989C}) missing from font(s) DejaVu Sans.\n", + " plt.tight_layout()\n", + "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 33394 (\\N{CJK UNIFIED IDEOGRAPH-8272}) missing from font(s) DejaVu Sans.\n", + " plt.tight_layout()\n", + "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 24102 (\\N{CJK UNIFIED IDEOGRAPH-5E26}) missing from font(s) DejaVu Sans.\n", + " plt.tight_layout()\n", + "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 26377 (\\N{CJK UNIFIED IDEOGRAPH-6709}) missing from font(s) DejaVu Sans.\n", + " plt.tight_layout()\n", + "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 21644 (\\N{CJK UNIFIED IDEOGRAPH-548C}) missing from font(s) DejaVu Sans.\n", + " plt.tight_layout()\n", + "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 28857 (\\N{CJK UNIFIED IDEOGRAPH-70B9}) missing from font(s) DejaVu Sans.\n", + " plt.tight_layout()\n", + "/tmp/ipykernel_2248352/1538416663.py:151: UserWarning: Glyph 30340 (\\N{CJK UNIFIED IDEOGRAPH-7684}) missing from font(s) DejaVu Sans.\n", + " plt.tight_layout()\n", + "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 21407 (\\N{CJK UNIFIED IDEOGRAPH-539F}) missing from font(s) DejaVu Sans.\n", + " plt.savefig(output_path)\n", + "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 22987 (\\N{CJK UNIFIED IDEOGRAPH-59CB}) missing from font(s) DejaVu Sans.\n", + " plt.savefig(output_path)\n", + "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 22270 (\\N{CJK UNIFIED IDEOGRAPH-56FE}) missing from font(s) DejaVu Sans.\n", + " plt.savefig(output_path)\n", + "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 20687 (\\N{CJK UNIFIED IDEOGRAPH-50CF}) missing from font(s) DejaVu Sans.\n", + " plt.savefig(output_path)\n", + "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 25513 (\\N{CJK UNIFIED IDEOGRAPH-63A9}) missing from font(s) DejaVu Sans.\n", + " plt.savefig(output_path)\n", + "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 30721 (\\N{CJK UNIFIED IDEOGRAPH-7801}) missing from font(s) DejaVu Sans.\n", + " plt.savefig(output_path)\n", + "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 19981 (\\N{CJK UNIFIED IDEOGRAPH-4E0D}) missing from font(s) DejaVu Sans.\n", + " plt.savefig(output_path)\n", + "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 21516 (\\N{CJK UNIFIED IDEOGRAPH-540C}) missing from font(s) DejaVu Sans.\n", + " plt.savefig(output_path)\n", + "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 39068 (\\N{CJK UNIFIED IDEOGRAPH-989C}) missing from font(s) DejaVu Sans.\n", + " plt.savefig(output_path)\n", + "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 33394 (\\N{CJK UNIFIED IDEOGRAPH-8272}) missing from font(s) DejaVu Sans.\n", + " plt.savefig(output_path)\n", + "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 24102 (\\N{CJK UNIFIED IDEOGRAPH-5E26}) missing from font(s) DejaVu Sans.\n", + " plt.savefig(output_path)\n", + "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 26377 (\\N{CJK UNIFIED IDEOGRAPH-6709}) missing from font(s) DejaVu Sans.\n", + " plt.savefig(output_path)\n", + "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 21644 (\\N{CJK UNIFIED IDEOGRAPH-548C}) missing from font(s) DejaVu Sans.\n", + " plt.savefig(output_path)\n", + "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 28857 (\\N{CJK UNIFIED IDEOGRAPH-70B9}) missing from font(s) DejaVu Sans.\n", + " plt.savefig(output_path)\n", + "/tmp/ipykernel_2248352/1538416663.py:152: UserWarning: Glyph 30340 (\\N{CJK UNIFIED IDEOGRAPH-7684}) missing from font(s) DejaVu Sans.\n", + " plt.savefig(output_path)\n", + "处理文件: 1%| | 18/1808 [00:01<01:33, 19.19it/s]" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "处理文件: 100%|██████████| 1808/1808 [02:45<00:00, 10.89it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "数据集创建完成! 总共处理了 1808 个文件,生成了 6620 个掩码。\n", + "图像保存在: /media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train/JPEGImages\n", + "掩码保存在: /media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train/Annotations\n", + "可视化结果保存在: /media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train/Visualization\n", + "训练CSV文件: /media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train/train.csv\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "\n", + "import json\n", + "import base64\n", + "import numpy as np\n", + "import cv2\n", + "from PIL import Image\n", + "import io\n", + "import shutil\n", + "from shapely.geometry import Polygon\n", + "import random\n", + "import pandas as pd\n", + "from tqdm import tqdm\n", + "import matplotlib.pyplot as plt\n", + "\n", + "def ensure_dir(directory):\n", + " \"\"\"确保目录存在,如果不存在则创建\"\"\"\n", + " if not os.path.exists(directory):\n", + " os.makedirs(directory)\n", + "\n", + "def decode_base64_to_image(base64_string):\n", + " \"\"\"将Base64编码的图像数据解码为PIL图像\"\"\"\n", + " if not base64_string or len(base64_string) < 100: # 简单检查以避免处理截断的字符串\n", + " return None\n", + " \n", + " try:\n", + " # 尝试解码Base64字符串\n", + " image_data = base64.b64decode(base64_string)\n", + " image = Image.open(io.BytesIO(image_data))\n", + " return image\n", + " except Exception as e:\n", + " print(f\"解码Base64图像时出错: {e}\")\n", + " return None\n", + "\n", + "def draw_polygon_mask(width, height, points, label):\n", + " \"\"\"根据多边形点创建二值掩码\"\"\"\n", + " # 创建空白掩码\n", + " mask = np.zeros((height, width), dtype=np.uint8)\n", + " \n", + " # 将点格式转换为OpenCV需要的格式\n", + " points_array = np.array(points, dtype=np.int32)\n", + " \n", + " # 绘制填充多边形\n", + " cv2.fillPoly(mask, [points_array], 255)\n", + " \n", + " return mask\n", + "\n", + "def generate_point_in_mask(mask, num_points=1):\n", + " \"\"\"在掩码内生成随机点\"\"\"\n", + " # 找到掩码中值为255的像素位置\n", + " y_indices, x_indices = np.where(mask == 255)\n", + " \n", + " if len(y_indices) == 0:\n", + " return [] # 如果掩码为空,返回空列表\n", + " \n", + " points = []\n", + " for _ in range(num_points):\n", + " # 随机选择一个位置\n", + " idx = random.randint(0, len(y_indices) - 1)\n", + " x, y = int(x_indices[idx]), int(y_indices[idx])\n", + " points.append((x, y))\n", + " \n", + " return points\n", + "\n", + "def visualize_masks(image, masks, labels, points, output_path):\n", + " \"\"\"\n", + " 可视化图像、掩码和点\n", + " \n", + " 参数:\n", + " - image: 原始图像 (PIL图像对象或numpy数组)\n", + " - masks: 掩码列表 (每个掩码是numpy数组)\n", + " - labels: 标签列表\n", + " - points: 每个掩码中的点列表 (每个点是(x,y)元组)\n", + " - output_path: 输出图像路径\n", + " \"\"\"\n", + " # 确保image是numpy数组\n", + " if isinstance(image, Image.Image):\n", + " img_np = np.array(image)\n", + " else:\n", + " img_np = image.copy()\n", + " \n", + " # 创建可视化图像\n", + " vis_img = img_np.copy()\n", + " \n", + " # 颜色映射 (标签到RGB颜色)\n", + " color_map = {\n", + " 'LV': (255, 0, 0), # 红色\n", + " 'LA': (0, 255, 0), # 绿色\n", + " 'RV': (0, 0, 255), # 蓝色\n", + " 'RA': (0, 255, 255), # 黄色\n", + " 'M': (255, 0, 255), # 紫色\n", + " 'unknown': (128, 128, 128) # 灰色\n", + " }\n", + " \n", + " # 创建图形和子图\n", + " fig, axes = plt.subplots(1, 3, figsize=(18, 6))\n", + " \n", + " # 显示原始图像\n", + " axes[0].imshow(img_np)\n", + " axes[0].set_title('原始图像')\n", + " axes[0].axis('off')\n", + " \n", + " # 创建合并掩码的彩色图像\n", + " color_mask = np.zeros((*img_np.shape[:2], 3), dtype=np.uint8)\n", + " \n", + " # 将掩码叠加到图像上,使用不同颜色\n", + " for mask, label, point in zip(masks, labels, points):\n", + " # 获取标签对应的颜色\n", + " color = color_map.get(label, color_map['unknown'])\n", + " \n", + " # 添加掩码颜色\n", + " for i in range(3):\n", + " color_mask[:, :, i][mask > 0] = color[i]\n", + " \n", + " # 显示彩色掩码\n", + " axes[1].imshow(color_mask)\n", + " axes[1].set_title('掩码 (不同颜色)')\n", + " axes[1].axis('off')\n", + " \n", + " # 将掩码和点叠加到图像上\n", + " overlay = img_np.copy()\n", + " for mask, label, point in zip(masks, labels, points):\n", + " # 获取标签对应的颜色\n", + " color = color_map.get(label, color_map['unknown'])\n", + " \n", + " # 叠加掩码\n", + " for i in range(3):\n", + " overlay[:, :, i] = np.where(mask > 0, \n", + " (overlay[:, :, i] * 0.7 + color[i] * 0.3).astype(np.uint8), \n", + " overlay[:, :, i])\n", + " \n", + " # 在图像上标记点\n", + " if point:\n", + " # 绘制点\n", + " cv2.circle(overlay, point, 5, (255, 255, 255), -1) # 白色实心圆\n", + " # 确保点在图像内\n", + " y, x = point[1], point[0]\n", + " if 0 <= y < overlay.shape[0] and 0 <= x < overlay.shape[1]:\n", + " cv2.circle(overlay, point, 5, (0, 0, 0), 1) # 黑色圆边框\n", + " \n", + " # 显示带有掩码和点的图像\n", + " axes[2].imshow(overlay)\n", + " axes[2].set_title('带有掩码和点的图像')\n", + " for i, (label, color) in enumerate(color_map.items()):\n", + " if label in labels:\n", + " # 将RGB颜色转换为0-1范围\n", + " normalized_color = [c/255 for c in color]\n", + " axes[2].plot([], [], 'o', color=normalized_color, label=label)\n", + " axes[2].legend(loc='upper right')\n", + " axes[2].axis('off')\n", + " \n", + " # 保存图像\n", + " plt.tight_layout()\n", + " plt.savefig(output_path)\n", + " plt.close(fig)\n", + "\n", + "def process_json_file(json_file, output_dir, index, subdir=\"\"):\n", + " \"\"\"处理单个JSON文件,生成图像和掩码\"\"\"\n", + " try:\n", + " with open(json_file, 'r') as f:\n", + " data = json.load(f)\n", + " \n", + " # 提取文件名(不含扩展名)\n", + " base_name = os.path.splitext(os.path.basename(json_file))[0]\n", + " \n", + " # 使用子目录和索引创建唯一文件名\n", + " if subdir:\n", + " unique_id = f\"{subdir}_{base_name}_{index}\"\n", + " else:\n", + " unique_id = f\"{base_name}_{index}\"\n", + " \n", + " # 解码Base64图像\n", + " image_data = data.get('imageData', '')\n", + " image = decode_base64_to_image(image_data)\n", + " \n", + " if image is None:\n", + " print(f\"无法解码图像: {json_file}\")\n", + " return None\n", + " \n", + " # 获取图像尺寸\n", + " width, height = image.size\n", + " \n", + " # 保存原始图像\n", + " image_path = os.path.join(output_dir, 'JPEGImages', f\"{unique_id}.jpg\")\n", + " image.save(image_path)\n", + " \n", + " shapes = data.get('shapes', [])\n", + " \n", + " # 用于存储每个掩码的信息\n", + " mask_info = []\n", + " \n", + " # 存储可视化相关的数据\n", + " all_masks = []\n", + " all_labels = []\n", + " all_points = []\n", + " \n", + " # 处理每个形状/掩码\n", + " for i, shape in enumerate(shapes):\n", + " label = shape.get('label', 'unknown')\n", + " points = shape.get('points', [])\n", + " \n", + " if not points:\n", + " continue\n", + " \n", + " # 创建掩码\n", + " mask = draw_polygon_mask(width, height, points, label)\n", + " \n", + " # 生成掩码文件名\n", + " mask_filename = f\"{unique_id}_{label}_{i}.png\"\n", + " mask_path = os.path.join(output_dir, 'Annotations', mask_filename)\n", + " \n", + " # 保存掩码\n", + " cv2.imwrite(mask_path, mask)\n", + " \n", + " # 在掩码内生成一个随机点\n", + " random_points = generate_point_in_mask(mask)\n", + " \n", + " if random_points:\n", + " mask_info.append({\n", + " 'ImageId': f\"{unique_id}.jpg\",\n", + " 'MaskId': mask_filename,\n", + " 'Label': label,\n", + " 'PointX': random_points[0][0],\n", + " 'PointY': random_points[0][1]\n", + " })\n", + " \n", + " # 存储可视化数据\n", + " all_masks.append(mask)\n", + " all_labels.append(label)\n", + " all_points.append(random_points[0])\n", + " \n", + " # 如果有掩码,则随机可视化一个\n", + " if all_masks:\n", + " # 创建可视化目录\n", + " vis_dir = os.path.join(output_dir, 'Visualization')\n", + " ensure_dir(vis_dir)\n", + " \n", + " # 随机选择可视化\n", + " if random.random() < 0.2: # 20%的概率进行可视化\n", + " vis_path = os.path.join(vis_dir, f\"{unique_id}_visualization.png\")\n", + " visualize_masks(image, all_masks, all_labels, all_points, vis_path)\n", + " \n", + " return mask_info\n", + " \n", + " except Exception as e:\n", + " print(f\"处理文件 {json_file} 时出错: {e}\")\n", + " import traceback\n", + " traceback.print_exc()\n", + " return None\n", + "\n", + "def get_all_json_files(root_dir):\n", + " \"\"\"递归获取目录及其子目录中的所有JSON文件\"\"\"\n", + " json_files = []\n", + " subdirs = []\n", + " \n", + " # 遍历目录\n", + " for dirpath, dirnames, filenames in os.walk(root_dir):\n", + " rel_path = os.path.relpath(dirpath, root_dir)\n", + " if rel_path == '.':\n", + " rel_path = ''\n", + " \n", + " # 获取子目录名称(仅一级子目录)\n", + " if dirpath == root_dir:\n", + " subdirs.extend(dirnames)\n", + " \n", + " # 添加JSON文件和它们的相对路径\n", + " for filename in filenames:\n", + " if filename.endswith('.json'):\n", + " json_files.append((os.path.join(dirpath, filename), rel_path))\n", + " \n", + " return json_files, subdirs\n", + "\n", + "def create_dataset_structure(json_dir, output_dir):\n", + " \"\"\"创建SAM模型训练所需的数据集结构\"\"\"\n", + " # 清空并创建输出目录\n", + " if os.path.exists(output_dir):\n", + " shutil.rmtree(output_dir)\n", + " \n", + " # 创建必要的子目录\n", + " ensure_dir(os.path.join(output_dir, 'JPEGImages'))\n", + " ensure_dir(os.path.join(output_dir, 'Annotations'))\n", + " ensure_dir(os.path.join(output_dir, 'Visualization'))\n", + " \n", + " # 获取所有JSON文件和子目录\n", + " json_files, subdirs = get_all_json_files(json_dir)\n", + " \n", + " print(f\"找到 {len(json_files)} 个JSON文件在目录: {json_dir}\")\n", + " print(f\"子目录: {subdirs}\")\n", + " \n", + " all_mask_info = []\n", + " \n", + " # 处理每个JSON文件\n", + " for i, (json_file, subdir) in enumerate(tqdm(json_files, desc=\"处理文件\")):\n", + " mask_info = process_json_file(json_file, output_dir, i, subdir)\n", + " if mask_info:\n", + " all_mask_info.extend(mask_info)\n", + " \n", + " # 创建CSV文件\n", + " df = pd.DataFrame(all_mask_info)\n", + " csv_path = os.path.join(output_dir, 'train.csv')\n", + " df.to_csv(csv_path, index=False)\n", + " \n", + " print(f\"数据集创建完成! 总共处理了 {len(json_files)} 个文件,生成了 {len(all_mask_info)} 个掩码。\")\n", + " print(f\"图像保存在: {os.path.join(output_dir, 'JPEGImages')}\")\n", + " print(f\"掩码保存在: {os.path.join(output_dir, 'Annotations')}\")\n", + " print(f\"可视化结果保存在: {os.path.join(output_dir, 'Visualization')}\")\n", + " print(f\"训练CSV文件: {csv_path}\")\n", + "\n", + "# 示例用法\n", + "if __name__ == \"__main__\":\n", + " # 如果你有JSON文件目录\n", + " json_dir = \"/media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/sam2/sam2/data/heart_chambers_dataset\" # 替换为你的JSON文件目录\n", + " output_dir = \"/media/ps/data/Datasets/300例心脏分割/300例勾画图像/FineTune_SAM2/Sam2_new/sam2/data_train\"\n", + " create_dataset_structure(json_dir, output_dir)\n", + " \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa78b4c5", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}