a b/Evidential_segmentation/TRAINING-ENN.py
1
#!/usr/bin/env python
2
# coding: utf-8
3
4
# In[]:
5
6
###########################  IMPORTS   ############################################# 
7
8
import torch
9
import torch.nn as nn
10
import torch.nn.functional as F
11
import monai
12
from monai.networks.nets import UNet,VNet,DynUNet,UNet_ENN
13
from monai.networks.utils import one_hot
14
from monai.transforms import (
15
    AsDiscrete,
16
    AddChanneld,
17
    AsChannelFirstd,
18
    Compose,
19
    LoadNiftid,
20
    RandCropByPosNegLabeld,
21
    RandRotate90d,
22
    ScaleIntensityd,
23
    ToTensord,
24
)
25
from monai.visualize import plot_2d_or_3d_image
26
from monai.data.utils import list_data_collate, worker_init_fn
27
from monai.inferers import sliding_window_inference
28
from monai.metrics import DiceMetric
29
from monai.metrics import compute_meandice
30
from torch.autograd import Variable
31
import torch.optim as optim
32
from torch.optim import lr_scheduler
33
from torch.utils.data import DataLoader
34
from torch.utils.data import Dataset
35
import torch.backends.cudnn as cudnn
36
import torchvision
37
from torchvision import datasets, models, transforms
38
import csv
39
import time
40
import SimpleITK as sitk
41
from os.path import splitext,basename
42
import random
43
from glob import glob
44
45
import matplotlib.pyplot as plt
46
from matplotlib.backends.backend_pdf import PdfPages
47
from copy import copy
48
import os
49
import numpy as np
50
from torch.utils.tensorboard import SummaryWriter
51
52
#from global_tools.tools import display_loading_bar
53
from class_modalities.transforms import LoadNifti, Roi2Mask, ResampleReshapeAlign, Sitk2Numpy, ConcatModality
54
from monai.utils import first, set_determinism
55
56
57
train_transforms = Compose(
58
    [  # read img + meta info
59
        LoadNifti(keys=["pet_img", "ct_img", "mask_img"]),
60
        Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']),
61
        # user can also add other random transforms
62
63
        # Prepare for neural network
64
        ConcatModality(keys=['pet_img', 'ct_img']),
65
        AddChanneld(keys=["mask_img"]),  # Add channel to the first axis
66
        ToTensord(keys=["image", "mask_img"]),
67
    ])
68
# without data augmentation for validation
69
val_transforms = Compose(
70
    [  # read img + meta info
71
        LoadNifti(keys=["pet_img", "ct_img", "mask_img"]),
72
        Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']),
73
        # Prepare for neural network
74
        ConcatModality(keys=['pet_img', 'ct_img']),
75
        AddChanneld(keys=["mask_img"]),  # Add channel to the first axis
76
        ToTensord(keys=["image", "mask_img"]),
77
    ])
78
79
80
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
81
82
base_path="/home/lab/hualing/2.5_SUV_dilation" #####the path you put the pre_processed data.
83
pet_path = base_path + '/' + 'pet_test'
84
ct_path = base_path + '/' + 'ct_test'
85
mask_path = base_path + '/' + 'pet_test_mask'
86
PET_ids = sorted(glob(os.path.join(pet_path, '*pet.nii')))
87
CT_ids = sorted(glob(os.path.join(ct_path, '*ct.nii')))
88
MASK_ids = sorted(glob(os.path.join(mask_path, '*mask.nii')))
89
90
data_dicts= zip(PET_ids, CT_ids, MASK_ids)
91
files=list(data_dicts)
92
93
94
train_files = [{"pet_img": PET, "ct_img": CT, 'mask_img': MASK} for  PET, CT, MASK in files[:138]]
95
val_files = [{"pet_img": PET, "ct_img": CT, 'mask_img': MASK} for  PET, CT, MASK in files[138:156]]
96
test_files = [{"pet_img": PET, "ct_img": CT, 'mask_img': MASK} for  PET, CT, MASK in files[156:]]
97
98
99
train_ds = monai.data.Dataset(data=train_files,transform=train_transforms)
100
val_ds = monai.data.Dataset(data=val_files,transform=val_transforms)
101
test_ds = monai.data.Dataset(data=test_files,transform=val_transforms)
102
103
train_loader = DataLoader(
104
        train_ds,
105
        batch_size=3,
106
        shuffle=True,
107
        num_workers=4,
108
        collate_fn=list_data_collate,
109
        pin_memory=torch.cuda.is_available(),)
