a b/function.py
1
2
import os
3
import sys
4
import argparse
5
from datetime import datetime
6
from collections import OrderedDict
7
import numpy as np
8
import torch
9
import torch.nn as nn
10
import torch.optim as optim
11
from sklearn.metrics import roc_auc_score, accuracy_score,confusion_matrix
12
import torchvision
13
import torchvision.transforms as transforms
14
from skimage import io
15
from torch.utils.data import DataLoader
16
#from dataset import *
17
from torch.autograd import Variable
18
from PIL import Image
19
from tensorboardX import SummaryWriter
20
#from models.discriminatorlayer import discriminator
21
from conf import settings
22
import time
23
import cfg
24
from conf import settings
25
from tqdm import tqdm
26
from utils import *
27
import torch.nn.functional as F
28
import torch
29
from einops import rearrange
30
import pytorch_ssim
31
import models.sam.utils.transforms as samtrans
32
33
# from lucent.modelzoo.util import get_model_layers
34
# from lucent.optvis import render, param, transform, objectives
35
# from lucent.modelzoo import inceptionv1
36
37
import shutil
38
import tempfile
39
40
import matplotlib.pyplot as plt
41
from tqdm import tqdm
42
43
from monai.losses import DiceCELoss
44
from monai.inferers import sliding_window_inference
45
from monai.transforms import (
46
    AsDiscrete,
47
)
48
49
50
import torch
51
52
53
args = cfg.parse_args()
54
55
GPUdevice = torch.device('cuda', args.gpu_device)
56
pos_weight = torch.ones([1]).cuda(device=GPUdevice)*2
57
criterion_G = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
58
seed = torch.randint(1,11,(args.b,7))
59
60
torch.backends.cudnn.benchmark = True
61
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
62
scaler = torch.cuda.amp.GradScaler()
63
max_iterations = settings.EPOCH
64
post_label = AsDiscrete(to_onehot=14)
65
post_pred = AsDiscrete(argmax=True, to_onehot=14)
66
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
67
dice_val_best = 0.0
68
global_step_best = 0
69
epoch_loss_values = []
70
metric_values = []
71
72
def train_sam(args, net: nn.Module, optimizer, train_loader,
73
          epoch, writer, schedulers=None, vis = 50):
74
    hard = 0
75
    epoch_loss = 0
76
    ind = 0
77
    # train mode
78
    net.train()
79
    optimizer.zero_grad()
80
81
    epoch_loss = 0
82
    GPUdevice = torch.device('cuda:' + str(args.gpu_device))
83
    device = GPUdevice
84
85
    if args.thd:
86
        lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
87
    else:
88
        lossfunc = criterion_G
89
90
    with tqdm(total=len(train_loader), desc=f'Epoch {epoch}', unit='img') as pbar:
91
        for pack in train_loader:
92
            imgs = pack['image'].to(dtype = torch.float32, device = GPUdevice)
93
            masks = pack['label'].to(dtype = torch.float32, device = GPUdevice)
94
            # for k,v in pack['image_meta_dict'].items():
95
            #     print(k)
96
            if 'pt' not in pack:
97
                imgs, pt, masks = generate_click_prompt(imgs, masks)
98
            else:
99
                pt = pack['pt']
100
                point_labels = pack['p_label']
101
            name = pack['image_meta_dict']['filename_or_obj']
102
103
            if args.thd:
104
                pt = rearrange(pt, 'b n d -> (b d) n')
105
                imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ')
106
                masks = rearrange(masks, 'b c h w d -> (b d) c h w ')
107
108
                imgs = imgs.repeat(1,3,1,1)
109
                point_labels = torch.ones(imgs.size(0))
110
111
                imgs = torchvision.transforms.Resize((args.image_size,args.image_size))(imgs)
112
                masks = torchvision.transforms.Resize((args.out_size,args.out_size))(masks)
113
            
114
            showp = pt
115
116
            mask_type = torch.float32
117
            ind += 1
118
            b_size,c,w,h = imgs.size()
119
            longsize = w if w >=h else h
120
121
            if point_labels[0] != -1:
122
                # point_coords = samtrans.ResizeLongestSide(longsize).apply_coords(pt, (h, w))
123
                point_coords = pt
124
                coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice)
125
                labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
126
                coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
127
                pt = (coords_torch, labels_torch)
128
129
            '''init'''
130
            if hard:
131
                true_mask_ave = (true_mask_ave > 0.5).float()
132
                #true_mask_ave = cons_tensor(true_mask_ave)
133
            imgs = imgs.to(dtype = mask_type,device = GPUdevice)
134
            
135
            '''Train'''
136
            for n, value in net.image_encoder.named_parameters():
137
                if "Adapter" not in n:
138
                    value.requires_grad = False
139
            imge= net.image_encoder(imgs)
140
141
            with torch.no_grad():
142
                # imge= net.image_encoder(imgs)
143
                se, de = net.prompt_encoder(
144
                    points=pt,
145
                    boxes=None,
146
                    masks=None,
147
                )
148
            pred, _ = net.mask_decoder(
149
                image_embeddings=imge,
150
                image_pe=net.prompt_encoder.get_dense_pe(), 
151
                sparse_prompt_embeddings=se,
152
                dense_prompt_embeddings=de, 
153
                multimask_output=False,
154
              )
