[748954]: / train.py

Download this file

88 lines (71 with data), 3.4 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import json
from pathlib import Path
import torch
import torch.nn as nn
from torch.optim import Adam
from tqdm import tqdm
from argparse import Namespace
from data import GIImage, GIImageDataset, GIImageDataLoader, train_valid_split_cases
from utils import load_yaml, set_seed
from model import UNet
from evaluation import evaluate
def train(config: Namespace):
exp_path = Path(config.save_path) / config.exp_name
exp_path.mkdir(parents=True, exist_ok=True)
set_seed(config.seed)
model = UNet(n_classes=len(GIImage.organs))
train_cases, valid_cases = train_valid_split_cases(config.input_path, config.valid_size)
train_set = GIImageDataset(image_path=config.input_path, label_path=config.label_path, cases=train_cases)
valid_set = GIImageDataset(image_path=config.input_path, label_path=config.label_path, cases=valid_cases)
train_loader = GIImageDataLoader(
model=model,
dataset=train_set,
batch_size=config.batch_size,
shuffle=True,
input_resolution=config.input_resolution,
padding_mode=config.padding_mode
).get_data_loader()
valid_loader = GIImageDataLoader(
model=model,
dataset=valid_set,
batch_size=config.batch_size,
shuffle=False,
input_resolution=config.input_resolution,
padding_mode=config.padding_mode
).get_data_loader()
# TODO: use other loss functions and optimizer
scaler = torch.cuda.amp.GradScaler(enabled=config.use_fp16)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=config.lr)
model = model.to(config.device)
best_valid_loss = evaluate(model, valid_loader, criterion, config.device, config.use_fp16) # TODO: track model performance with other metrics
valid_losses = [best_valid_loss]
(exp_path / "valid_losses.json").write_text(json.dumps(valid_losses, indent=4))
for epoch in range(config.nepochs):
model.train()
pbar = tqdm(train_loader)
pbar.set_description(f"Epoch {epoch + 1}")
for i, (inputs, labels) in enumerate(train_loader):
inputs = inputs.to(config.device)
labels = labels.to(config.device)
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=config.use_fp16):
preds = model(inputs)
loss = criterion(preds, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
pbar.update()
if (i != 0 and i % config.valid_steps == 0) or (i == len(train_loader) - 1):
total_valid_loss = evaluate(model, valid_loader, criterion, config.device, config.use_fp16)
valid_losses.append(total_valid_loss)
print(f"Valid loss: {total_valid_loss:.4f}")
# save model if validation loss is improved
if total_valid_loss < best_valid_loss:
best_valid_loss = total_valid_loss
torch.save(model.state_dict(), exp_path / "model.pth")
(exp_path / "valid_losses.json").write_text(json.dumps(valid_losses, indent=4))
print(f"Model saved at {config.save_path}")
if __name__ == "__main__":
config = load_yaml("config.yml")
train(config)