a b/finetune_segmentanybone_wo_atten.py
1
#from segment_anything import SamPredictor, sam_model_registry
2
from models.sam import SamPredictor, sam_model_registry
3
from models.sam.utils.transforms import ResizeLongestSide
4
from models.sam.modeling.prompt_encoder import auto_cls_emb
5
from models.sam.modeling.prompt_encoder import attention_fusion
6
from skimage.measure import label
7
#Scientific computing 
8
import numpy as np
9
import os
10
#Pytorch packages
11
import torch
12
from torch import nn
13
import torch.optim as optim
14
from einops import rearrange
15
import torchvision
16
from torchvision import datasets
17
from tensorboardX import SummaryWriter
18
#Visulization
19
import matplotlib.pyplot as plt
20
from torchvision import transforms
21
from PIL import Image
22
#Others
23
from torch.utils.data import DataLoader, Subset
24
from torch.autograd import Variable
25
import matplotlib.pyplot as plt
26
import copy
27
from dataset_bone import MRI_dataset_multicls
28
import torch.nn.functional as F
29
from torch.nn.functional import one_hot
30
from pathlib import Path
31
from tqdm import tqdm
32
from losses import DiceLoss
33
from dsc import dice_coeff,dice_coeff_multi_class
34
import cv2
35
import monai
36
from utils import vis_image
37
import random
38
39
import cfg
40
args = cfg.parse_args()
41
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
42
args.if_mask_decoder_adapter=True
43
args.if_encoder_adapter = True
44
args.lr = 5e-4
45
args.decoder_adapt_depth = 2
46
args.if_warmup = True
47
args.initial_path = '/mnt/largeDrives/sevenTBTwo/bone_proj/codes_for_data/'
48
args.pretrain_weight = os.path.join('/mnt/largeDrives/sevenTBTwo/bone_proj/codes_for_data/588/fine-tune-sam/Medical-SAM-Adapter','2D-MobileSAM-onlyfusion-adapter_Bone_0107_paired_attentionpredicted','checkpoint_best.pth')
49
args.num_classes = 2
50
args.targets = 'multi_all'
51
52
53
def train_model(trainloader,valloader,dir_checkpoint,epochs):
54
    # Set up model
55
    
56
    if args.if_warmup:
57
        b_lr = args.lr / args.warmup_period
58
    else:
59
        b_lr = args.lr
60
    
61
    
62
    iter_num = 0
63
    max_iterations = epochs * len(trainloader) 
64
    writer = SummaryWriter(dir_checkpoint + '/log')
65
    
66
    sam = sam_model_registry["vit_t"](args,checkpoint=args.pretrain_weight,num_classes=args.num_classes) 
67
    sam.load_state_dict(torch.load(os.path.join(args.pretrain_weight)), strict = False)
68
    print(sam)
69
    
70
    for n, value in sam.named_parameters():
71
        value.requires_grad = False
72
    
73
    for n, value in sam.mask_decoder.named_parameters():
74
        if "Adapter" in n: # only update parameters in decoder adapter
75
            value.requires_grad = True
76
        if 'output_hypernetworks_mlps' in n:
77
            value.requires_grad = True
78
            
79
    print('if image encoder adapter:',args.if_encoder_adapter)
80
    print('if mask decoder adapter:',args.if_mask_decoder_adapter)
81
    sam.to('cuda')
82
    
83
    optimizer = optim.AdamW(sam.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
84
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) #learning rate decay
85
    criterion1 = monai.losses.DiceLoss(sigmoid=True, squared_pred=True, to_onehot_y=True,reduction='mean')
86
    criterion2 = nn.CrossEntropyLoss()
87
    
88
    pbar = tqdm(range(epochs))
89
    val_largest_dsc = 0
90
    last_update_epoch = 0
91
    for epoch in pbar:
92
        sam.train()
93
        train_loss = 0
94
        for i,data in enumerate(trainloader):
95
            imgs = data['image'].cuda()
96
            img_emb= sam.image_encoder(imgs)
97
            alpha = random.random()
98
            # automatic masks contaning all muscles
99
            msks = torchvision.transforms.Resize((args.out_size,args.out_size))(data['mask'])
100
            #print('mask unique value:',msks.unique())
101
            msks = msks.cuda()
102
            sparse_emb, dense_emb = sam.prompt_encoder(
103
                points=None,
104
                boxes=None,
105
                masks=None,
106
            )
107
            pred, _ = sam.mask_decoder(
108
                            image_embeddings=img_emb,
109
                            image_pe=sam.prompt_encoder.get_dense_pe(), 
110
                            sparse_prompt_embeddings=sparse_emb,
111
                            dense_prompt_embeddings=dense_emb, 
112
                            multimask_output=True,
113
                          )
114
            loss_dice = criterion1(pred,msks.float()) 
115
            loss_ce = criterion2(pred,torch.squeeze(msks.long(),1))
116
            loss =  loss_dice + loss_ce
117
            
118
            loss.backward()
119
            optimizer.step()
120
            optimizer.zero_grad(set_to_none=True)
