a b/Evidential_segmentation/TRAINING-ENN_step2.py
1
#!/usr/bin/env python
2
# coding: utf-8
3
4
# In[]:
5
###########################  IMPORTS   ############################################# 
6
7
import torch
8
import torch.nn as nn
9
import torch.nn.functional as F
10
import monai
11
from monai.networks.nets import UNet,VNet,DynUNet,UNet_ENN_KMEANS
12
from monai.networks.utils import one_hot
13
from monai.transforms import (
14
    AsDiscrete,
15
    AddChanneld,
16
    AsChannelFirstd,
17
    Compose,
18
    LoadNiftid,
19
    RandCropByPosNegLabeld,
20
    RandRotate90d,
21
    ScaleIntensityd,
22
    ToTensord,
23
)
24
from monai.visualize import plot_2d_or_3d_image
25
from monai.data.utils import list_data_collate, worker_init_fn
26
from monai.inferers import sliding_window_inference
27
from monai.metrics import DiceMetric
28
from monai.metrics import compute_meandice
29
from torch.autograd import Variable
30
import torch.optim as optim
31
from torch.optim import lr_scheduler
32
from torch.utils.data import DataLoader
33
from torch.utils.data import Dataset
34
import torch.backends.cudnn as cudnn
35
import torchvision
36
from torchvision import datasets, models, transforms
37
import csv
38
import time
39
import SimpleITK as sitk
40
from os.path import splitext,basename
41
import random
42
from glob import glob
43
44
import matplotlib.pyplot as plt
45
from matplotlib.backends.backend_pdf import PdfPages
46
from copy import copy
47
import os
48
import numpy as np
49
from torch.utils.tensorboard import SummaryWriter
50
51
#from global_tools.tools import display_loading_bar
52
from class_modalities.transforms import LoadNifti, Roi2Mask, ResampleReshapeAlign, Sitk2Numpy, ConcatModality
53
from monai.utils import first, set_determinism
54
55
56
train_transforms = Compose(
57
    [  # read img + meta info
58
        LoadNifti(keys=["pet_img", "ct_img", "mask_img"]),
59
        Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']),
60
        ConcatModality(keys=['pet_img', 'ct_img']),
61
        AddChanneld(keys=["mask_img"]),  # Add channel to the first axis
62
        ToTensord(keys=["image", "mask_img"]),
63
    ])
64
# without data augmentation for validation
65
val_transforms = Compose(
66
    [  # read img + meta info
67
        LoadNifti(keys=["pet_img", "ct_img", "mask_img"]),
68
        Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']),
69
        ConcatModality(keys=['pet_img', 'ct_img']),
70
        AddChanneld(keys=["mask_img"]),  # Add channel to the first axis
71
        ToTensord(keys=["image", "mask_img"]),
72
    ])
73
74
75
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
76
base_path="/home/lab/hualing/2.5_SUV_dilation"
77
pet_path = base_path + '/' + 'pet_test'
78
ct_path = base_path + '/' + 'ct_test'
79
mask_path = base_path + '/' + 'pet_test_mask'
80
PET_ids = sorted(glob(os.path.join(pet_path, '*pet.nii')))
81
CT_ids = sorted(glob(os.path.join(ct_path, '*ct.nii')))
82
MASK_ids = sorted(glob(os.path.join(mask_path, '*mask.nii')))
83
data_dicts= zip(PET_ids, CT_ids, MASK_ids)
84
files=list(data_dicts)
85
86
87
train_files = [{"pet_img": PET, "ct_img": CT, 'mask_img': MASK} for  PET, CT, MASK in files[:138]]
88
val_files = [{"pet_img": PET, "ct_img": CT, 'mask_img': MASK} for  PET, CT, MASK in files[138:156]]
89
test_files = [{"pet_img": PET, "ct_img": CT, 'mask_img': MASK} for  PET, CT, MASK in files[156:]]
90
91
92
train_ds = monai.data.Dataset(data=train_files,transform=train_transforms)
93
val_ds = monai.data.Dataset(data=val_files,transform=val_transforms)
94
test_ds = monai.data.Dataset(data=test_files,transform=val_transforms)
95
96
train_loader = DataLoader(
97
        train_ds,
98
        batch_size=6,
99
        shuffle=True,
100
        num_workers=4,
101
        collate_fn=list_data_collate,
102
        pin_memory=torch.cuda.is_available(),)
