Switch to unified view

a b/Evidential_segmentation/TRAINING-RBF_step1.py
1
#!/usr/bin/env python
2
# coding: utf-8
3
4
# In[]:
5
6
###########################  IMPORTS   ############################################# 
7
8
import torch
9
import torch.nn as nn
10
import torch.nn.functional as F
11
import monai
12
from monai.networks.nets import UNet,VNet,DynUNet,UNet_RBF,UNet_RBF_KMEANS
13
from monai.networks.utils import one_hot
14
from monai.transforms import (
15
    AsDiscrete,
16
    AddChanneld,
17
    AsChannelFirstd,
18
    Compose,
19
    LoadNiftid,
20
    RandCropByPosNegLabeld,
21
    RandRotate90d,
22
    ScaleIntensityd,
23
    ToTensord,
24
)
25
from monai.visualize import plot_2d_or_3d_image
26
from monai.data.utils import list_data_collate, worker_init_fn
27
from monai.inferers import sliding_window_inference
28
from monai.metrics import DiceMetric
29
from monai.metrics import compute_meandice
30
from torch.autograd import Variable
31
import torch.optim as optim
32
from torch.optim import lr_scheduler
33
from torch.utils.data import DataLoader
34
from torch.utils.data import Dataset
35
import torch.backends.cudnn as cudnn
36
import torchvision
37
from torchvision import datasets, models, transforms
38
import csv
39
import time
40
import SimpleITK as sitk
41
from os.path import splitext,basename
42
import random
43
from glob import glob
44
45
import matplotlib.pyplot as plt
46
from matplotlib.backends.backend_pdf import PdfPages
47
from copy import copy
48
import os
49
import numpy as np
50
from torch.utils.tensorboard import SummaryWriter
51
52
#from global_tools.tools import display_loading_bar
53
from class_modalities.transforms import LoadNifti, Roi2Mask, ResampleReshapeAlign, Sitk2Numpy, ConcatModality
54
from monai.utils import first, set_determinism
55
56
57
58
##################
59
train_transforms = Compose(
60
    [  # read img + meta info
61
        LoadNifti(keys=["pet_img", "ct_img", "mask_img"]),
62
        Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']),
63
        # user can also add other random transforms
64
65
        # Prepare for neural network
66
        ConcatModality(keys=['pet_img', 'ct_img']),
67
        AddChanneld(keys=["mask_img"]),  # Add channel to the first axis
68
        ToTensord(keys=["image", "mask_img"]),
69
    ])
70
# without data augmentation for validation
71
val_transforms = Compose(
72
    [  # read img + meta info
73
        LoadNifti(keys=["pet_img", "ct_img", "mask_img"]),
74
        Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']),
75
        # Prepare for neural network
76
        ConcatModality(keys=['pet_img', 'ct_img']),
77
        AddChanneld(keys=["mask_img"]),  # Add channel to the first axis
78
        ToTensord(keys=["image", "mask_img"]),
79
    ])
80
81
82
##################loading data###############################
83
84
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")
85
86
base_path="/home/lab/hualing/2.5_SUV_dilation" #####the path you put the pre_processed data.
87
pet_path = base_path + '/' + 'pet_test'
88
ct_path = base_path + '/' + 'ct_test'
89
mask_path = base_path + '/' + 'pet_test_mask'
90
PET_ids = sorted(glob(os.path.join(pet_path, '*pet.nii')))
91
CT_ids = sorted(glob(os.path.join(ct_path, '*ct.nii')))
92
MASK_ids = sorted(glob(os.path.join(mask_path, '*mask.nii')))
93
data_dicts= zip(PET_ids, CT_ids, MASK_ids)
94
files=list(data_dicts)
95
96
97
98
train_files = [{"pet_img": PET, "ct_img": CT, 'mask_img': MASK} for  PET, CT, MASK in files[:138]]
99
val_files = [{"pet_img": PET, "ct_img": CT, 'mask_img': MASK} for  PET, CT, MASK in files[138:156]]
100
test_files = [{"pet_img": PET, "ct_img": CT, 'mask_img': MASK} for  PET, CT, MASK in files[156:]]
101
102
103
train_ds = monai.data.Dataset(data=train_files,transform=train_transforms)
104
val_ds = monai.data.Dataset(data=val_files,transform=val_transforms)
105
test_ds = monai.data.Dataset(data=test_files,transform=val_transforms)
106
107
train_loader = DataLoader(
108
        train_ds,
109
        batch_size=5,
110
        shuffle=True,
111
        num_workers=4,
112
        collate_fn=list_data_collate,
113
        pin_memory=torch.cuda.is_available(),)
