--- a +++ b/training.py @@ -0,0 +1,265 @@ +import os +import cv2 +import numpy as np +import pandas as pd +import torch +from torch.cuda.amp import GradScaler, autocast + +# 导入SAM2相关模块 +from sam2.build_sam import build_sam2 +from sam2.sam2_image_predictor import SAM2ImagePredictor + +def initialize_sam2_predictor(model_cfg_path, checkpoint_path, device="cuda"): + """ + 初始化SAM2预测器 + + Args: + model_cfg_path: SAM2模型配置文件路径 + checkpoint_path: 预训练模型检查点路径 + device: 使用的设备 ('cuda' 或 'cpu') + + Returns: + SAM2ImagePredictor: 初始化好的预测器实例 + """ + # 构建SAM2模型 + sam_model = build_sam2( + model_cfg_path, + checkpoint_path, + device="cuda" + ) + + # 创建预测器 + predictor = SAM2ImagePredictor(sam_model) + + # 冻结图像编码器参数以防止过拟合 + for param in predictor.model.image_encoder.parameters(): + param.requires_grad = False + + print("图像编码器已冻结,仅掩码解码器和提示编码器将进行微调") + + return predictor + +def read_batch(df, images_dir, masks_dir): + """ + 读取和预处理一批数据 + + Args: + df: 包含图像和掩码信息的DataFrame + images_dir: 图像目录路径 + masks_dir: 掩码目录路径 + + Returns: + tuple: (图像, 掩码, 点坐标, 点标签) + """ + # 随机选择一个图像 + image_ids = df['ImageId'].unique() + selected_image = np.random.choice(image_ids) + + # 获取该图像的所有掩码 + image_masks = df[df['ImageId'] == selected_image] + + # 读取图像 + image_path = os.path.join(images_dir, selected_image) + img = cv2.imread(image_path)[...,::-1] # BGR to RGB + + # 调整图像大小 + r = np.min([1024 / img.shape[1], 1024 / img.shape[0]]) + img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r))) + + masks = [] + points = [] + + # 处理每个掩码 + for _, row in image_masks.iterrows(): + mask_path = os.path.join(masks_dir, row['MaskId']) + mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) + + # 调整掩码大小,与图像一致 + mask = cv2.resize(mask, (int(mask.shape[1] * r), int(mask.shape[0] * r)), + interpolation=cv2.INTER_NEAREST) + + # 二值化掩码(如果不是二值的) + binary_mask = (mask > 0).astype(np.uint8) + masks.append(binary_mask) + + # 使用CSV中提供的点坐标,并根据缩放调整 + point_x = int(row['PointX'] * r) + point_y = int(row['PointY'] * r) + points.append([[point_x, point_y]]) + + return img, np.array(masks), np.array(points), np.ones([len(masks), 1]) + +def train_sam2_model(predictor, df, images_dir, masks_dir, max_iterations=50000, learning_rate=1e-5, weight_decay=4e-5): + """ + 训练SAM2模型 + + Args: + predictor: SAM2图像预测器实例 + df: 包含图像和掩码信息的DataFrame + images_dir: 图像目录路径 + masks_dir: 掩码目录路径 + max_iterations: 最大迭代次数 + learning_rate: 学习率 + weight_decay: 权重衰减率 + """ + # 启用训练模式 + predictor.model.sam_mask_decoder.train(True) # 启用掩码解码器的训练 + predictor.model.sam_prompt_encoder.train(True) # 启用提示编码器的训练 + + # 设置优化器 + optimizer = torch.optim.AdamW( + params=[p for p in predictor.model.parameters() if p.requires_grad], + lr=learning_rate, + weight_decay=weight_decay + ) + + # 设置混合精度训练 + scaler = GradScaler() + mean_iou = 0 + + # 创建保存模型的目录 + os.makedirs("models", exist_ok=True) + + print(f"开始训练,共{max_iterations}次迭代...") + + for itr in range(max_iterations): + # 使用混合精度训练 + with torch.amp.autocast('cuda'): + # 加载数据批次 + image, masks, points, labels = read_batch(df, images_dir, masks_dir) + + # 忽略空批次 + if len(masks) == 0: + continue + + # 确保数据格式正确 + gt_masks = torch.tensor(masks, dtype=torch.float32, device=predictor.device) + + # 对图像应用SAM图像编码器 + predictor.set_image(image) + + # 对每个点/掩码对进行处理 + batch_loss = 0 + batch_iou = 0 + + for i in range(len(points)): + point = points[i:i+1] + gt_mask = gt_masks[i:i+1] + label = labels[i:i+1] + + # 准备提示 + mask_input, unnorm_coords, point_labels, unnorm_box = predictor._prep_prompts( + point, + label, + box=None, + mask_logits=None, + normalize_coords=True + ) + + # 生成嵌入 + sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder( + points=(unnorm_coords, point_labels), + boxes=None, + masks=None, + ) + batched_mode = unnorm_coords.shape[0] > 1 + # 准备高分辨率特征 + high_res_features = [ + feat_level[-1].unsqueeze(0) + for feat_level in predictor._features["high_res_feats"] + ] + + # 生成掩码 + low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder( + image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0), + image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=True, + repeat_image=batched_mode, # 添加这个参数 + high_res_features=high_res_features, + ) + + # 后处理掩码到原始图像分辨率 + prd_masks = predictor._transforms.postprocess_masks( + low_res_masks, + predictor._orig_hw[-1] + ) + + # 将logit图转换为概率图 + prd_mask = torch.sigmoid(prd_masks[:, 0]) + + # 计算交叉熵损失 + seg_loss = (-gt_mask * torch.log(prd_mask + 1e-5) - + (1 - gt_mask) * torch.log((1 - prd_mask) + 1e-5)).mean() + + # 计算IoU + inter = (gt_mask * (prd_mask > 0.5)).sum() + union = gt_mask.sum() + (prd_mask > 0.5).sum() - inter + iou = inter / (union + 1e-8) # 添加小值防止除零 + + # 计算得分损失 + score_loss = torch.abs(prd_scores[:, 0] - iou).mean() + + # 混合损失 + mask_loss = seg_loss + score_loss * 0.05 + batch_loss += mask_loss + batch_iou += iou.item() + + # 计算平均损失和IoU + if len(masks) > 0: + avg_loss = batch_loss / len(masks) + avg_iou = batch_iou / len(masks) + else: + continue + + # 清空梯度 + optimizer.zero_grad() + + # 反向传播(使用混合精度) + scaler.scale(avg_loss).backward() + + # 更新权重 + scaler.step(optimizer) + scaler.update() + + # 定期保存模型 + if itr % 1000 == 0 and itr > 0: + torch.save(predictor.model.state_dict(), f"models/model_{itr}.torch") + + # 更新平均IoU(使用指数移动平均) + if itr == 0: + mean_iou = avg_iou + else: + mean_iou = mean_iou * 0.99 + 0.01 * avg_iou + + # 打印训练进度 + if itr % 100 == 0: + print(f"步骤 {itr}, 准确率 (IoU) = {mean_iou:.4f}, 损失 = {avg_loss.item():.4f}") + + # 训练结束,保存最终模型 + torch.save(predictor.model.state_dict(), "models/model_final.torch") + print(f"训练完成。最终准确率 (IoU) = {mean_iou:.4f}") + + +# 主程序 +if __name__ == "__main__": + # 数据路径 + csv_path = "/media/ps/data/zhy/Sam2_new/sam2/data_train/train.csv" + images_dir = "/media/ps/data/zhy/Sam2_new/sam2/data_train/JPEGImages" + masks_dir = "/media/ps/data/zhy/Sam2_new/sam2/data_train/Annotations" + + # SAM2模型路径 + model_cfg_path = "configs/sam2.1/sam2.1_hiera_l.yaml" # 修改为实际的配置文件路径 + checkpoint_path = "checkpoints/sam2.1_hiera_large.pt" # 修改为实际的检查点路径 + + # 加载数据 + df = pd.read_csv(csv_path) + print(f"加载了{len(df)}条训练数据") + + # 初始化SAM2预测器 + predictor = initialize_sam2_predictor(model_cfg_path, checkpoint_path) + print("SAM2预测器初始化完成") + + # 开始训练 + train_sam2_model(predictor, df, images_dir, masks_dir, max_iterations=50000) \ No newline at end of file