121
122
            if args.if_warmup and iter_num < args.warmup_period:
123
                lr_ = args.lr * ((iter_num + 1) / args.warmup_period)
124
                for param_group in optimizer.param_groups:
125
                    param_group['lr'] = lr_
126
127
            else:
128
                if args.if_warmup:
129
                    shift_iter = iter_num - args.warmup_period
130
                    assert shift_iter >= 0, f'Shift iter is {shift_iter}, smaller than zero'
131
                    lr_ = args.lr * (1.0 - shift_iter / max_iterations) ** 0.9  # learning rate adjustment depends on the max iterations
132
                    for param_group in optimizer.param_groups:
133
                        param_group['lr'] = lr_
134
                        
135
            train_loss += loss.item()
136
            
137
            iter_num+=1
138
            writer.add_scalar('info/lr', lr_, iter_num)
139
            writer.add_scalar('info/total_loss', loss, iter_num)
140
            writer.add_scalar('info/loss_ce', loss_ce, iter_num)
141
            writer.add_scalar('info/loss_dice', loss_dice, iter_num)
142
143
        train_loss /= (i+1)
144
        pbar.set_description('Epoch num {}| train loss {} \n'.format(epoch,train_loss))
145
146
        if epoch%2==0:
147
            eval_loss=0
148
            dsc = 0
149
            sam.eval()
150
            with torch.no_grad():
151
                for i,data in enumerate(valloader):
152
                    imgs = data['image'].cuda()
153
                    img_emb= sam.image_encoder(imgs)
154
                    alpha = random.random()
155
                    msks = torchvision.transforms.Resize((args.out_size,args.out_size))(data['mask'])
156
                    msks = msks.cuda()
157
                    sparse_emb, dense_emb = sam.prompt_encoder(
158
                        points=None,
159
                        boxes=None,
160
                        masks=None,
161
                    )
162
                    pred, _ = sam.mask_decoder(
163
                                    image_embeddings=img_emb,
164
                                    image_pe=sam.prompt_encoder.get_dense_pe(), 
165
                                    sparse_prompt_embeddings=sparse_emb,
166
                                    dense_prompt_embeddings=dense_emb, 
167
                                    multimask_output=True,
168
                                  )
169
                    loss = criterion1(pred,msks.float()) + criterion2(pred,torch.squeeze(msks.long(),1))
170
                    eval_loss +=loss.item()
171
                    dsc_batch = dice_coeff_multi_class(pred.argmax(dim=1).cpu(), torch.squeeze(msks.long(),1).cpu().long(), 5)
172
                    dsc+=dsc_batch
173
174
                    
175
                eval_loss /= (i+1)
176
                dsc /= (i+1)
177
                
178
                writer.add_scalar('eval/loss', eval_loss, epoch)
179
                writer.add_scalar('eval/dice', dsc, epoch)
180
                
181
                print('Eval Epoch num {} | val loss {} | dsc {} \n'.format(epoch,eval_loss,dsc))
182
                if dsc>val_largest_dsc:
183
                    val_largest_dsc = dsc
184
                    last_update_epoch = epoch
185
                    print('largest DSC now: {}'.format(dsc))
186
                    Path(dir_checkpoint).mkdir(parents=True,exist_ok = True)
187
                    torch.save(sam.state_dict(),dir_checkpoint + '/checkpoint_best.pth')
188
                elif (epoch-last_update_epoch)>20:
189
                    # the network haven't been updated for 20 epochs
190
                    print('Training finished###########')
191
                    break
192
    writer.close()                                 
193
                
194
                
195
if __name__ == "__main__":
196
    bodypart = 'hip'
197
    dataset_name = 'Bone_0820_cls'
198
    img_folder = args.initial_path +'2D-slices/images'
199
    mask_folder = args.initial_path + '2D-slices/masks'
200
    train_img_list = args.initial_path + 'datalist_body_parts/img_list_12_12_train_' + bodypart + '_annotate_paired_2dslices.txt'
201
    val_img_list = args.initial_path + 'datalist_body_parts/img_list_12_12_val_' + bodypart + '_annotate_paired_2dslices.txt'
202
    dir_checkpoint = '2D-MobileSAM-onlyfusion-adapter_'+dataset_name+'_attentionpredicted'
203
    num_workers = 1
204
    if_vis = True
205
    epochs = 200
206
    
207
    label_mapping = args.initial_path  + 'segment_names_to_labels.pickle'
208
    train_dataset = MRI_dataset_multicls(args,img_folder, mask_folder, train_img_list,phase='train',targets=[args.targets],delete_empty_masks='subsample',label_mapping=label_mapping,if_prompt=False)
209
    eval_dataset = MRI_dataset_multicls(args,img_folder, mask_folder, val_img_list,phase='val',targets=[args.targets],delete_empty_masks='subsample',label_mapping=label_mapping,if_prompt=False)
210
    trainloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=num_workers)
211
    valloader = DataLoader(eval_dataset, batch_size=16, shuffle=False, num_workers=num_workers)
212
    train_model(trainloader,valloader,dir_checkpoint,epochs)