Diff of /U-Net/train_blood.py [000000] .. [6f3ba0]

Switch to unified view

a b/U-Net/train_blood.py
1
import argparse
2
import logging
3
import os
4
import random
5
import sys
6
import copy
7
import torch
8
import torch.nn as nn
9
import torch.nn.functional as F
10
import torchvision.transforms as transforms
11
import torchvision.transforms.functional as TF
12
import torchvision.models as models
13
14
from pathlib import Path
15
from torch import optim
16
from torch.utils.data import DataLoader, random_split
17
from tqdm import tqdm
18
19
from evaluate import evaluate
20
from unet.unet_model import UNet
21
from utils.data_loading import BasicDataset, CarvanaDataset
22
from utils.dice_score import dice_loss
23
24
import segmentation_models_pytorch as smp
25
26
device = torch.device('cuda:4' if torch.cuda.is_available() else 'cpu')
27
28
dir_img = Path('./data/train/imgs/')
29
dir_mask = Path('./data/train/masks/')
30
dir_checkpoint = Path('./checkpoints')
31
32
def train_model(
33
        model, device, epochs, batch_size, learning_rate,
34
        val_percent: float = 0.1,
35
        save_checkpoint: bool = True,
36
        img_scale: float = 0.5,
37
        amp: bool = False,
38
        weight_decay: float = 1e-8,
39
        momentum: float = 0.5,
40
        gradient_clipping: float = 1.0
41
    ):
42
    
43
    best_model_params = copy.deepcopy(model.state_dict())
44
    best_acc = 0.0
45
    best_epoch = 0
46
47
    # 1. Create dataset    
48
    data_transform = transforms.Compose([
49
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
50
    ])
51
52
    try:
53
        dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
54
    except (AssertionError, RuntimeError, IndexError):
55
        dataset = BasicDataset(dir_img, dir_mask, img_scale, data_transform)
56
57
    # 2. Split into train / validation partitions
58
    n_val = int(len(dataset) * val_percent)
59
    n_train = len(dataset) - n_val
60
    
61
    train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))
62
63
    # 3. Create data loaders
64
    loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
65
    train_loader = DataLoader(train_set, shuffle=True, **loader_args)
66
    val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)
67
68
    logging.info(f'''Starting training:
69
        Epochs:          {epochs}
70
        Batch size:      {batch_size}
71
        Learning rate:   {learning_rate}
72
        Training size:   {n_train}
73
        Validation size: {n_val}
74
        Checkpoints:     {save_checkpoint}
75
        Device:          {device.type}
76
        Images scaling:  {img_scale}
77
        Mixed Precision: {amp}
78
    ''')
79
80
    # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
81
    optimizer = optim.Adam(model.parameters(),
82
                              lr=learning_rate, weight_decay=weight_decay)
83
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5)  # goal: maximize Dice score
84
    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
85
    criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss()
86
    global_step = 0
87
88
    # 5. Begin training
89
    for epoch in range(1, epochs + 1):
90
        model.train()
91
        epoch_loss = 0
92
        with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
93
            for batch in train_loader:
94
                images, true_masks = batch['image'], batch['mask']
95
96
                assert images.shape[1] == model.n_channels, \
97
                    f'Network has been defined with {model.n_channels} input channels, ' \
98
                    f'but loaded images have {images.shape[1]} channels. Please check that ' \
99
                    'the images are loaded correctly.'
100
101
                images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
102
                true_masks = true_masks.to(device=device, dtype=torch.long)
103
104
                with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
105
                    masks_pred = model(images)
106
                    if model.n_classes == 1:
107
                        loss = criterion(masks_pred.squeeze(1), true_masks.float())
108
                        loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False)
109
                    else:
110
                        loss = criterion(masks_pred, true_masks)
111
                        loss += dice_loss(
112
                            F.softmax(masks_pred, dim=1).float(),
113
                            F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(),
114
                            multiclass=True
115
                        )
116
117
                optimizer.zero_grad(set_to_none=True)
118
                grad_scaler.scale(loss).backward()
119
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
120
                grad_scaler.step(optimizer)
121
                grad_scaler.update()
122
123
                pbar.update(images.shape[0])
124
                global_step += 1
125
                epoch_loss += loss.item()
126
                pbar.set_postfix(**{'loss (batch)': loss.item()})
127
128
                # Evaluation round
129
                division_step = (n_train // (5 * batch_size))
130
                if division_step > 0:
131
                    if global_step % division_step == 0:
132
133
                        val_score = evaluate(model, val_loader, device, amp)
134
                        scheduler.step(val_score)
135
136
                        logging.info('Validation Dice score: {}'.format(val_score))
137
138
            # Check best accuracy model ( but not the best on test )
139
            if val_score > best_acc:
140
                best_acc = val_score
141
                best_epoch = epoch
142
                best_model_params = copy.deepcopy(model.state_dict())
143
            logging.info("Best model: [" + f'epoch: {best_epoch}, acc: {best_acc:.4f}]')
144
145
146
        if save_checkpoint:
147
            Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
148
            state_dict = model.state_dict()
149
            state_dict['mask_values'] = dataset.mask_values
150
            torch.save(state_dict, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch)))
151
            logging.info(f'Checkpoint {epoch} saved!')
152
        
153
    # only weight
154
    torch.save(best_model_params, f'epoch_{best_epoch}_acc_{best_acc:.2f}_best_val_acc.pth')
155
    logging.info("Best model name : " + f'epoch_{best_epoch}_acc_{best_acc:.2f}_best_val_acc.pth')
156
157
158
def get_args():
159
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
160
    parser.add_argument('--epochs', '-e', metavar='E', type=int, default=50, help='Number of epochs')
161
    parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=16, help='Batch size')
162
    parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=0.001,
163
                        help='Learning rate', dest='lr')
164
    parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
165
    parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
166
    parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
167
                        help='Percent of the data that is used as validation (0-100)')
168
    parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
169
    parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
170
    parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
171
172
    return parser.parse_args()
173
174
175
if __name__ == '__main__':
176
    args = get_args()
177
178
    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
179
    #device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
180
    logging.info(f'Using device {device}')
181
182
    """
183
    Change here to adapt to your data
184
    n_channels=3 for RGB images
185
    n_classes is the number of probabilities you want to get per pixel
186
    """
187
    model = UNet(n_channels=1, n_classes=5, bilinear=True)
188
189
    model = model.to(memory_format=torch.channels_last)
190
191
    logging.info(f'Network:\n'
192
                 f'\t{model.n_channels} input channels\n'
193
                 f'\t{model.n_classes} output channels (classes)\n'
194
                 f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling')
195
196
    if args.load:
197
        state_dict = torch.load(args.load, map_location=device)
198
        del state_dict['mask_values']
199
        model.load_state_dict(state_dict)
200
        logging.info(f'Model loaded from {args.load}')
201
202
    model.to(device=device)
203
    
204
    train_model(
205
        model=model,
206
        epochs=args.epochs,
207
        batch_size=args.batch_size,
208
        learning_rate=args.lr,
209
        device=device,
210
        img_scale=args.scale,
211
        val_percent=args.val / 100,
212
        amp=args.amp
213
    )
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243