|
a |
|
b/U-Net/train_blood.py |
|
|
1 |
import argparse |
|
|
2 |
import logging |
|
|
3 |
import os |
|
|
4 |
import random |
|
|
5 |
import sys |
|
|
6 |
import copy |
|
|
7 |
import torch |
|
|
8 |
import torch.nn as nn |
|
|
9 |
import torch.nn.functional as F |
|
|
10 |
import torchvision.transforms as transforms |
|
|
11 |
import torchvision.transforms.functional as TF |
|
|
12 |
import torchvision.models as models |
|
|
13 |
|
|
|
14 |
from pathlib import Path |
|
|
15 |
from torch import optim |
|
|
16 |
from torch.utils.data import DataLoader, random_split |
|
|
17 |
from tqdm import tqdm |
|
|
18 |
|
|
|
19 |
from evaluate import evaluate |
|
|
20 |
from unet.unet_model import UNet |
|
|
21 |
from utils.data_loading import BasicDataset, CarvanaDataset |
|
|
22 |
from utils.dice_score import dice_loss |
|
|
23 |
|
|
|
24 |
import segmentation_models_pytorch as smp |
|
|
25 |
|
|
|
26 |
device = torch.device('cuda:4' if torch.cuda.is_available() else 'cpu') |
|
|
27 |
|
|
|
28 |
dir_img = Path('./data/train/imgs/') |
|
|
29 |
dir_mask = Path('./data/train/masks/') |
|
|
30 |
dir_checkpoint = Path('./checkpoints') |
|
|
31 |
|
|
|
32 |
def train_model( |
|
|
33 |
model, device, epochs, batch_size, learning_rate, |
|
|
34 |
val_percent: float = 0.1, |
|
|
35 |
save_checkpoint: bool = True, |
|
|
36 |
img_scale: float = 0.5, |
|
|
37 |
amp: bool = False, |
|
|
38 |
weight_decay: float = 1e-8, |
|
|
39 |
momentum: float = 0.5, |
|
|
40 |
gradient_clipping: float = 1.0 |
|
|
41 |
): |
|
|
42 |
|
|
|
43 |
best_model_params = copy.deepcopy(model.state_dict()) |
|
|
44 |
best_acc = 0.0 |
|
|
45 |
best_epoch = 0 |
|
|
46 |
|
|
|
47 |
# 1. Create dataset |
|
|
48 |
data_transform = transforms.Compose([ |
|
|
49 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
50 |
]) |
|
|
51 |
|
|
|
52 |
try: |
|
|
53 |
dataset = CarvanaDataset(dir_img, dir_mask, img_scale) |
|
|
54 |
except (AssertionError, RuntimeError, IndexError): |
|
|
55 |
dataset = BasicDataset(dir_img, dir_mask, img_scale, data_transform) |
|
|
56 |
|
|
|
57 |
# 2. Split into train / validation partitions |
|
|
58 |
n_val = int(len(dataset) * val_percent) |
|
|
59 |
n_train = len(dataset) - n_val |
|
|
60 |
|
|
|
61 |
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0)) |
|
|
62 |
|
|
|
63 |
# 3. Create data loaders |
|
|
64 |
loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True) |
|
|
65 |
train_loader = DataLoader(train_set, shuffle=True, **loader_args) |
|
|
66 |
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args) |
|
|
67 |
|
|
|
68 |
logging.info(f'''Starting training: |
|
|
69 |
Epochs: {epochs} |
|
|
70 |
Batch size: {batch_size} |
|
|
71 |
Learning rate: {learning_rate} |
|
|
72 |
Training size: {n_train} |
|
|
73 |
Validation size: {n_val} |
|
|
74 |
Checkpoints: {save_checkpoint} |
|
|
75 |
Device: {device.type} |
|
|
76 |
Images scaling: {img_scale} |
|
|
77 |
Mixed Precision: {amp} |
|
|
78 |
''') |
|
|
79 |
|
|
|
80 |
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP |
|
|
81 |
optimizer = optim.Adam(model.parameters(), |
|
|
82 |
lr=learning_rate, weight_decay=weight_decay) |
|
|
83 |
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5) # goal: maximize Dice score |
|
|
84 |
grad_scaler = torch.cuda.amp.GradScaler(enabled=amp) |
|
|
85 |
criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss() |
|
|
86 |
global_step = 0 |
|
|
87 |
|
|
|
88 |
# 5. Begin training |
|
|
89 |
for epoch in range(1, epochs + 1): |
|
|
90 |
model.train() |
|
|
91 |
epoch_loss = 0 |
|
|
92 |
with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar: |
|
|
93 |
for batch in train_loader: |
|
|
94 |
images, true_masks = batch['image'], batch['mask'] |
|
|
95 |
|
|
|
96 |
assert images.shape[1] == model.n_channels, \ |
|
|
97 |
f'Network has been defined with {model.n_channels} input channels, ' \ |
|
|
98 |
f'but loaded images have {images.shape[1]} channels. Please check that ' \ |
|
|
99 |
'the images are loaded correctly.' |
|
|
100 |
|
|
|
101 |
images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last) |
|
|
102 |
true_masks = true_masks.to(device=device, dtype=torch.long) |
|
|
103 |
|
|
|
104 |
with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp): |
|
|
105 |
masks_pred = model(images) |
|
|
106 |
if model.n_classes == 1: |
|
|
107 |
loss = criterion(masks_pred.squeeze(1), true_masks.float()) |
|
|
108 |
loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False) |
|
|
109 |
else: |
|
|
110 |
loss = criterion(masks_pred, true_masks) |
|
|
111 |
loss += dice_loss( |
|
|
112 |
F.softmax(masks_pred, dim=1).float(), |
|
|
113 |
F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(), |
|
|
114 |
multiclass=True |
|
|
115 |
) |
|
|
116 |
|
|
|
117 |
optimizer.zero_grad(set_to_none=True) |
|
|
118 |
grad_scaler.scale(loss).backward() |
|
|
119 |
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping) |
|
|
120 |
grad_scaler.step(optimizer) |
|
|
121 |
grad_scaler.update() |
|
|
122 |
|
|
|
123 |
pbar.update(images.shape[0]) |
|
|
124 |
global_step += 1 |
|
|
125 |
epoch_loss += loss.item() |
|
|
126 |
pbar.set_postfix(**{'loss (batch)': loss.item()}) |
|
|
127 |
|
|
|
128 |
# Evaluation round |
|
|
129 |
division_step = (n_train // (5 * batch_size)) |
|
|
130 |
if division_step > 0: |
|
|
131 |
if global_step % division_step == 0: |
|
|
132 |
|
|
|
133 |
val_score = evaluate(model, val_loader, device, amp) |
|
|
134 |
scheduler.step(val_score) |
|
|
135 |
|
|
|
136 |
logging.info('Validation Dice score: {}'.format(val_score)) |
|
|
137 |
|
|
|
138 |
# Check best accuracy model ( but not the best on test ) |
|
|
139 |
if val_score > best_acc: |
|
|
140 |
best_acc = val_score |
|
|
141 |
best_epoch = epoch |
|
|
142 |
best_model_params = copy.deepcopy(model.state_dict()) |
|
|
143 |
logging.info("Best model: [" + f'epoch: {best_epoch}, acc: {best_acc:.4f}]') |
|
|
144 |
|
|
|
145 |
|
|
|
146 |
if save_checkpoint: |
|
|
147 |
Path(dir_checkpoint).mkdir(parents=True, exist_ok=True) |
|
|
148 |
state_dict = model.state_dict() |
|
|
149 |
state_dict['mask_values'] = dataset.mask_values |
|
|
150 |
torch.save(state_dict, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch))) |
|
|
151 |
logging.info(f'Checkpoint {epoch} saved!') |
|
|
152 |
|
|
|
153 |
# only weight |
|
|
154 |
torch.save(best_model_params, f'epoch_{best_epoch}_acc_{best_acc:.2f}_best_val_acc.pth') |
|
|
155 |
logging.info("Best model name : " + f'epoch_{best_epoch}_acc_{best_acc:.2f}_best_val_acc.pth') |
|
|
156 |
|
|
|
157 |
|
|
|
158 |
def get_args(): |
|
|
159 |
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks') |
|
|
160 |
parser.add_argument('--epochs', '-e', metavar='E', type=int, default=50, help='Number of epochs') |
|
|
161 |
parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=16, help='Batch size') |
|
|
162 |
parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=0.001, |
|
|
163 |
help='Learning rate', dest='lr') |
|
|
164 |
parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file') |
|
|
165 |
parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images') |
|
|
166 |
parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0, |
|
|
167 |
help='Percent of the data that is used as validation (0-100)') |
|
|
168 |
parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision') |
|
|
169 |
parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling') |
|
|
170 |
parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes') |
|
|
171 |
|
|
|
172 |
return parser.parse_args() |
|
|
173 |
|
|
|
174 |
|
|
|
175 |
if __name__ == '__main__': |
|
|
176 |
args = get_args() |
|
|
177 |
|
|
|
178 |
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') |
|
|
179 |
#device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
|
180 |
logging.info(f'Using device {device}') |
|
|
181 |
|
|
|
182 |
""" |
|
|
183 |
Change here to adapt to your data |
|
|
184 |
n_channels=3 for RGB images |
|
|
185 |
n_classes is the number of probabilities you want to get per pixel |
|
|
186 |
""" |
|
|
187 |
model = UNet(n_channels=1, n_classes=5, bilinear=True) |
|
|
188 |
|
|
|
189 |
model = model.to(memory_format=torch.channels_last) |
|
|
190 |
|
|
|
191 |
logging.info(f'Network:\n' |
|
|
192 |
f'\t{model.n_channels} input channels\n' |
|
|
193 |
f'\t{model.n_classes} output channels (classes)\n' |
|
|
194 |
f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling') |
|
|
195 |
|
|
|
196 |
if args.load: |
|
|
197 |
state_dict = torch.load(args.load, map_location=device) |
|
|
198 |
del state_dict['mask_values'] |
|
|
199 |
model.load_state_dict(state_dict) |
|
|
200 |
logging.info(f'Model loaded from {args.load}') |
|
|
201 |
|
|
|
202 |
model.to(device=device) |
|
|
203 |
|
|
|
204 |
train_model( |
|
|
205 |
model=model, |
|
|
206 |
epochs=args.epochs, |
|
|
207 |
batch_size=args.batch_size, |
|
|
208 |
learning_rate=args.lr, |
|
|
209 |
device=device, |
|
|
210 |
img_scale=args.scale, |
|
|
211 |
val_percent=args.val / 100, |
|
|
212 |
amp=args.amp |
|
|
213 |
) |
|
|
214 |
|
|
|
215 |
|
|
|
216 |
|
|
|
217 |
|
|
|
218 |
|
|
|
219 |
|
|
|
220 |
|
|
|
221 |
|
|
|
222 |
|
|
|
223 |
|
|
|
224 |
|
|
|
225 |
|
|
|
226 |
|
|
|
227 |
|
|
|
228 |
|
|
|
229 |
|
|
|
230 |
|
|
|
231 |
|
|
|
232 |
|
|
|
233 |
|
|
|
234 |
|
|
|
235 |
|
|
|
236 |
|
|
|
237 |
|
|
|
238 |
|
|
|
239 |
|
|
|
240 |
|
|
|
241 |
|
|
|
242 |
|
|
|
243 |
|