Diff of /train.py [000000] .. [748954]

Switch to unified view

a b/train.py
1
import json
2
from pathlib import Path
3
4
import torch
5
import torch.nn as nn
6
from torch.optim import Adam
7
from tqdm import tqdm
8
9
from argparse import Namespace
10
from data import GIImage, GIImageDataset, GIImageDataLoader, train_valid_split_cases
11
from utils import load_yaml, set_seed
12
from model import UNet
13
from evaluation import evaluate
14
15
def train(config: Namespace):
16
    exp_path = Path(config.save_path) / config.exp_name
17
    exp_path.mkdir(parents=True, exist_ok=True)
18
19
    set_seed(config.seed)
20
    
21
    model = UNet(n_classes=len(GIImage.organs))
22
    
23
    train_cases, valid_cases = train_valid_split_cases(config.input_path, config.valid_size)
24
    train_set = GIImageDataset(image_path=config.input_path, label_path=config.label_path, cases=train_cases)
25
    valid_set = GIImageDataset(image_path=config.input_path, label_path=config.label_path, cases=valid_cases)
26
27
    train_loader = GIImageDataLoader(
28
        model=model,
29
        dataset=train_set,
30
        batch_size=config.batch_size,
31
        shuffle=True,
32
        input_resolution=config.input_resolution,
33
        padding_mode=config.padding_mode
34
    ).get_data_loader()
35
36
    valid_loader = GIImageDataLoader(
37
        model=model,
38
        dataset=valid_set,
39
        batch_size=config.batch_size,
40
        shuffle=False,
41
        input_resolution=config.input_resolution,
42
        padding_mode=config.padding_mode
43
    ).get_data_loader()
44
    
45
    # TODO: use other loss functions and optimizer
46
    scaler = torch.cuda.amp.GradScaler(enabled=config.use_fp16)
47
    criterion = nn.CrossEntropyLoss()
48
    optimizer = Adam(model.parameters(), lr=config.lr)
49
    
50
    model = model.to(config.device)
51
52
    best_valid_loss = evaluate(model, valid_loader, criterion, config.device, config.use_fp16) # TODO: track model performance with other metrics
53
    valid_losses = [best_valid_loss]
54
    (exp_path / "valid_losses.json").write_text(json.dumps(valid_losses, indent=4))
55
    for epoch in range(config.nepochs):
56
        model.train()
57
        pbar = tqdm(train_loader)
58
        pbar.set_description(f"Epoch {epoch + 1}")
59
        for i, (inputs, labels) in enumerate(train_loader):
60
            inputs = inputs.to(config.device)
61
            labels = labels.to(config.device)
62
            
63
            with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=config.use_fp16):
64
                preds = model(inputs)
65
                loss = criterion(preds, labels)
66
            
67
            scaler.scale(loss).backward()
68
            scaler.step(optimizer)
69
            scaler.update()
70
            optimizer.zero_grad(set_to_none=True)
71
72
            pbar.update()
73
            
74
            if (i != 0 and i % config.valid_steps == 0) or (i == len(train_loader) - 1):
75
                total_valid_loss = evaluate(model, valid_loader, criterion, config.device, config.use_fp16)
76
                valid_losses.append(total_valid_loss)
77
                print(f"Valid loss: {total_valid_loss:.4f}")
78
                # save model if validation loss is improved
79
                if total_valid_loss < best_valid_loss:
80
                    best_valid_loss = total_valid_loss
81
                    torch.save(model.state_dict(), exp_path / "model.pth")
82
                    (exp_path / "valid_losses.json").write_text(json.dumps(valid_losses, indent=4))
83
                    print(f"Model saved at {config.save_path}")
84
85
if __name__ == "__main__":
86
    config = load_yaml("config.yml")
87
    train(config)