Diff of /training.py [000000] .. [d255cc]

Switch to unified view

a b/training.py
1
import os
2
import cv2
3
import numpy as np
4
import pandas as pd
5
import torch
6
from torch.cuda.amp import GradScaler, autocast
7
8
# 导入SAM2相关模块
9
from sam2.build_sam import build_sam2
10
from sam2.sam2_image_predictor import SAM2ImagePredictor
11
12
def initialize_sam2_predictor(model_cfg_path, checkpoint_path, device="cuda"):
13
    """
14
    初始化SAM2预测器
15
    
16
    Args:
17
        model_cfg_path: SAM2模型配置文件路径
18
        checkpoint_path: 预训练模型检查点路径
19
        device: 使用的设备 ('cuda' 或 'cpu')
20
        
21
    Returns:
22
        SAM2ImagePredictor: 初始化好的预测器实例
23
    """
24
    # 构建SAM2模型
25
    sam_model = build_sam2(
26
        model_cfg_path,
27
        checkpoint_path,
28
        device="cuda"
29
    )
30
    
31
    # 创建预测器
32
    predictor = SAM2ImagePredictor(sam_model)
33
    
34
    # 冻结图像编码器参数以防止过拟合
35
    for param in predictor.model.image_encoder.parameters():
36
        param.requires_grad = False
37
    
38
    print("图像编码器已冻结,仅掩码解码器和提示编码器将进行微调")
39
    
40
    return predictor
41
42
def read_batch(df, images_dir, masks_dir):
43
    """
44
    读取和预处理一批数据
45
    
46
    Args:
47
        df: 包含图像和掩码信息的DataFrame
48
        images_dir: 图像目录路径
49
        masks_dir: 掩码目录路径
50
    
51
    Returns:
52
        tuple: (图像, 掩码, 点坐标, 点标签)
53
    """
54
    # 随机选择一个图像
55
    image_ids = df['ImageId'].unique()
56
    selected_image = np.random.choice(image_ids)
57
    
58
    # 获取该图像的所有掩码
59
    image_masks = df[df['ImageId'] == selected_image]
60
    
61
    # 读取图像
62
    image_path = os.path.join(images_dir, selected_image)
63
    img = cv2.imread(image_path)[...,::-1]  # BGR to RGB
64
    
65
    # 调整图像大小
66
    r = np.min([1024 / img.shape[1], 1024 / img.shape[0]])
67
    img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)))
68
    
69
    masks = []
70
    points = []
71
    
72
    # 处理每个掩码
73
    for _, row in image_masks.iterrows():
74
        mask_path = os.path.join(masks_dir, row['MaskId'])
75
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
76
        
77
        # 调整掩码大小,与图像一致
78
        mask = cv2.resize(mask, (int(mask.shape[1] * r), int(mask.shape[0] * r)), 
79
                          interpolation=cv2.INTER_NEAREST)
80
        
81
        # 二值化掩码(如果不是二值的)
82
        binary_mask = (mask > 0).astype(np.uint8)
83
        masks.append(binary_mask)
84
        
85
        # 使用CSV中提供的点坐标,并根据缩放调整
86
        point_x = int(row['PointX'] * r)
87
        point_y = int(row['PointY'] * r)
88
        points.append([[point_x, point_y]])
89
    
90
    return img, np.array(masks), np.array(points), np.ones([len(masks), 1])
91
92
def train_sam2_model(predictor, df, images_dir, masks_dir, max_iterations=50000, learning_rate=1e-5, weight_decay=4e-5):
93
    """
94
    训练SAM2模型
95
    
96
    Args:
97
        predictor: SAM2图像预测器实例
98
        df: 包含图像和掩码信息的DataFrame
99
        images_dir: 图像目录路径
100
        masks_dir: 掩码目录路径
101
        max_iterations: 最大迭代次数
102
        learning_rate: 学习率
103
        weight_decay: 权重衰减率
104
    """
105
    # 启用训练模式
106
    predictor.model.sam_mask_decoder.train(True)  # 启用掩码解码器的训练
107
    predictor.model.sam_prompt_encoder.train(True)  # 启用提示编码器的训练
108
    
109
    # 设置优化器
110
    optimizer = torch.optim.AdamW(
111
        params=[p for p in predictor.model.parameters() if p.requires_grad],
112
        lr=learning_rate,
113
        weight_decay=weight_decay
114
    )
115
    
116
    # 设置混合精度训练
117
    scaler = GradScaler()
118
    mean_iou = 0
119
    
120
    # 创建保存模型的目录
121
    os.makedirs("models", exist_ok=True)
122
    
123
    print(f"开始训练,共{max_iterations}次迭代...")
124
    
125
    for itr in range(max_iterations):
126
        # 使用混合精度训练
127
        with torch.amp.autocast('cuda'):
128
            # 加载数据批次
129
            image, masks, points, labels = read_batch(df, images_dir, masks_dir)
130
            
131
            # 忽略空批次
132
            if len(masks) == 0:
133
                continue
134
            
135
            # 确保数据格式正确
136
            gt_masks = torch.tensor(masks, dtype=torch.float32, device=predictor.device)
