|
a |
|
b/finetune_segmentanybone_wo_atten.py |
|
|
1 |
#from segment_anything import SamPredictor, sam_model_registry |
|
|
2 |
from models.sam import SamPredictor, sam_model_registry |
|
|
3 |
from models.sam.utils.transforms import ResizeLongestSide |
|
|
4 |
from models.sam.modeling.prompt_encoder import auto_cls_emb |
|
|
5 |
from models.sam.modeling.prompt_encoder import attention_fusion |
|
|
6 |
from skimage.measure import label |
|
|
7 |
#Scientific computing |
|
|
8 |
import numpy as np |
|
|
9 |
import os |
|
|
10 |
#Pytorch packages |
|
|
11 |
import torch |
|
|
12 |
from torch import nn |
|
|
13 |
import torch.optim as optim |
|
|
14 |
from einops import rearrange |
|
|
15 |
import torchvision |
|
|
16 |
from torchvision import datasets |
|
|
17 |
from tensorboardX import SummaryWriter |
|
|
18 |
#Visulization |
|
|
19 |
import matplotlib.pyplot as plt |
|
|
20 |
from torchvision import transforms |
|
|
21 |
from PIL import Image |
|
|
22 |
#Others |
|
|
23 |
from torch.utils.data import DataLoader, Subset |
|
|
24 |
from torch.autograd import Variable |
|
|
25 |
import matplotlib.pyplot as plt |
|
|
26 |
import copy |
|
|
27 |
from dataset_bone import MRI_dataset_multicls |
|
|
28 |
import torch.nn.functional as F |
|
|
29 |
from torch.nn.functional import one_hot |
|
|
30 |
from pathlib import Path |
|
|
31 |
from tqdm import tqdm |
|
|
32 |
from losses import DiceLoss |
|
|
33 |
from dsc import dice_coeff,dice_coeff_multi_class |
|
|
34 |
import cv2 |
|
|
35 |
import monai |
|
|
36 |
from utils import vis_image |
|
|
37 |
import random |
|
|
38 |
|
|
|
39 |
import cfg |
|
|
40 |
args = cfg.parse_args() |
|
|
41 |
os.environ["CUDA_VISIBLE_DEVICES"] = "1" |
|
|
42 |
args.if_mask_decoder_adapter=True |
|
|
43 |
args.if_encoder_adapter = True |
|
|
44 |
args.lr = 5e-4 |
|
|
45 |
args.decoder_adapt_depth = 2 |
|
|
46 |
args.if_warmup = True |
|
|
47 |
args.initial_path = '/mnt/largeDrives/sevenTBTwo/bone_proj/codes_for_data/' |
|
|
48 |
args.pretrain_weight = os.path.join('/mnt/largeDrives/sevenTBTwo/bone_proj/codes_for_data/588/fine-tune-sam/Medical-SAM-Adapter','2D-MobileSAM-onlyfusion-adapter_Bone_0107_paired_attentionpredicted','checkpoint_best.pth') |
|
|
49 |
args.num_classes = 2 |
|
|
50 |
args.targets = 'multi_all' |
|
|
51 |
|
|
|
52 |
|
|
|
53 |
def train_model(trainloader,valloader,dir_checkpoint,epochs): |
|
|
54 |
# Set up model |
|
|
55 |
|
|
|
56 |
if args.if_warmup: |
|
|
57 |
b_lr = args.lr / args.warmup_period |
|
|
58 |
else: |
|
|
59 |
b_lr = args.lr |
|
|
60 |
|
|
|
61 |
|
|
|
62 |
iter_num = 0 |
|
|
63 |
max_iterations = epochs * len(trainloader) |
|
|
64 |
writer = SummaryWriter(dir_checkpoint + '/log') |
|
|
65 |
|
|
|
66 |
sam = sam_model_registry["vit_t"](args,checkpoint=args.pretrain_weight,num_classes=args.num_classes) |
|
|
67 |
sam.load_state_dict(torch.load(os.path.join(args.pretrain_weight)), strict = False) |
|
|
68 |
print(sam) |
|
|
69 |
|
|
|
70 |
for n, value in sam.named_parameters(): |
|
|
71 |
value.requires_grad = False |
|
|
72 |
|
|
|
73 |
for n, value in sam.mask_decoder.named_parameters(): |
|
|
74 |
if "Adapter" in n: # only update parameters in decoder adapter |
|
|
75 |
value.requires_grad = True |
|
|
76 |
if 'output_hypernetworks_mlps' in n: |
|
|
77 |
value.requires_grad = True |
|
|
78 |
|
|
|
79 |
print('if image encoder adapter:',args.if_encoder_adapter) |
|
|
80 |
print('if mask decoder adapter:',args.if_mask_decoder_adapter) |
|
|
81 |
sam.to('cuda') |
|
|
82 |
|
|
|
83 |
optimizer = optim.AdamW(sam.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False) |
|
|
84 |
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) #learning rate decay |
|
|
85 |
criterion1 = monai.losses.DiceLoss(sigmoid=True, squared_pred=True, to_onehot_y=True,reduction='mean') |
|
|
86 |
criterion2 = nn.CrossEntropyLoss() |
|
|
87 |
|
|
|
88 |
pbar = tqdm(range(epochs)) |
|
|
89 |
val_largest_dsc = 0 |
|
|
90 |
last_update_epoch = 0 |
|
|
91 |
for epoch in pbar: |
|
|
92 |
sam.train() |
|
|
93 |
train_loss = 0 |
|
|
94 |
for i,data in enumerate(trainloader): |
|
|
95 |
imgs = data['image'].cuda() |
|
|
96 |
img_emb= sam.image_encoder(imgs) |
|
|
97 |
alpha = random.random() |
|
|
98 |
# automatic masks contaning all muscles |
|
|
99 |
msks = torchvision.transforms.Resize((args.out_size,args.out_size))(data['mask']) |
|
|
100 |
#print('mask unique value:',msks.unique()) |
|
|
101 |
msks = msks.cuda() |
|
|
102 |
sparse_emb, dense_emb = sam.prompt_encoder( |
|
|
103 |
points=None, |
|
|
104 |
boxes=None, |
|
|
105 |
masks=None, |
|
|
106 |
) |
|
|
107 |
pred, _ = sam.mask_decoder( |
|
|
108 |
image_embeddings=img_emb, |
|
|
109 |
image_pe=sam.prompt_encoder.get_dense_pe(), |
|
|
110 |
sparse_prompt_embeddings=sparse_emb, |
|
|
111 |
dense_prompt_embeddings=dense_emb, |
|
|
112 |
multimask_output=True, |
|
|
113 |
) |
|
|
114 |
loss_dice = criterion1(pred,msks.float()) |
|
|
115 |
loss_ce = criterion2(pred,torch.squeeze(msks.long(),1)) |
|
|
116 |
loss = loss_dice + loss_ce |
|
|
117 |
|
|
|
118 |
loss.backward() |
|
|
119 |
optimizer.step() |
|
|
120 |
optimizer.zero_grad(set_to_none=True) |
|
|
121 |
|
|
|
122 |
if args.if_warmup and iter_num < args.warmup_period: |
|
|
123 |
lr_ = args.lr * ((iter_num + 1) / args.warmup_period) |
|
|
124 |
for param_group in optimizer.param_groups: |
|
|
125 |
param_group['lr'] = lr_ |
|
|
126 |
|
|
|
127 |
else: |
|
|
128 |
if args.if_warmup: |
|
|
129 |
shift_iter = iter_num - args.warmup_period |
|
|
130 |
assert shift_iter >= 0, f'Shift iter is {shift_iter}, smaller than zero' |
|
|
131 |
lr_ = args.lr * (1.0 - shift_iter / max_iterations) ** 0.9 # learning rate adjustment depends on the max iterations |
|
|
132 |
for param_group in optimizer.param_groups: |
|
|
133 |
param_group['lr'] = lr_ |
|
|
134 |
|
|
|
135 |
train_loss += loss.item() |
|
|
136 |
|
|
|
137 |
iter_num+=1 |
|
|
138 |
writer.add_scalar('info/lr', lr_, iter_num) |
|
|
139 |
writer.add_scalar('info/total_loss', loss, iter_num) |
|
|
140 |
writer.add_scalar('info/loss_ce', loss_ce, iter_num) |
|
|
141 |
writer.add_scalar('info/loss_dice', loss_dice, iter_num) |
|
|
142 |
|
|
|
143 |
train_loss /= (i+1) |
|
|
144 |
pbar.set_description('Epoch num {}| train loss {} \n'.format(epoch,train_loss)) |
|
|
145 |
|
|
|
146 |
if epoch%2==0: |
|
|
147 |
eval_loss=0 |
|
|
148 |
dsc = 0 |
|
|
149 |
sam.eval() |
|
|
150 |
with torch.no_grad(): |
|
|
151 |
for i,data in enumerate(valloader): |
|
|
152 |
imgs = data['image'].cuda() |
|
|
153 |
img_emb= sam.image_encoder(imgs) |
|
|
154 |
alpha = random.random() |
|
|
155 |
msks = torchvision.transforms.Resize((args.out_size,args.out_size))(data['mask']) |
|
|
156 |
msks = msks.cuda() |
|
|
157 |
sparse_emb, dense_emb = sam.prompt_encoder( |
|
|
158 |
points=None, |
|
|
159 |
boxes=None, |
|
|
160 |
masks=None, |
|
|
161 |
) |
|
|
162 |
pred, _ = sam.mask_decoder( |
|
|
163 |
image_embeddings=img_emb, |
|
|
164 |
image_pe=sam.prompt_encoder.get_dense_pe(), |
|
|
165 |
sparse_prompt_embeddings=sparse_emb, |
|
|
166 |
dense_prompt_embeddings=dense_emb, |
|
|
167 |
multimask_output=True, |
|
|
168 |
) |
|
|
169 |
loss = criterion1(pred,msks.float()) + criterion2(pred,torch.squeeze(msks.long(),1)) |
|
|
170 |
eval_loss +=loss.item() |
|
|
171 |
dsc_batch = dice_coeff_multi_class(pred.argmax(dim=1).cpu(), torch.squeeze(msks.long(),1).cpu().long(), 5) |
|
|
172 |
dsc+=dsc_batch |
|
|
173 |
|
|
|
174 |
|
|
|
175 |
eval_loss /= (i+1) |
|
|
176 |
dsc /= (i+1) |
|
|
177 |
|
|
|
178 |
writer.add_scalar('eval/loss', eval_loss, epoch) |
|
|
179 |
writer.add_scalar('eval/dice', dsc, epoch) |
|
|
180 |
|
|
|
181 |
print('Eval Epoch num {} | val loss {} | dsc {} \n'.format(epoch,eval_loss,dsc)) |
|
|
182 |
if dsc>val_largest_dsc: |
|
|
183 |
val_largest_dsc = dsc |
|
|
184 |
last_update_epoch = epoch |
|
|
185 |
print('largest DSC now: {}'.format(dsc)) |
|
|
186 |
Path(dir_checkpoint).mkdir(parents=True,exist_ok = True) |
|
|
187 |
torch.save(sam.state_dict(),dir_checkpoint + '/checkpoint_best.pth') |
|
|
188 |
elif (epoch-last_update_epoch)>20: |
|
|
189 |
# the network haven't been updated for 20 epochs |
|
|
190 |
print('Training finished###########') |
|
|
191 |
break |
|
|
192 |
writer.close() |
|
|
193 |
|
|
|
194 |
|
|
|
195 |
if __name__ == "__main__": |
|
|
196 |
bodypart = 'hip' |
|
|
197 |
dataset_name = 'Bone_0820_cls' |
|
|
198 |
img_folder = args.initial_path +'2D-slices/images' |
|
|
199 |
mask_folder = args.initial_path + '2D-slices/masks' |
|
|
200 |
train_img_list = args.initial_path + 'datalist_body_parts/img_list_12_12_train_' + bodypart + '_annotate_paired_2dslices.txt' |
|
|
201 |
val_img_list = args.initial_path + 'datalist_body_parts/img_list_12_12_val_' + bodypart + '_annotate_paired_2dslices.txt' |
|
|
202 |
dir_checkpoint = '2D-MobileSAM-onlyfusion-adapter_'+dataset_name+'_attentionpredicted' |
|
|
203 |
num_workers = 1 |
|
|
204 |
if_vis = True |
|
|
205 |
epochs = 200 |
|
|
206 |
|
|
|
207 |
label_mapping = args.initial_path + 'segment_names_to_labels.pickle' |
|
|
208 |
train_dataset = MRI_dataset_multicls(args,img_folder, mask_folder, train_img_list,phase='train',targets=[args.targets],delete_empty_masks='subsample',label_mapping=label_mapping,if_prompt=False) |
|
|
209 |
eval_dataset = MRI_dataset_multicls(args,img_folder, mask_folder, val_img_list,phase='val',targets=[args.targets],delete_empty_masks='subsample',label_mapping=label_mapping,if_prompt=False) |
|
|
210 |
trainloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=num_workers) |
|
|
211 |
valloader = DataLoader(eval_dataset, batch_size=16, shuffle=False, num_workers=num_workers) |
|
|
212 |
train_model(trainloader,valloader,dir_checkpoint,epochs) |