155
156
            loss = lossfunc(pred, masks)
157
158
            pbar.set_postfix(**{'loss (batch)': loss.item()})
159
            epoch_loss += loss.item()
160
            loss.backward()
161
162
            # nn.utils.clip_grad_value_(net.parameters(), 0.1)
163
            optimizer.step()
164
            optimizer.zero_grad()
165
166
            '''vis images'''
167
            if vis:
168
                if ind % vis == 0:
169
                    namecat = 'Train'
170
                    for na in name:
171
                        namecat = namecat + na.split('/')[-1].split('.')[0] + '+'
172
                    vis_image(imgs,pred,masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
173
174
            pbar.update()
175
176
    return loss
177
178
def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True):
179
     # eval mode
180
    net.eval()
181
182
    mask_type = torch.float32
183
    n_val = len(val_loader)  # the number of batch
184
    ave_res, mix_res = (0,0,0,0), (0,0,0,0)
185
    rater_res = [(0,0,0,0) for _ in range(6)]
186
    tot = 0
187
    hard = 0
188
    threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
189
    GPUdevice = torch.device('cuda:' + str(args.gpu_device))
190
    device = GPUdevice
191
192
    if args.thd:
193
        lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
194
    else:
195
        lossfunc = criterion_G
196
197
    with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
198
        for ind, pack in enumerate(val_loader):
199
            imgsw = pack['image'].to(dtype = torch.float32, device = GPUdevice)
200
            masksw = pack['label'].to(dtype = torch.float32, device = GPUdevice)
201
            # for k,v in pack['image_meta_dict'].items():
202
            #     print(k)
203
            if 'pt' not in pack:
204
                imgsw, ptw, masksw = generate_click_prompt(imgsw, masksw)
205
            else:
206
                ptw = pack['pt']
207
                point_labels = pack['p_label']
208
            name = pack['image_meta_dict']['filename_or_obj']
209
            
210
            buoy = 0
211
            if args.evl_chunk:
212
                evl_ch = int(args.evl_chunk)
213
            else:
214
                evl_ch = int(imgsw.size(-1))
215
216
            while (buoy + evl_ch) <= imgsw.size(-1):
217
                if args.thd:
218
                    pt = ptw[:,:,buoy: buoy + evl_ch]
219
                else:
220
                    pt = ptw
221
222
                imgs = imgsw[...,buoy:buoy + evl_ch]
223
                masks = masksw[...,buoy:buoy + evl_ch]
224
                buoy += evl_ch
225
226
                if args.thd:
227
                    pt = rearrange(pt, 'b n d -> (b d) n')
228
                    imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ')
229
                    masks = rearrange(masks, 'b c h w d -> (b d) c h w ')
230
                    imgs = imgs.repeat(1,3,1,1)
231
                    point_labels = torch.ones(imgs.size(0))
232
233
                    imgs = torchvision.transforms.Resize((args.image_size,args.image_size))(imgs)
234
                    masks = torchvision.transforms.Resize((args.out_size,args.out_size))(masks)
235
                
236
                showp = pt
237
238
                mask_type = torch.float32
239
                ind += 1
240
                b_size,c,w,h = imgs.size()
241
                longsize = w if w >=h else h
242
243
                if point_labels[0] != -1:
244
                    # point_coords = samtrans.ResizeLongestSide(longsize).apply_coords(pt, (h, w))
245
                    point_coords = pt
246
                    coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice)
247
                    labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
248
                    coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
249
                    pt = (coords_torch, labels_torch)
250
251
                '''init'''
252
                if hard:
253
                    true_mask_ave = (true_mask_ave > 0.5).float()
254
                    #true_mask_ave = cons_tensor(true_mask_ave)
255
                imgs = imgs.to(dtype = mask_type,device = GPUdevice)
256
                
257
                '''test'''
258
                with torch.no_grad():
259
                    imge= net.image_encoder(imgs)
260
261
                    se, de = net.prompt_encoder(
262
                        points=pt,
263
                        boxes=None,
264
                        masks=None,
265
                    )
266
267
                    pred, _ = net.mask_decoder(
268
                        image_embeddings=imge,
269
                        image_pe=net.prompt_encoder.get_dense_pe(),
270
                        sparse_prompt_embeddings=se,
271
                        dense_prompt_embeddings=de, 
272
                        multimask_output=False,
273
                    )
274
                
275
                    tot += lossfunc(pred, masks)
276
277
                    '''vis images'''
278
                    if ind % args.vis == 0:
279
                        namecat = 'Test'
280
                        for na in name:
281
                            img_name = na.split('/')[-1].split('.')[0]
282
                            namecat = namecat + img_name + '+'
283
                        vis_image(imgs,pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
284
                    
285
286
                    temp = eval_seg(pred, masks, threshold)
287
                    mix_res = tuple([sum(a) for a in zip(mix_res, temp)])
288
289
            pbar.update()
290
291
    if args.evl_chunk:
292
        n_val = n_val * (imgsw.size(-1) // evl_ch)
293
294
    return tot/ n_val , tuple([a/n_val for a in mix_res])