137
            
138
            # 对图像应用SAM图像编码器
139
            predictor.set_image(image)
140
            
141
            # 对每个点/掩码对进行处理
142
            batch_loss = 0
143
            batch_iou = 0
144
            
145
            for i in range(len(points)):
146
                point = points[i:i+1]
147
                gt_mask = gt_masks[i:i+1]
148
                label = labels[i:i+1]
149
                
150
                # 准备提示
151
                mask_input, unnorm_coords, point_labels, unnorm_box = predictor._prep_prompts(
152
                    point, 
153
                    label, 
154
                    box=None, 
155
                    mask_logits=None, 
156
                    normalize_coords=True
157
                )
158
                
159
                # 生成嵌入
160
                sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
161
                    points=(unnorm_coords, point_labels),
162
                    boxes=None,
163
                    masks=None,
164
                )
165
                batched_mode = unnorm_coords.shape[0] > 1
166
                # 准备高分辨率特征
167
                high_res_features = [
168
                    feat_level[-1].unsqueeze(0) 
169
                    for feat_level in predictor._features["high_res_feats"]
170
                ]
171
                
172
                # 生成掩码
173
                low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
174
                    image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
175
                    image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
176
                    sparse_prompt_embeddings=sparse_embeddings,
177
                    dense_prompt_embeddings=dense_embeddings,
178
                    multimask_output=True,
179
                    repeat_image=batched_mode,  # 添加这个参数
180
                    high_res_features=high_res_features,
181
                )
182
                
183
                # 后处理掩码到原始图像分辨率
184
                prd_masks = predictor._transforms.postprocess_masks(
185
                    low_res_masks, 
186
                    predictor._orig_hw[-1]
187
                )
188
                
189
                # 将logit图转换为概率图
190
                prd_mask = torch.sigmoid(prd_masks[:, 0])
191
                
192
                # 计算交叉熵损失
193
                seg_loss = (-gt_mask * torch.log(prd_mask + 1e-5) - 
194
                          (1 - gt_mask) * torch.log((1 - prd_mask) + 1e-5)).mean()
195
                
196
                # 计算IoU
197
                inter = (gt_mask * (prd_mask > 0.5)).sum()
198
                union = gt_mask.sum() + (prd_mask > 0.5).sum() - inter
199
                iou = inter / (union + 1e-8)  # 添加小值防止除零
200
                
201
                # 计算得分损失
202
                score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
203
                
204
                # 混合损失
205
                mask_loss = seg_loss + score_loss * 0.05
206
                batch_loss += mask_loss
207
                batch_iou += iou.item()
208
            
209
            # 计算平均损失和IoU
210
            if len(masks) > 0:
211
                avg_loss = batch_loss / len(masks)
212
                avg_iou = batch_iou / len(masks)
213
            else:
214
                continue
215
        
216
        # 清空梯度
217
        optimizer.zero_grad()
218
        
219
        # 反向传播(使用混合精度)
220
        scaler.scale(avg_loss).backward()
221
        
222
        # 更新权重
223
        scaler.step(optimizer)
224
        scaler.update()
225
        
226
        # 定期保存模型
227
        if itr % 1000 == 0 and itr > 0:
228
            torch.save(predictor.model.state_dict(), f"models/model_{itr}.torch")
229
        
230
        # 更新平均IoU(使用指数移动平均)
231
        if itr == 0:
232
            mean_iou = avg_iou
233
        else:
234
            mean_iou = mean_iou * 0.99 + 0.01 * avg_iou
235
        
236
        # 打印训练进度
237
        if itr % 100 == 0:
238
            print(f"步骤 {itr}, 准确率 (IoU) = {mean_iou:.4f}, 损失 = {avg_loss.item():.4f}")
239
    
240
    # 训练结束,保存最终模型
241
    torch.save(predictor.model.state_dict(), "models/model_final.torch")
242
    print(f"训练完成。最终准确率 (IoU) = {mean_iou:.4f}")
243
244
245
# 主程序
246
if __name__ == "__main__":
247
    # 数据路径
248
    csv_path = "/media/ps/data/zhy/Sam2_new/sam2/data_train/train.csv"
249
    images_dir = "/media/ps/data/zhy/Sam2_new/sam2/data_train/JPEGImages"
250
    masks_dir = "/media/ps/data/zhy/Sam2_new/sam2/data_train/Annotations"
251
    
252
    # SAM2模型路径
253
    model_cfg_path = "configs/sam2.1/sam2.1_hiera_l.yaml"  # 修改为实际的配置文件路径
254
    checkpoint_path = "checkpoints/sam2.1_hiera_large.pt"  # 修改为实际的检查点路径
255
    
256
    # 加载数据
257
    df = pd.read_csv(csv_path)
258
    print(f"加载了{len(df)}条训练数据")
259
    
260
    # 初始化SAM2预测器
261
    predictor = initialize_sam2_predictor(model_cfg_path, checkpoint_path)
262
    print("SAM2预测器初始化完成")
263
    
264
    # 开始训练
265
    train_sam2_model(predictor, df, images_dir, masks_dir, max_iterations=50000)