110
111
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
112
test_loader = DataLoader(test_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
113
114
115
116
trained_model_path="./pre-trained_model/best_metric_model_unet.pth"      ######path to the pretrained UNet model
117
model = UNet_ENN(
118
        dimensions=3,  # 3D
119
        in_channels=2,
120
        out_channels=2,
121
        kernel_size=5,
122
        channels=(8,16, 32, 64,128),
123
        strides=(2, 2, 2, 2),
124
        num_res_units=2,).to(device)
125
model_dict = model.state_dict()
126
pre_dict = torch.load(trained_model_path)
127
pre_dict = {k: v for k, v in pre_dict.items() if k in model_dict}
128
model_dict.update(pre_dict)
129
130
model.load_state_dict(model_dict)
131
132
133
####code to make sure only the parameters from ENN are optimized
134
135
for name, param in model.named_parameters():
136
    if param.requires_grad==True:
137
        print(name)
138
####code to make sure only the parameters from ENN are optimized
139
140
params = filter(lambda p: p.requires_grad, model.parameters())
141
optimizer = torch.optim.Adam(params, 1e-2)
142
dice_metric = monai.metrics.DiceMetric( include_background=False,reduction="mean")
143
scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,'min',patience=10)
144
loss_function = monai.losses.DiceLoss(include_background=False,softmax=False,squared_pred=True,to_onehot_y=True)
145
146
147
# TODO : generate a learning rate scheduler
148
149
val_interval = 1
150
best_metric = -1
151
best_metric_epoch = -1
152
epoch_loss_values = list()
153
metric_values = list()
154
155
writer = SummaryWriter()
156
157
158
for epoch in range(100):
159
    print("-" * 10)
160
    print(f"epoch {epoch + 1}/{100}")
161
    model.train()
162
    epoch_loss = 0
163
    step = 0
164
    for batch_data in train_loader:
165
        step += 1
166
        inputs, labels = batch_data["image"].to(device), batch_data["mask_img"].to(device)
167
        optimizer.zero_grad()
168
        outputs = model(inputs)
169
170
        output=outputs[:, :2, :, :, :]+0.5*outputs[:, 2, :, :, :].unsqueeze(1)
171
        loss=loss_function(output, labels)
172
        loss.backward()
173
        optimizer.step()
174
        epoch_loss += loss.item()
175
        epoch_len = len(train_ds) // train_loader.batch_size
176
        print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
177
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
178
    epoch_loss /= step
179
    epoch_loss_values.append(epoch_loss)
180
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
181
    scheduler.step(epoch_loss)
182
    if (epoch + 1) % val_interval == 0:
183
        model.eval()
184
        with torch.no_grad():
185
            metric_sum = 0.0
186
            metric_count = 0
187
188
            for val_data in val_loader:
189
                val_images, val_labels = val_data["image"].to(device), val_data["mask_img"].to(device)
190
                output = model(val_images)
191
                val_outputs = output[:, :2, :, :, :]+0.5*output[:, 2, :, :, :].unsqueeze(1)
192
                value = dice_metric(y_pred=val_outputs, y=val_labels)
193
                metric_count += len(value)
194
                metric_sum += value.item() * len(value)
195
196
            metric = metric_sum / metric_count
197
            metric_values.append(metric)
198
            if metric > best_metric:
199
                best_metric = metric
200
                best_metric_epoch = epoch + 1
201
                torch.save(model.state_dict(), "ENN_best_metric_model_segmentation3d_dict.pth")
202
                print("saved new best metric model")
203
            print(
204
                "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
205
                    epoch + 1, metric, best_metric, best_metric_epoch
206
                )
207
            )
208
209
210
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
211
writer.close()