[ef4563]: / train.py

Download this file

75 lines (55 with data), 2.5 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import math
import torch
from config import (
TRAINING_EPOCH, NUM_CLASSES, IN_CHANNELS, BCE_WEIGHTS, BACKGROUND_AS_CLASS, TRAIN_CUDA
)
from torch.nn import CrossEntropyLoss
from dataset import get_train_val_test_Dataloaders
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from unet3d import UNet3D
from transforms import (train_transform, train_transform_cuda,
val_transform, val_transform_cuda)
if BACKGROUND_AS_CLASS: NUM_CLASSES += 1
writer = SummaryWriter("runs")
model = UNet3D(in_channels=IN_CHANNELS , num_classes= NUM_CLASSES)
train_transforms = train_transform
val_transforms = val_transform
if torch.cuda.is_available() and TRAIN_CUDA:
model = model.cuda()
train_transforms = train_transform_cuda
val_transforms = val_transform_cuda
elif not torch.cuda.is_available() and TRAIN_CUDA:
print('cuda not available! Training initialized on cpu ...')
train_dataloader, val_dataloader, _ = get_train_val_test_Dataloaders(train_transforms= train_transforms, val_transforms=val_transforms, test_transforms= val_transforms)
criterion = CrossEntropyLoss(weight=torch.Tensor(BCE_WEIGHTS))
optimizer = Adam(params=model.parameters())
min_valid_loss = math.inf
for epoch in range(TRAINING_EPOCH):
train_loss = 0.0
model.train()
for data in train_dataloader:
image, ground_truth = data['image'], data['label']
optimizer.zero_grad()
target = model(image)
loss = criterion(target, ground_truth)
loss.backward()
optimizer.step()
train_loss += loss.item()
valid_loss = 0.0
model.eval()
for data in val_dataloader:
image, ground_truth = data['image'], data['label']
target = model(image)
loss = criterion(target,ground_truth)
valid_loss = loss.item()
writer.add_scalar("Loss/Train", train_loss / len(train_dataloader), epoch)
writer.add_scalar("Loss/Validation", valid_loss / len(val_dataloader), epoch)
print(f'Epoch {epoch+1} \t\t Training Loss: {train_loss / len(train_dataloader)} \t\t Validation Loss: {valid_loss / len(val_dataloader)}')
if min_valid_loss > valid_loss:
print(f'Validation Loss Decreased({min_valid_loss:.6f}--->{valid_loss:.6f}) \t Saving The Model')
min_valid_loss = valid_loss
# Saving State Dict
torch.save(model.state_dict(), f'checkpoints/epoch{epoch}_valLoss{min_valid_loss}.pth')
writer.flush()
writer.close()