a b/Evidential_segmentation/TRAINING-ENN_step1.py
1
#!/usr/bin/env python
2
# coding: utf-8
3
4
# In[]:
5
6
###########################  IMPORTS   ############################################# 
7
8
import torch
9
import torch.nn as nn
10
import torch.nn.functional as F
11
import monai
12
from monai.networks.nets import UNet,VNet,DynUNet,UNet_ENN_KMEANS
13
from monai.networks.utils import one_hot
14
from monai.transforms import (
15
    AsDiscrete,
16
    AddChanneld,
17
    AsChannelFirstd,
18
    Compose,
19
    LoadNiftid,
20
    RandCropByPosNegLabeld,
21
    RandRotate90d,
22
    ScaleIntensityd,
23
    ToTensord,
24
)
25
from monai.visualize import plot_2d_or_3d_image
26
from monai.data.utils import list_data_collate, worker_init_fn
27
from monai.inferers import sliding_window_inference
28
from monai.metrics import DiceMetric
29
from monai.metrics import compute_meandice
30
from torch.autograd import Variable
31
import torch.optim as optim
32
from torch.optim import lr_scheduler
33
from torch.utils.data import DataLoader
34
from torch.utils.data import Dataset
35
import torch.backends.cudnn as cudnn
36
import torchvision
37
from torchvision import datasets, models, transforms
38
import csv
39
import time
40
import SimpleITK as sitk
41
from os.path import splitext,basename
42
import random
43
from glob import glob
44
45
import matplotlib.pyplot as plt
46
from matplotlib.backends.backend_pdf import PdfPages
47
from copy import copy
48
import os
49
import numpy as np
50
from torch.utils.tensorboard import SummaryWriter
51
52
#from global_tools.tools import display_loading_bar
53
from class_modalities.transforms import LoadNifti, Roi2Mask, ResampleReshapeAlign, Sitk2Numpy, ConcatModality
54
from monai.utils import first, set_determinism
55
56
57
58
train_transforms = Compose(
59
    [  # read img + meta info
60
        LoadNifti(keys=["pet_img", "ct_img", "mask_img"]),
61
        Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']),
62
        ConcatModality(keys=['pet_img', 'ct_img']),
63
        AddChanneld(keys=["mask_img"]),  # Add channel to the first axis
64
        ToTensord(keys=["image", "mask_img"]),
65
    ])
66
# without data augmentation for validation
67
val_transforms = Compose(
68
    [  # read img + meta info
69
        LoadNifti(keys=["pet_img", "ct_img", "mask_img"]),
70
        Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']),
71
        ConcatModality(keys=['pet_img', 'ct_img']),
72
        AddChanneld(keys=["mask_img"]),  # Add channel to the first axis
73
        ToTensord(keys=["image", "mask_img"]),
74
    ])
75
76
77
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
78
base_path="/home/lab/hualing/2.5_SUV_dilation" #####the path you put the pre_processed data.
79
pet_path = base_path + '/' + 'pet_test'
80
ct_path = base_path + '/' + 'ct_test'
81
mask_path = base_path + '/' + 'pet_test_mask'
82
PET_ids = sorted(glob(os.path.join(pet_path, '*pet.nii')))
83
CT_ids = sorted(glob(os.path.join(ct_path, '*ct.nii')))
84
MASK_ids = sorted(glob(os.path.join(mask_path, '*mask.nii')))
85
86
data_dicts= zip(PET_ids, CT_ids, MASK_ids)
87
files=list(data_dicts)
88
89
train_files = [{"pet_img": PET, "ct_img": CT, 'mask_img': MASK} for  PET, CT, MASK in files[:138]]
90
val_files = [{"pet_img": PET, "ct_img": CT, 'mask_img': MASK} for  PET, CT, MASK in files[138:156]]
91
test_files = [{"pet_img": PET, "ct_img": CT, 'mask_img': MASK} for  PET, CT, MASK in files[156:]]
92
93
94
train_ds = monai.data.Dataset(data=train_files,transform=train_transforms)
95
val_ds = monai.data.Dataset(data=val_files,transform=val_transforms)
96
test_ds = monai.data.Dataset(data=test_files,transform=val_transforms)
97
98
train_loader = DataLoader(
99
        train_ds,
100
        batch_size=6,
101
        shuffle=True,
102
        num_workers=4,
103
        collate_fn=list_data_collate,
104
        pin_memory=torch.cuda.is_available(),)