103
104
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
105
test_loader = DataLoader(test_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
106
107
108
trained_model_path="./ENN_best_metric_model_segmentation3d_dict_step1.pth" #####path to the ENN model after step 1
109
model = UNet_ENN_KMEANS(
110
        dimensions=3,  # 3D
111
        in_channels=2,
112
        out_channels=2,
113
        kernel_size=5,
114
        channels=(8,16, 32, 64,128),
115
        strides=(2, 2, 2, 2),
116
        num_res_units=2,).to(device)
117
model_dict = model.state_dict()
118
pre_dict = torch.load(trained_model_path)
119
pre_dict = {k: v for k, v in pre_dict.items() if k in model_dict}
120
model_dict.update(pre_dict)
121
122
model.load_state_dict(model_dict)
123
params = filter(lambda p: p.requires_grad, model.parameters())
124
for name, param in model.named_parameters():
125
    if param.requires_grad==True:
126
        print(name)  ####code to make sure the parameters from the whole model are optimized
127
128
optimizer = torch.optim.Adam(params, 1e-4)      ####finetune the whole models
129
130
dice_metric = monai.metrics.DiceMetric( include_background=False,reduction="mean")
131
scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,'min',patience=10)
132
loss_function = monai.losses.DiceLoss(include_background=False,softmax=False,squared_pred=True,to_onehot_y=True)
133
134
135
# TODO : generate a learning rate scheduler
136
val_interval = 1
137
best_metric = -1
138
best_metric_epoch = -1
139
epoch_loss_values = list()
140
metric_values = list()
141
142
writer = SummaryWriter()
143
post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2)
144
post_label = AsDiscrete(to_onehot=True, n_classes=2)
145
146
147
148
for epoch in range(100):
149
    print("-" * 10)
150
    print(f"epoch {epoch + 1}/{100}")
151
    model.train()
152
    epoch_loss = 0
153
    step = 0
154
    for batch_data in train_loader:
155
        step += 1
156
        inputs, labels = batch_data["image"].to(device), batch_data["mask_img"].to(device)
157
        optimizer.zero_grad()
158
        outputs = model(inputs)
159
160
        output=outputs[:, :2, :, :, :]+0.5*outputs[:, 2, :, :, :].unsqueeze(1)
161
        dice_loss=loss_function(output, labels)
162
163
164
        loss = dice_loss
165
        loss.backward()
166
        optimizer.step()
167
        epoch_loss += loss.item()
168
        epoch_len = len(train_ds) // train_loader.batch_size
169
        print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
170
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
171
    epoch_loss /= step
172
    epoch_loss_values.append(epoch_loss)
173
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
174
    scheduler.step(epoch_loss)
175
    if (epoch + 1) % val_interval == 0:
176
        model.eval()
177
        with torch.no_grad():
178
            metric_sum = 0.0
179
180
            metric_count = 0
181
182
            val_images = None
183
            val_labels = None
184
            val_outputs = None
185
            for val_data in val_loader:
186
                val_images, val_labels = val_data["image"].to(device), val_data["mask_img"].to(device)
187
                output = model(val_images)
188
                val_outputs = output[:, :2, :, :, :]+0.5*output[:, 2, :, :, :].unsqueeze(1)
189
                
190
                value = dice_metric(y_pred=val_outputs, y=val_labels)
191
192
                metric_count += len(value)
193
                metric_sum += value.item() * len(value)
194
195
            
196
            metric = metric_sum / metric_count
197
            metric_values.append(metric)
198
            if metric > best_metric:
199
                best_metric = metric
200
                best_metric_epoch = epoch + 1
201
                torch.save(model.state_dict(), "ENN_best_metric_model_segmentation3d_dict_step2.pth")
202
                print("saved new best metric model")
203
            print(
204
                "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
205
                    epoch + 1, metric, best_metric, best_metric_epoch
206
                )
207
            )
208
209
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
210
writer.close()
211
212
213
214
# LEARNING PROCEDURE
215
216
217
model.load_state_dict(torch.load("ENN_best_metric_model_segmentation3d_dict_step2.pth"))
218
model.eval()
219
220
221
###########################  RULES   #############################################
222
223
PREDICTION_VALIDATION_SET = True
224
#path_results = '/home/tongxue/huangling/code-hl/(IJAR)hl_medical-segmentation-master/result'
225
path_results='/home/lab/hualing/(IJAR_new)hl_medical-segmentation-master/result_enn_kmeans'
226
# generates folders
227
228
if not os.path.exists(path_results):
229
    os.makedirs(path_results)
