Switch to side-by-side view

--- a
+++ b/finetune_segmentanybone_wo_atten.py
@@ -0,0 +1,212 @@
+#from segment_anything import SamPredictor, sam_model_registry
+from models.sam import SamPredictor, sam_model_registry
+from models.sam.utils.transforms import ResizeLongestSide
+from models.sam.modeling.prompt_encoder import auto_cls_emb
+from models.sam.modeling.prompt_encoder import attention_fusion
+from skimage.measure import label
+#Scientific computing 
+import numpy as np
+import os
+#Pytorch packages
+import torch
+from torch import nn
+import torch.optim as optim
+from einops import rearrange
+import torchvision
+from torchvision import datasets
+from tensorboardX import SummaryWriter
+#Visulization
+import matplotlib.pyplot as plt
+from torchvision import transforms
+from PIL import Image
+#Others
+from torch.utils.data import DataLoader, Subset
+from torch.autograd import Variable
+import matplotlib.pyplot as plt
+import copy
+from dataset_bone import MRI_dataset_multicls
+import torch.nn.functional as F
+from torch.nn.functional import one_hot
+from pathlib import Path
+from tqdm import tqdm
+from losses import DiceLoss
+from dsc import dice_coeff,dice_coeff_multi_class
+import cv2
+import monai
+from utils import vis_image
+import random
+
+import cfg
+args = cfg.parse_args()
+os.environ["CUDA_VISIBLE_DEVICES"] = "1"
+args.if_mask_decoder_adapter=True
+args.if_encoder_adapter = True
+args.lr = 5e-4
+args.decoder_adapt_depth = 2
+args.if_warmup = True
+args.initial_path = '/mnt/largeDrives/sevenTBTwo/bone_proj/codes_for_data/'
+args.pretrain_weight = os.path.join('/mnt/largeDrives/sevenTBTwo/bone_proj/codes_for_data/588/fine-tune-sam/Medical-SAM-Adapter','2D-MobileSAM-onlyfusion-adapter_Bone_0107_paired_attentionpredicted','checkpoint_best.pth')
+args.num_classes = 2
+args.targets = 'multi_all'
+
+
+def train_model(trainloader,valloader,dir_checkpoint,epochs):
+    # Set up model
+    
+    if args.if_warmup:
+        b_lr = args.lr / args.warmup_period
+    else:
+        b_lr = args.lr
+    
+    
+    iter_num = 0
+    max_iterations = epochs * len(trainloader) 
+    writer = SummaryWriter(dir_checkpoint + '/log')
+    
+    sam = sam_model_registry["vit_t"](args,checkpoint=args.pretrain_weight,num_classes=args.num_classes) 
+    sam.load_state_dict(torch.load(os.path.join(args.pretrain_weight)), strict = False)
+    print(sam)
+    
+    for n, value in sam.named_parameters():
+        value.requires_grad = False
+    
+    for n, value in sam.mask_decoder.named_parameters():
+        if "Adapter" in n: # only update parameters in decoder adapter
+            value.requires_grad = True
+        if 'output_hypernetworks_mlps' in n:
+            value.requires_grad = True
+            
+    print('if image encoder adapter:',args.if_encoder_adapter)
+    print('if mask decoder adapter:',args.if_mask_decoder_adapter)
+    sam.to('cuda')
+    
+    optimizer = optim.AdamW(sam.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
+    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) #learning rate decay
+    criterion1 = monai.losses.DiceLoss(sigmoid=True, squared_pred=True, to_onehot_y=True,reduction='mean')
+    criterion2 = nn.CrossEntropyLoss()
+    
+    pbar = tqdm(range(epochs))
+    val_largest_dsc = 0
+    last_update_epoch = 0
+    for epoch in pbar:
+        sam.train()
+        train_loss = 0
+        for i,data in enumerate(trainloader):
+            imgs = data['image'].cuda()
+            img_emb= sam.image_encoder(imgs)
+            alpha = random.random()
+            # automatic masks contaning all muscles
+            msks = torchvision.transforms.Resize((args.out_size,args.out_size))(data['mask'])
+            #print('mask unique value:',msks.unique())
+            msks = msks.cuda()
+            sparse_emb, dense_emb = sam.prompt_encoder(
+                points=None,
+                boxes=None,
+                masks=None,
+            )
+            pred, _ = sam.mask_decoder(
+                            image_embeddings=img_emb,
+                            image_pe=sam.prompt_encoder.get_dense_pe(), 
+                            sparse_prompt_embeddings=sparse_emb,
+                            dense_prompt_embeddings=dense_emb, 
+                            multimask_output=True,
+                          )
+            loss_dice = criterion1(pred,msks.float()) 
+            loss_ce = criterion2(pred,torch.squeeze(msks.long(),1))
+            loss =  loss_dice + loss_ce
+            
+            loss.backward()
+            optimizer.step()
+            optimizer.zero_grad(set_to_none=True)
+
+            if args.if_warmup and iter_num < args.warmup_period:
+                lr_ = args.lr * ((iter_num + 1) / args.warmup_period)
+                for param_group in optimizer.param_groups:
+                    param_group['lr'] = lr_
+
+            else:
+                if args.if_warmup:
+                    shift_iter = iter_num - args.warmup_period
+                    assert shift_iter >= 0, f'Shift iter is {shift_iter}, smaller than zero'
+                    lr_ = args.lr * (1.0 - shift_iter / max_iterations) ** 0.9  # learning rate adjustment depends on the max iterations
+                    for param_group in optimizer.param_groups:
+                        param_group['lr'] = lr_
+                        
+            train_loss += loss.item()
+            
+            iter_num+=1
+            writer.add_scalar('info/lr', lr_, iter_num)
+            writer.add_scalar('info/total_loss', loss, iter_num)
+            writer.add_scalar('info/loss_ce', loss_ce, iter_num)
+            writer.add_scalar('info/loss_dice', loss_dice, iter_num)
+
+        train_loss /= (i+1)
+        pbar.set_description('Epoch num {}| train loss {} \n'.format(epoch,train_loss))
+
+        if epoch%2==0:
+            eval_loss=0
+            dsc = 0
+            sam.eval()
+            with torch.no_grad():
+                for i,data in enumerate(valloader):
+                    imgs = data['image'].cuda()
+                    img_emb= sam.image_encoder(imgs)
+                    alpha = random.random()
+                    msks = torchvision.transforms.Resize((args.out_size,args.out_size))(data['mask'])
+                    msks = msks.cuda()
+                    sparse_emb, dense_emb = sam.prompt_encoder(
+                        points=None,
+                        boxes=None,
+                        masks=None,
+                    )
+                    pred, _ = sam.mask_decoder(
+                                    image_embeddings=img_emb,
+                                    image_pe=sam.prompt_encoder.get_dense_pe(), 
+                                    sparse_prompt_embeddings=sparse_emb,
+                                    dense_prompt_embeddings=dense_emb, 
+                                    multimask_output=True,
+                                  )
+                    loss = criterion1(pred,msks.float()) + criterion2(pred,torch.squeeze(msks.long(),1))
+                    eval_loss +=loss.item()
+                    dsc_batch = dice_coeff_multi_class(pred.argmax(dim=1).cpu(), torch.squeeze(msks.long(),1).cpu().long(), 5)
+                    dsc+=dsc_batch
+
+                    
+                eval_loss /= (i+1)
+                dsc /= (i+1)
+                
+                writer.add_scalar('eval/loss', eval_loss, epoch)
+                writer.add_scalar('eval/dice', dsc, epoch)
+                
+                print('Eval Epoch num {} | val loss {} | dsc {} \n'.format(epoch,eval_loss,dsc))
+                if dsc>val_largest_dsc:
+                    val_largest_dsc = dsc
+                    last_update_epoch = epoch
+                    print('largest DSC now: {}'.format(dsc))
+                    Path(dir_checkpoint).mkdir(parents=True,exist_ok = True)
+                    torch.save(sam.state_dict(),dir_checkpoint + '/checkpoint_best.pth')
+                elif (epoch-last_update_epoch)>20:
+                    # the network haven't been updated for 20 epochs
+                    print('Training finished###########')
+                    break
+    writer.close()                                 
+                
+                
+if __name__ == "__main__":
+    bodypart = 'hip'
+    dataset_name = 'Bone_0820_cls'
+    img_folder = args.initial_path +'2D-slices/images'
+    mask_folder = args.initial_path + '2D-slices/masks'
+    train_img_list = args.initial_path + 'datalist_body_parts/img_list_12_12_train_' + bodypart + '_annotate_paired_2dslices.txt'
+    val_img_list = args.initial_path + 'datalist_body_parts/img_list_12_12_val_' + bodypart + '_annotate_paired_2dslices.txt'
+    dir_checkpoint = '2D-MobileSAM-onlyfusion-adapter_'+dataset_name+'_attentionpredicted'
+    num_workers = 1
+    if_vis = True
+    epochs = 200
+    
+    label_mapping = args.initial_path  + 'segment_names_to_labels.pickle'
+    train_dataset = MRI_dataset_multicls(args,img_folder, mask_folder, train_img_list,phase='train',targets=[args.targets],delete_empty_masks='subsample',label_mapping=label_mapping,if_prompt=False)
+    eval_dataset = MRI_dataset_multicls(args,img_folder, mask_folder, val_img_list,phase='val',targets=[args.targets],delete_empty_masks='subsample',label_mapping=label_mapping,if_prompt=False)
+    trainloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=num_workers)
+    valloader = DataLoader(eval_dataset, batch_size=16, shuffle=False, num_workers=num_workers)
+    train_model(trainloader,valloader,dir_checkpoint,epochs)
\ No newline at end of file