--- a +++ b/finetune_segmentanybone_wo_atten.py @@ -0,0 +1,212 @@ +#from segment_anything import SamPredictor, sam_model_registry +from models.sam import SamPredictor, sam_model_registry +from models.sam.utils.transforms import ResizeLongestSide +from models.sam.modeling.prompt_encoder import auto_cls_emb +from models.sam.modeling.prompt_encoder import attention_fusion +from skimage.measure import label +#Scientific computing +import numpy as np +import os +#Pytorch packages +import torch +from torch import nn +import torch.optim as optim +from einops import rearrange +import torchvision +from torchvision import datasets +from tensorboardX import SummaryWriter +#Visulization +import matplotlib.pyplot as plt +from torchvision import transforms +from PIL import Image +#Others +from torch.utils.data import DataLoader, Subset +from torch.autograd import Variable +import matplotlib.pyplot as plt +import copy +from dataset_bone import MRI_dataset_multicls +import torch.nn.functional as F +from torch.nn.functional import one_hot +from pathlib import Path +from tqdm import tqdm +from losses import DiceLoss +from dsc import dice_coeff,dice_coeff_multi_class +import cv2 +import monai +from utils import vis_image +import random + +import cfg +args = cfg.parse_args() +os.environ["CUDA_VISIBLE_DEVICES"] = "1" +args.if_mask_decoder_adapter=True +args.if_encoder_adapter = True +args.lr = 5e-4 +args.decoder_adapt_depth = 2 +args.if_warmup = True +args.initial_path = '/mnt/largeDrives/sevenTBTwo/bone_proj/codes_for_data/' +args.pretrain_weight = os.path.join('/mnt/largeDrives/sevenTBTwo/bone_proj/codes_for_data/588/fine-tune-sam/Medical-SAM-Adapter','2D-MobileSAM-onlyfusion-adapter_Bone_0107_paired_attentionpredicted','checkpoint_best.pth') +args.num_classes = 2 +args.targets = 'multi_all' + + +def train_model(trainloader,valloader,dir_checkpoint,epochs): + # Set up model + + if args.if_warmup: + b_lr = args.lr / args.warmup_period + else: + b_lr = args.lr + + + iter_num = 0 + max_iterations = epochs * len(trainloader) + writer = SummaryWriter(dir_checkpoint + '/log') + + sam = sam_model_registry["vit_t"](args,checkpoint=args.pretrain_weight,num_classes=args.num_classes) + sam.load_state_dict(torch.load(os.path.join(args.pretrain_weight)), strict = False) + print(sam) + + for n, value in sam.named_parameters(): + value.requires_grad = False + + for n, value in sam.mask_decoder.named_parameters(): + if "Adapter" in n: # only update parameters in decoder adapter + value.requires_grad = True + if 'output_hypernetworks_mlps' in n: + value.requires_grad = True + + print('if image encoder adapter:',args.if_encoder_adapter) + print('if mask decoder adapter:',args.if_mask_decoder_adapter) + sam.to('cuda') + + optimizer = optim.AdamW(sam.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False) + scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) #learning rate decay + criterion1 = monai.losses.DiceLoss(sigmoid=True, squared_pred=True, to_onehot_y=True,reduction='mean') + criterion2 = nn.CrossEntropyLoss() + + pbar = tqdm(range(epochs)) + val_largest_dsc = 0 + last_update_epoch = 0 + for epoch in pbar: + sam.train() + train_loss = 0 + for i,data in enumerate(trainloader): + imgs = data['image'].cuda() + img_emb= sam.image_encoder(imgs) + alpha = random.random() + # automatic masks contaning all muscles + msks = torchvision.transforms.Resize((args.out_size,args.out_size))(data['mask']) + #print('mask unique value:',msks.unique()) + msks = msks.cuda() + sparse_emb, dense_emb = sam.prompt_encoder( + points=None, + boxes=None, + masks=None, + ) + pred, _ = sam.mask_decoder( + image_embeddings=img_emb, + image_pe=sam.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_emb, + dense_prompt_embeddings=dense_emb, + multimask_output=True, + ) + loss_dice = criterion1(pred,msks.float()) + loss_ce = criterion2(pred,torch.squeeze(msks.long(),1)) + loss = loss_dice + loss_ce + + loss.backward() + optimizer.step() + optimizer.zero_grad(set_to_none=True) + + if args.if_warmup and iter_num < args.warmup_period: + lr_ = args.lr * ((iter_num + 1) / args.warmup_period) + for param_group in optimizer.param_groups: + param_group['lr'] = lr_ + + else: + if args.if_warmup: + shift_iter = iter_num - args.warmup_period + assert shift_iter >= 0, f'Shift iter is {shift_iter}, smaller than zero' + lr_ = args.lr * (1.0 - shift_iter / max_iterations) ** 0.9 # learning rate adjustment depends on the max iterations + for param_group in optimizer.param_groups: + param_group['lr'] = lr_ + + train_loss += loss.item() + + iter_num+=1 + writer.add_scalar('info/lr', lr_, iter_num) + writer.add_scalar('info/total_loss', loss, iter_num) + writer.add_scalar('info/loss_ce', loss_ce, iter_num) + writer.add_scalar('info/loss_dice', loss_dice, iter_num) + + train_loss /= (i+1) + pbar.set_description('Epoch num {}| train loss {} \n'.format(epoch,train_loss)) + + if epoch%2==0: + eval_loss=0 + dsc = 0 + sam.eval() + with torch.no_grad(): + for i,data in enumerate(valloader): + imgs = data['image'].cuda() + img_emb= sam.image_encoder(imgs) + alpha = random.random() + msks = torchvision.transforms.Resize((args.out_size,args.out_size))(data['mask']) + msks = msks.cuda() + sparse_emb, dense_emb = sam.prompt_encoder( + points=None, + boxes=None, + masks=None, + ) + pred, _ = sam.mask_decoder( + image_embeddings=img_emb, + image_pe=sam.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_emb, + dense_prompt_embeddings=dense_emb, + multimask_output=True, + ) + loss = criterion1(pred,msks.float()) + criterion2(pred,torch.squeeze(msks.long(),1)) + eval_loss +=loss.item() + dsc_batch = dice_coeff_multi_class(pred.argmax(dim=1).cpu(), torch.squeeze(msks.long(),1).cpu().long(), 5) + dsc+=dsc_batch + + + eval_loss /= (i+1) + dsc /= (i+1) + + writer.add_scalar('eval/loss', eval_loss, epoch) + writer.add_scalar('eval/dice', dsc, epoch) + + print('Eval Epoch num {} | val loss {} | dsc {} \n'.format(epoch,eval_loss,dsc)) + if dsc>val_largest_dsc: + val_largest_dsc = dsc + last_update_epoch = epoch + print('largest DSC now: {}'.format(dsc)) + Path(dir_checkpoint).mkdir(parents=True,exist_ok = True) + torch.save(sam.state_dict(),dir_checkpoint + '/checkpoint_best.pth') + elif (epoch-last_update_epoch)>20: + # the network haven't been updated for 20 epochs + print('Training finished###########') + break + writer.close() + + +if __name__ == "__main__": + bodypart = 'hip' + dataset_name = 'Bone_0820_cls' + img_folder = args.initial_path +'2D-slices/images' + mask_folder = args.initial_path + '2D-slices/masks' + train_img_list = args.initial_path + 'datalist_body_parts/img_list_12_12_train_' + bodypart + '_annotate_paired_2dslices.txt' + val_img_list = args.initial_path + 'datalist_body_parts/img_list_12_12_val_' + bodypart + '_annotate_paired_2dslices.txt' + dir_checkpoint = '2D-MobileSAM-onlyfusion-adapter_'+dataset_name+'_attentionpredicted' + num_workers = 1 + if_vis = True + epochs = 200 + + label_mapping = args.initial_path + 'segment_names_to_labels.pickle' + train_dataset = MRI_dataset_multicls(args,img_folder, mask_folder, train_img_list,phase='train',targets=[args.targets],delete_empty_masks='subsample',label_mapping=label_mapping,if_prompt=False) + eval_dataset = MRI_dataset_multicls(args,img_folder, mask_folder, val_img_list,phase='val',targets=[args.targets],delete_empty_masks='subsample',label_mapping=label_mapping,if_prompt=False) + trainloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=num_workers) + valloader = DataLoader(eval_dataset, batch_size=16, shuffle=False, num_workers=num_workers) + train_model(trainloader,valloader,dir_checkpoint,epochs) \ No newline at end of file