Diff of /training.py [000000] .. [d255cc]

Switch to side-by-side view

--- a
+++ b/training.py
@@ -0,0 +1,265 @@
+import os
+import cv2
+import numpy as np
+import pandas as pd
+import torch
+from torch.cuda.amp import GradScaler, autocast
+
+# 导入SAM2相关模块
+from sam2.build_sam import build_sam2
+from sam2.sam2_image_predictor import SAM2ImagePredictor
+
+def initialize_sam2_predictor(model_cfg_path, checkpoint_path, device="cuda"):
+    """
+    初始化SAM2预测器
+    
+    Args:
+        model_cfg_path: SAM2模型配置文件路径
+        checkpoint_path: 预训练模型检查点路径
+        device: 使用的设备 ('cuda' 或 'cpu')
+        
+    Returns:
+        SAM2ImagePredictor: 初始化好的预测器实例
+    """
+    # 构建SAM2模型
+    sam_model = build_sam2(
+        model_cfg_path,
+        checkpoint_path,
+        device="cuda"
+    )
+    
+    # 创建预测器
+    predictor = SAM2ImagePredictor(sam_model)
+    
+    # 冻结图像编码器参数以防止过拟合
+    for param in predictor.model.image_encoder.parameters():
+        param.requires_grad = False
+    
+    print("图像编码器已冻结,仅掩码解码器和提示编码器将进行微调")
+    
+    return predictor
+
+def read_batch(df, images_dir, masks_dir):
+    """
+    读取和预处理一批数据
+    
+    Args:
+        df: 包含图像和掩码信息的DataFrame
+        images_dir: 图像目录路径
+        masks_dir: 掩码目录路径
+    
+    Returns:
+        tuple: (图像, 掩码, 点坐标, 点标签)
+    """
+    # 随机选择一个图像
+    image_ids = df['ImageId'].unique()
+    selected_image = np.random.choice(image_ids)
+    
+    # 获取该图像的所有掩码
+    image_masks = df[df['ImageId'] == selected_image]
+    
+    # 读取图像
+    image_path = os.path.join(images_dir, selected_image)
+    img = cv2.imread(image_path)[...,::-1]  # BGR to RGB
+    
+    # 调整图像大小
+    r = np.min([1024 / img.shape[1], 1024 / img.shape[0]])
+    img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)))
+    
+    masks = []
+    points = []
+    
+    # 处理每个掩码
+    for _, row in image_masks.iterrows():
+        mask_path = os.path.join(masks_dir, row['MaskId'])
+        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
+        
+        # 调整掩码大小,与图像一致
+        mask = cv2.resize(mask, (int(mask.shape[1] * r), int(mask.shape[0] * r)), 
+                          interpolation=cv2.INTER_NEAREST)
+        
+        # 二值化掩码(如果不是二值的)
+        binary_mask = (mask > 0).astype(np.uint8)
+        masks.append(binary_mask)
+        
+        # 使用CSV中提供的点坐标,并根据缩放调整
+        point_x = int(row['PointX'] * r)
+        point_y = int(row['PointY'] * r)
+        points.append([[point_x, point_y]])
+    
+    return img, np.array(masks), np.array(points), np.ones([len(masks), 1])
+
+def train_sam2_model(predictor, df, images_dir, masks_dir, max_iterations=50000, learning_rate=1e-5, weight_decay=4e-5):
+    """
+    训练SAM2模型
+    
+    Args:
+        predictor: SAM2图像预测器实例
+        df: 包含图像和掩码信息的DataFrame
+        images_dir: 图像目录路径
+        masks_dir: 掩码目录路径
+        max_iterations: 最大迭代次数
+        learning_rate: 学习率
+        weight_decay: 权重衰减率
+    """
+    # 启用训练模式
+    predictor.model.sam_mask_decoder.train(True)  # 启用掩码解码器的训练
+    predictor.model.sam_prompt_encoder.train(True)  # 启用提示编码器的训练
+    
+    # 设置优化器
+    optimizer = torch.optim.AdamW(
+        params=[p for p in predictor.model.parameters() if p.requires_grad],
+        lr=learning_rate,
+        weight_decay=weight_decay
+    )
+    
+    # 设置混合精度训练
+    scaler = GradScaler()
+    mean_iou = 0
+    
+    # 创建保存模型的目录
+    os.makedirs("models", exist_ok=True)
+    
+    print(f"开始训练,共{max_iterations}次迭代...")
+    
+    for itr in range(max_iterations):
+        # 使用混合精度训练
+        with torch.amp.autocast('cuda'):
+            # 加载数据批次
+            image, masks, points, labels = read_batch(df, images_dir, masks_dir)
+            
+            # 忽略空批次
+            if len(masks) == 0:
+                continue
+            
+            # 确保数据格式正确
+            gt_masks = torch.tensor(masks, dtype=torch.float32, device=predictor.device)
+            
+            # 对图像应用SAM图像编码器
+            predictor.set_image(image)
+            
+            # 对每个点/掩码对进行处理
+            batch_loss = 0
+            batch_iou = 0
+            
+            for i in range(len(points)):
+                point = points[i:i+1]
+                gt_mask = gt_masks[i:i+1]
+                label = labels[i:i+1]
+                
+                # 准备提示
+                mask_input, unnorm_coords, point_labels, unnorm_box = predictor._prep_prompts(
+                    point, 
+                    label, 
+                    box=None, 
+                    mask_logits=None, 
+                    normalize_coords=True
+                )
+                
+                # 生成嵌入
+                sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
+                    points=(unnorm_coords, point_labels),
+                    boxes=None,
+                    masks=None,
+                )
+                batched_mode = unnorm_coords.shape[0] > 1
+                # 准备高分辨率特征
+                high_res_features = [
+                    feat_level[-1].unsqueeze(0) 
+                    for feat_level in predictor._features["high_res_feats"]
+                ]
+                
+                # 生成掩码
+                low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
+                    image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
+                    image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
+                    sparse_prompt_embeddings=sparse_embeddings,
+                    dense_prompt_embeddings=dense_embeddings,
+                    multimask_output=True,
+                    repeat_image=batched_mode,  # 添加这个参数
+                    high_res_features=high_res_features,
+                )
+                
+                # 后处理掩码到原始图像分辨率
+                prd_masks = predictor._transforms.postprocess_masks(
+                    low_res_masks, 
+                    predictor._orig_hw[-1]
+                )
+                
+                # 将logit图转换为概率图
+                prd_mask = torch.sigmoid(prd_masks[:, 0])
+                
+                # 计算交叉熵损失
+                seg_loss = (-gt_mask * torch.log(prd_mask + 1e-5) - 
+                          (1 - gt_mask) * torch.log((1 - prd_mask) + 1e-5)).mean()
+                
+                # 计算IoU
+                inter = (gt_mask * (prd_mask > 0.5)).sum()
+                union = gt_mask.sum() + (prd_mask > 0.5).sum() - inter
+                iou = inter / (union + 1e-8)  # 添加小值防止除零
+                
+                # 计算得分损失
+                score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
+                
+                # 混合损失
+                mask_loss = seg_loss + score_loss * 0.05
+                batch_loss += mask_loss
+                batch_iou += iou.item()
+            
+            # 计算平均损失和IoU
+            if len(masks) > 0:
+                avg_loss = batch_loss / len(masks)
+                avg_iou = batch_iou / len(masks)
+            else:
+                continue
+        
+        # 清空梯度
+        optimizer.zero_grad()
+        
+        # 反向传播(使用混合精度)
+        scaler.scale(avg_loss).backward()
+        
+        # 更新权重
+        scaler.step(optimizer)
+        scaler.update()
+        
+        # 定期保存模型
+        if itr % 1000 == 0 and itr > 0:
+            torch.save(predictor.model.state_dict(), f"models/model_{itr}.torch")
+        
+        # 更新平均IoU(使用指数移动平均)
+        if itr == 0:
+            mean_iou = avg_iou
+        else:
+            mean_iou = mean_iou * 0.99 + 0.01 * avg_iou
+        
+        # 打印训练进度
+        if itr % 100 == 0:
+            print(f"步骤 {itr}, 准确率 (IoU) = {mean_iou:.4f}, 损失 = {avg_loss.item():.4f}")
+    
+    # 训练结束,保存最终模型
+    torch.save(predictor.model.state_dict(), "models/model_final.torch")
+    print(f"训练完成。最终准确率 (IoU) = {mean_iou:.4f}")
+
+
+# 主程序
+if __name__ == "__main__":
+    # 数据路径
+    csv_path = "/media/ps/data/zhy/Sam2_new/sam2/data_train/train.csv"
+    images_dir = "/media/ps/data/zhy/Sam2_new/sam2/data_train/JPEGImages"
+    masks_dir = "/media/ps/data/zhy/Sam2_new/sam2/data_train/Annotations"
+    
+    # SAM2模型路径
+    model_cfg_path = "configs/sam2.1/sam2.1_hiera_l.yaml"  # 修改为实际的配置文件路径
+    checkpoint_path = "checkpoints/sam2.1_hiera_large.pt"  # 修改为实际的检查点路径
+    
+    # 加载数据
+    df = pd.read_csv(csv_path)
+    print(f"加载了{len(df)}条训练数据")
+    
+    # 初始化SAM2预测器
+    predictor = initialize_sam2_predictor(model_cfg_path, checkpoint_path)
+    print("SAM2预测器初始化完成")
+    
+    # 开始训练
+    train_sam2_model(predictor, df, images_dir, masks_dir, max_iterations=50000)
\ No newline at end of file