105
106
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
107
test_loader = DataLoader(test_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
108
109
trained_model_path="./pre-trained_model/best_metric_model_unet.pth"      ######path to the pretrained UNet model
110
model = UNet_ENN_KMEANS(
111
        dimensions=3,  # 3D
112
        in_channels=2,
113
        out_channels=2,
114
        kernel_size=5,
115
        channels=(8,16, 32, 64,128),
116
        strides=(2, 2, 2, 2),
117
        num_res_units=2,).to(device)
118
model_dict = model.state_dict()
119
pre_dict = torch.load(trained_model_path)
120
pre_dict = {k: v for k, v in pre_dict.items() if k in model_dict}
121
model_dict.update(pre_dict)
122
123
model.load_state_dict(model_dict)
124
params = filter(lambda p: p.requires_grad, model.parameters())
125
for name, param in model.named_parameters():
126
    if param.requires_grad==True:
127
        print(name) ####code to make sure only the parameters from ENN are optimized
128
129
130
optimizer = torch.optim.Adam(params, 1e-2)    ####optimize only the ENN
131
dice_metric = monai.metrics.DiceMetric( include_background=False,reduction="mean")
132
scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,'min',patience=10)
133
loss_function = monai.losses.DiceLoss(include_background=False,softmax=False,squared_pred=True,to_onehot_y=True)
134
135
# TODO : generate a learning rate scheduler
136
137
val_interval = 1
138
best_metric = -1
139
best_metric_epoch = -1
140
epoch_loss_values = list()
141
metric_values = list()
142
143
144
writer = SummaryWriter()
145
post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2)
146
post_label = AsDiscrete(to_onehot=True, n_classes=2)
147
148
149
150
for epoch in range(100):
151
    print("-" * 10)
152
    print(f"epoch {epoch + 1}/{100}")
153
    model.train()
154
    epoch_loss = 0
155
    step = 0
156
    for batch_data in train_loader:
157
        step += 1
158
        inputs, labels = batch_data["image"].to(device), batch_data["mask_img"].to(device)
159
        optimizer.zero_grad()
160
        outputs = model(inputs)
161
        output=outputs[:, :2, :, :, :]+0.5*outputs[:, 2, :, :, :].unsqueeze(1)
162
        dice_loss=loss_function(output, labels)
163
164
165
        loss = dice_loss
166
        loss.backward()
167
        optimizer.step()
168
        epoch_loss += loss.item()
169
        epoch_len = len(train_ds) // train_loader.batch_size
170
        print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
171
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
172
    epoch_loss /= step
173
    epoch_loss_values.append(epoch_loss)
174
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
175
    scheduler.step(epoch_loss)
176
    if (epoch + 1) % val_interval == 0:
177
        model.eval()
178
        with torch.no_grad():
179
            metric_sum = 0.0
180
181
            metric_count = 0
182
183
            val_images = None
184
            val_labels = None
185
            val_outputs = None
186
            for val_data in val_loader:
187
                val_images, val_labels = val_data["image"].to(device), val_data["mask_img"].to(device)
188
                output = model(val_images)
189
                val_outputs = output[:, :2, :, :, :]+0.5*output[:, 2, :, :, :].unsqueeze(1)
190
                value = dice_metric(y_pred=val_outputs, y=val_labels)
191
                metric_count += len(value)
192
                metric_sum += value.item() * len(value)
193
194
            
195
            metric = metric_sum / metric_count
196
            metric_values.append(metric)
197
            if metric > best_metric:
198
                best_metric = metric
199
                best_metric_epoch = epoch + 1
200
                torch.save(model.state_dict(), "ENN_best_metric_model_segmentation3d_dict_step1.pth")
201
                print("saved new best metric model")
202
            print(
203
                "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
204
                    epoch + 1, metric, best_metric, best_metric_epoch
205
                )
206
            )
207
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
208
writer.close()