230
231
232
##################
233
234
def PREDICT_MASK(data_set_ids, path_predictions, model):
235
    # generates folder
236
    if not os.path.exists(path_predictions):
237
        os.makedirs(path_predictions)
238
239
    filenames_predicted_masks = []
240
    n_patients = len(val_ds)
241
    val_loader = DataLoader(data_set_ids, batch_size=1, num_workers=4, collate_fn=list_data_collate)
242
243
    metric_sum = 0.0
244
    metric_sum_sen = 0.0
245
    metric_sum_spe = 0.0
246
    metric_sum_pre = 0.0
247
248
    metric_count =0
249
    os.chdir(r'/home/lab/hualing/(IJAR_new)hl_medical-segmentation-master/result_enn_kmeans')
250
    # for i,data_set_id in enumerate(data_set_ids):
251
    for i,val_data in enumerate(val_loader):
252
        val_images, val_labels = val_data["image"].to(device), val_data["mask_img"].to(device)
253
        prediction = model(val_images)
254
255
        pm=prediction
256
        #####save mass to .npy##########
257
        mass_out=prediction.data.cpu().numpy()
258
        name=splitext(basename(test_files[i]["mask_img"]))[0]
259
        val_outputs = prediction[:, :2, :, :, :]+0.5*prediction[:, 2, :, :, :].unsqueeze(1)
260
        np.save(name, mass_out)
261
262
        #####save results to .nii##########
263
        prediction = torch.argmax(prediction, axis=1)
264
        prediction=prediction.permute(0,3,1,2)# output from a multiclass softmax
265
        prediction = prediction.squeeze().cpu().numpy()
266
        # conversion in unsigned int 8 to store mask with less memory requirement
267
        mask = np.asarray(prediction, dtype=np.uint8)
268
269
        new_filename = path_predictions + "/pred_" + splitext(basename(test_files[i]["mask_img"]))[0] + '.nii'
270
        filenames_predicted_masks.append(new_filename)
271
        sitk.WriteImage(sitk.GetImageFromArray(mask), new_filename)
272
273
274
        ########calculate sen,spe,pre,acc,f1#########
275
        value = dice_metric(y_pred=val_outputs, y=val_labels)
276
277
        val_outputs=torch.argmax(pm, axis=1)
278
        val_outputs=val_outputs.unsqueeze(1)
279
280
        sensitivity = monai.metrics.compute_confusion_metric(y_pred=val_outputs, y=val_labels, to_onehot_y=False,
281
                                                             metric_name='sensitivity')
282
        specificity = monai.metrics.compute_confusion_metric(y_pred=val_outputs, y=val_labels, to_onehot_y=False,
283
                                                             metric_name='specificity')
284
285
        precision = monai.metrics.compute_confusion_metric(y_pred=val_outputs, y=val_labels, to_onehot_y=False,
286
                                                           metric_name='precision')
287
288
        print(len(value))
289
        metric_count += len(value)
290
        metric_sum += value.item() * len(value)
291
        metric_sum_sen += sensitivity.item() * len(value)
292
        metric_sum_spe += specificity.item() * len(value)
293
        metric_sum_pre += precision.item() * len(value)
294
295
296
    metric_dice = metric_sum / metric_count
297
    metric_sen = metric_sum_sen / metric_count
298
    metric_spe = metric_sum_spe / metric_count
299
    metric_pre = metric_sum_pre / metric_count
300
301
302
    print("dice:", metric_dice)
303
    print("sen:", metric_sen)
304
    print("spe",metric_spe)
305
    print("pre",metric_pre)
306
307
308
    return filenames_predicted_masks
309
310
    ####################################################################################################
311
312
###########################  Testing and visulation   #############################################
313
314
if PREDICTION_VALIDATION_SET:
315
    print("Prediction on validation set :")
316
    # use to fine tune and evaluate model performances
317
318
    print("Generating predictions :")
319
    valid_prediction_ids = PREDICT_MASK(data_set_ids=test_ds,
320
                                                 path_predictions=path_results + '/valid_predictions',
321
                                                 model=model)
322
    print("fini")
323