114
115
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
116
test_loader = DataLoader(test_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
117
118
119
###################################### defining model#######################################
120
121
122
trained_model_path="./pre-trained_model/best_metric_model_unet.pth"  ####path to the pre-trained UNET model
123
model =UNet_RBF_KMEANS(
124
        dimensions=3,  # 3D
125
        in_channels=2,
126
        out_channels=2,
127
        kernel_size=5,
128
        channels=(8,16, 32, 64,128),
129
        strides=(2, 2, 2, 2),
130
        num_res_units=2,).to(device)
131
model_dict = model.state_dict()
132
pre_dict = torch.load(trained_model_path)
133
pre_dict = {k: v for k, v in pre_dict.items() if k in model_dict}
134
model_dict.update(pre_dict)
135
136
model.load_state_dict(model_dict)
137
138
139
for name, param in model.named_parameters():
140
    if param.requires_grad==True:
141
        print(name)
142
params = filter(lambda p: p.requires_grad, model.parameters()) ####code to make sure only the parameters from ENN are optimized
143
144
145
146
optimizer = torch.optim.Adam(params, 1e-2)
147
dice_metric = monai.metrics.DiceMetric( include_background=False,reduction="mean")
148
scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,'min',patience=10)
149
loss_function = monai.losses.DiceLoss(include_background=False,softmax=False,squared_pred=True,to_onehot_y=True)
150
151
152
# TODO : generate a learning rate scheduler
153
154
val_interval = 1
155
best_metric = -1
156
best_metric_epoch = -1
157
epoch_loss_values = list()
158
metric_values = list()
159
160
writer = SummaryWriter()
161
post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2)
162
post_label = AsDiscrete(to_onehot=True, n_classes=2)
163
164
165
#############################################   training and validation#############################################
166
167
168
for epoch in range(100):
169
    print("-" * 10)
170
    print(f"epoch {epoch + 1}/{100}")
171
    model.train()
172
    epoch_loss = 0
173
    step = 0
174
    #scheduler = StepLR(optimizer, step_size=50, gamma=0.1)
175
    for batch_data in train_loader:
176
        step += 1
177
        inputs, labels = batch_data["image"].to(device), batch_data["mask_img"].to(device)
178
        optimizer.zero_grad()
179
        pm,mass= model(inputs)
180
181
        dice_loss=loss_function(pm, labels)
182
        loss = dice_loss
183
        loss.backward()
184
        optimizer.step()
185
        epoch_loss += loss.item()
186
        epoch_len = len(train_ds) // train_loader.batch_size
187
        print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
188
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
189
    epoch_loss /= step
190
    epoch_loss_values.append(epoch_loss)
191
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
192
    scheduler.step(epoch_loss)
193
    if (epoch + 1) % val_interval == 0:
194
        model.eval()
195
        with torch.no_grad():
196
            metric_sum = 0.0
197
            metric_count = 0
198
199
            val_images = None
200
            val_labels = None
201
            val_outputs = None
202
            for val_data in val_loader:
203
                val_images, val_labels = val_data["image"].to(device), val_data["mask_img"].to(device)
204
205
                pm,mass = model(val_images)
206
                val_outputs=pm
207
                output=pm
208
                value = dice_metric(y_pred=val_outputs, y=val_labels)
209
                metric_count += len(value)
210
                metric_sum += value.item() * len(value)
211
            
212
            metric = metric_sum / metric_count
213
            metric_values.append(metric)
214
215
            if metric > best_metric:
216
                best_metric = metric
217
                best_metric_epoch = epoch + 1
218
                torch.save(model.state_dict(), "RBF_best_metric_model_segmentation3d_dict_step1.pth")
219
                print("saved new best metric model")
220
            print(
221
                "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
222
                    epoch + 1, metric, best_metric, best_metric_epoch
223
                )
224
            )
225
226
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
227
writer.close()
228
229