|
a |
|
b/training.py |
|
|
1 |
import os |
|
|
2 |
import cv2 |
|
|
3 |
import numpy as np |
|
|
4 |
import pandas as pd |
|
|
5 |
import torch |
|
|
6 |
from torch.cuda.amp import GradScaler, autocast |
|
|
7 |
|
|
|
8 |
# 导入SAM2相关模块 |
|
|
9 |
from sam2.build_sam import build_sam2 |
|
|
10 |
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
11 |
|
|
|
12 |
def initialize_sam2_predictor(model_cfg_path, checkpoint_path, device="cuda"): |
|
|
13 |
""" |
|
|
14 |
初始化SAM2预测器 |
|
|
15 |
|
|
|
16 |
Args: |
|
|
17 |
model_cfg_path: SAM2模型配置文件路径 |
|
|
18 |
checkpoint_path: 预训练模型检查点路径 |
|
|
19 |
device: 使用的设备 ('cuda' 或 'cpu') |
|
|
20 |
|
|
|
21 |
Returns: |
|
|
22 |
SAM2ImagePredictor: 初始化好的预测器实例 |
|
|
23 |
""" |
|
|
24 |
# 构建SAM2模型 |
|
|
25 |
sam_model = build_sam2( |
|
|
26 |
model_cfg_path, |
|
|
27 |
checkpoint_path, |
|
|
28 |
device="cuda" |
|
|
29 |
) |
|
|
30 |
|
|
|
31 |
# 创建预测器 |
|
|
32 |
predictor = SAM2ImagePredictor(sam_model) |
|
|
33 |
|
|
|
34 |
# 冻结图像编码器参数以防止过拟合 |
|
|
35 |
for param in predictor.model.image_encoder.parameters(): |
|
|
36 |
param.requires_grad = False |
|
|
37 |
|
|
|
38 |
print("图像编码器已冻结,仅掩码解码器和提示编码器将进行微调") |
|
|
39 |
|
|
|
40 |
return predictor |
|
|
41 |
|
|
|
42 |
def read_batch(df, images_dir, masks_dir): |
|
|
43 |
""" |
|
|
44 |
读取和预处理一批数据 |
|
|
45 |
|
|
|
46 |
Args: |
|
|
47 |
df: 包含图像和掩码信息的DataFrame |
|
|
48 |
images_dir: 图像目录路径 |
|
|
49 |
masks_dir: 掩码目录路径 |
|
|
50 |
|
|
|
51 |
Returns: |
|
|
52 |
tuple: (图像, 掩码, 点坐标, 点标签) |
|
|
53 |
""" |
|
|
54 |
# 随机选择一个图像 |
|
|
55 |
image_ids = df['ImageId'].unique() |
|
|
56 |
selected_image = np.random.choice(image_ids) |
|
|
57 |
|
|
|
58 |
# 获取该图像的所有掩码 |
|
|
59 |
image_masks = df[df['ImageId'] == selected_image] |
|
|
60 |
|
|
|
61 |
# 读取图像 |
|
|
62 |
image_path = os.path.join(images_dir, selected_image) |
|
|
63 |
img = cv2.imread(image_path)[...,::-1] # BGR to RGB |
|
|
64 |
|
|
|
65 |
# 调整图像大小 |
|
|
66 |
r = np.min([1024 / img.shape[1], 1024 / img.shape[0]]) |
|
|
67 |
img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r))) |
|
|
68 |
|
|
|
69 |
masks = [] |
|
|
70 |
points = [] |
|
|
71 |
|
|
|
72 |
# 处理每个掩码 |
|
|
73 |
for _, row in image_masks.iterrows(): |
|
|
74 |
mask_path = os.path.join(masks_dir, row['MaskId']) |
|
|
75 |
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) |
|
|
76 |
|
|
|
77 |
# 调整掩码大小,与图像一致 |
|
|
78 |
mask = cv2.resize(mask, (int(mask.shape[1] * r), int(mask.shape[0] * r)), |
|
|
79 |
interpolation=cv2.INTER_NEAREST) |
|
|
80 |
|
|
|
81 |
# 二值化掩码(如果不是二值的) |
|
|
82 |
binary_mask = (mask > 0).astype(np.uint8) |
|
|
83 |
masks.append(binary_mask) |
|
|
84 |
|
|
|
85 |
# 使用CSV中提供的点坐标,并根据缩放调整 |
|
|
86 |
point_x = int(row['PointX'] * r) |
|
|
87 |
point_y = int(row['PointY'] * r) |
|
|
88 |
points.append([[point_x, point_y]]) |
|
|
89 |
|
|
|
90 |
return img, np.array(masks), np.array(points), np.ones([len(masks), 1]) |
|
|
91 |
|
|
|
92 |
def train_sam2_model(predictor, df, images_dir, masks_dir, max_iterations=50000, learning_rate=1e-5, weight_decay=4e-5): |
|
|
93 |
""" |
|
|
94 |
训练SAM2模型 |
|
|
95 |
|
|
|
96 |
Args: |
|
|
97 |
predictor: SAM2图像预测器实例 |
|
|
98 |
df: 包含图像和掩码信息的DataFrame |
|
|
99 |
images_dir: 图像目录路径 |
|
|
100 |
masks_dir: 掩码目录路径 |
|
|
101 |
max_iterations: 最大迭代次数 |
|
|
102 |
learning_rate: 学习率 |
|
|
103 |
weight_decay: 权重衰减率 |
|
|
104 |
""" |
|
|
105 |
# 启用训练模式 |
|
|
106 |
predictor.model.sam_mask_decoder.train(True) # 启用掩码解码器的训练 |
|
|
107 |
predictor.model.sam_prompt_encoder.train(True) # 启用提示编码器的训练 |
|
|
108 |
|
|
|
109 |
# 设置优化器 |
|
|
110 |
optimizer = torch.optim.AdamW( |
|
|
111 |
params=[p for p in predictor.model.parameters() if p.requires_grad], |
|
|
112 |
lr=learning_rate, |
|
|
113 |
weight_decay=weight_decay |
|
|
114 |
) |
|
|
115 |
|
|
|
116 |
# 设置混合精度训练 |
|
|
117 |
scaler = GradScaler() |
|
|
118 |
mean_iou = 0 |
|
|
119 |
|
|
|
120 |
# 创建保存模型的目录 |
|
|
121 |
os.makedirs("models", exist_ok=True) |
|
|
122 |
|
|
|
123 |
print(f"开始训练,共{max_iterations}次迭代...") |
|
|
124 |
|
|
|
125 |
for itr in range(max_iterations): |
|
|
126 |
# 使用混合精度训练 |
|
|
127 |
with torch.amp.autocast('cuda'): |
|
|
128 |
# 加载数据批次 |
|
|
129 |
image, masks, points, labels = read_batch(df, images_dir, masks_dir) |
|
|
130 |
|
|
|
131 |
# 忽略空批次 |
|
|
132 |
if len(masks) == 0: |
|
|
133 |
continue |
|
|
134 |
|
|
|
135 |
# 确保数据格式正确 |
|
|
136 |
gt_masks = torch.tensor(masks, dtype=torch.float32, device=predictor.device) |
|
|
137 |
|
|
|
138 |
# 对图像应用SAM图像编码器 |
|
|
139 |
predictor.set_image(image) |
|
|
140 |
|
|
|
141 |
# 对每个点/掩码对进行处理 |
|
|
142 |
batch_loss = 0 |
|
|
143 |
batch_iou = 0 |
|
|
144 |
|
|
|
145 |
for i in range(len(points)): |
|
|
146 |
point = points[i:i+1] |
|
|
147 |
gt_mask = gt_masks[i:i+1] |
|
|
148 |
label = labels[i:i+1] |
|
|
149 |
|
|
|
150 |
# 准备提示 |
|
|
151 |
mask_input, unnorm_coords, point_labels, unnorm_box = predictor._prep_prompts( |
|
|
152 |
point, |
|
|
153 |
label, |
|
|
154 |
box=None, |
|
|
155 |
mask_logits=None, |
|
|
156 |
normalize_coords=True |
|
|
157 |
) |
|
|
158 |
|
|
|
159 |
# 生成嵌入 |
|
|
160 |
sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder( |
|
|
161 |
points=(unnorm_coords, point_labels), |
|
|
162 |
boxes=None, |
|
|
163 |
masks=None, |
|
|
164 |
) |
|
|
165 |
batched_mode = unnorm_coords.shape[0] > 1 |
|
|
166 |
# 准备高分辨率特征 |
|
|
167 |
high_res_features = [ |
|
|
168 |
feat_level[-1].unsqueeze(0) |
|
|
169 |
for feat_level in predictor._features["high_res_feats"] |
|
|
170 |
] |
|
|
171 |
|
|
|
172 |
# 生成掩码 |
|
|
173 |
low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder( |
|
|
174 |
image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0), |
|
|
175 |
image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(), |
|
|
176 |
sparse_prompt_embeddings=sparse_embeddings, |
|
|
177 |
dense_prompt_embeddings=dense_embeddings, |
|
|
178 |
multimask_output=True, |
|
|
179 |
repeat_image=batched_mode, # 添加这个参数 |
|
|
180 |
high_res_features=high_res_features, |
|
|
181 |
) |
|
|
182 |
|
|
|
183 |
# 后处理掩码到原始图像分辨率 |
|
|
184 |
prd_masks = predictor._transforms.postprocess_masks( |
|
|
185 |
low_res_masks, |
|
|
186 |
predictor._orig_hw[-1] |
|
|
187 |
) |
|
|
188 |
|
|
|
189 |
# 将logit图转换为概率图 |
|
|
190 |
prd_mask = torch.sigmoid(prd_masks[:, 0]) |
|
|
191 |
|
|
|
192 |
# 计算交叉熵损失 |
|
|
193 |
seg_loss = (-gt_mask * torch.log(prd_mask + 1e-5) - |
|
|
194 |
(1 - gt_mask) * torch.log((1 - prd_mask) + 1e-5)).mean() |
|
|
195 |
|
|
|
196 |
# 计算IoU |
|
|
197 |
inter = (gt_mask * (prd_mask > 0.5)).sum() |
|
|
198 |
union = gt_mask.sum() + (prd_mask > 0.5).sum() - inter |
|
|
199 |
iou = inter / (union + 1e-8) # 添加小值防止除零 |
|
|
200 |
|
|
|
201 |
# 计算得分损失 |
|
|
202 |
score_loss = torch.abs(prd_scores[:, 0] - iou).mean() |
|
|
203 |
|
|
|
204 |
# 混合损失 |
|
|
205 |
mask_loss = seg_loss + score_loss * 0.05 |
|
|
206 |
batch_loss += mask_loss |
|
|
207 |
batch_iou += iou.item() |
|
|
208 |
|
|
|
209 |
# 计算平均损失和IoU |
|
|
210 |
if len(masks) > 0: |
|
|
211 |
avg_loss = batch_loss / len(masks) |
|
|
212 |
avg_iou = batch_iou / len(masks) |
|
|
213 |
else: |
|
|
214 |
continue |
|
|
215 |
|
|
|
216 |
# 清空梯度 |
|
|
217 |
optimizer.zero_grad() |
|
|
218 |
|
|
|
219 |
# 反向传播(使用混合精度) |
|
|
220 |
scaler.scale(avg_loss).backward() |
|
|
221 |
|
|
|
222 |
# 更新权重 |
|
|
223 |
scaler.step(optimizer) |
|
|
224 |
scaler.update() |
|
|
225 |
|
|
|
226 |
# 定期保存模型 |
|
|
227 |
if itr % 1000 == 0 and itr > 0: |
|
|
228 |
torch.save(predictor.model.state_dict(), f"models/model_{itr}.torch") |
|
|
229 |
|
|
|
230 |
# 更新平均IoU(使用指数移动平均) |
|
|
231 |
if itr == 0: |
|
|
232 |
mean_iou = avg_iou |
|
|
233 |
else: |
|
|
234 |
mean_iou = mean_iou * 0.99 + 0.01 * avg_iou |
|
|
235 |
|
|
|
236 |
# 打印训练进度 |
|
|
237 |
if itr % 100 == 0: |
|
|
238 |
print(f"步骤 {itr}, 准确率 (IoU) = {mean_iou:.4f}, 损失 = {avg_loss.item():.4f}") |
|
|
239 |
|
|
|
240 |
# 训练结束,保存最终模型 |
|
|
241 |
torch.save(predictor.model.state_dict(), "models/model_final.torch") |
|
|
242 |
print(f"训练完成。最终准确率 (IoU) = {mean_iou:.4f}") |
|
|
243 |
|
|
|
244 |
|
|
|
245 |
# 主程序 |
|
|
246 |
if __name__ == "__main__": |
|
|
247 |
# 数据路径 |
|
|
248 |
csv_path = "/media/ps/data/zhy/Sam2_new/sam2/data_train/train.csv" |
|
|
249 |
images_dir = "/media/ps/data/zhy/Sam2_new/sam2/data_train/JPEGImages" |
|
|
250 |
masks_dir = "/media/ps/data/zhy/Sam2_new/sam2/data_train/Annotations" |
|
|
251 |
|
|
|
252 |
# SAM2模型路径 |
|
|
253 |
model_cfg_path = "configs/sam2.1/sam2.1_hiera_l.yaml" # 修改为实际的配置文件路径 |
|
|
254 |
checkpoint_path = "checkpoints/sam2.1_hiera_large.pt" # 修改为实际的检查点路径 |
|
|
255 |
|
|
|
256 |
# 加载数据 |
|
|
257 |
df = pd.read_csv(csv_path) |
|
|
258 |
print(f"加载了{len(df)}条训练数据") |
|
|
259 |
|
|
|
260 |
# 初始化SAM2预测器 |
|
|
261 |
predictor = initialize_sam2_predictor(model_cfg_path, checkpoint_path) |
|
|
262 |
print("SAM2预测器初始化完成") |
|
|
263 |
|
|
|
264 |
# 开始训练 |
|
|
265 |
train_sam2_model(predictor, df, images_dir, masks_dir, max_iterations=50000) |