import shutil
import os
import time
from datetime import datetime
import argparse
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchsample.transforms import RandomRotate, RandomTranslate, RandomFlip, ToTensor, Compose, RandomAffine
from torchvision import transforms
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from dataloader import MRDataset
import model
from sklearn import metrics
def train_model(model, train_loader, epoch, num_epochs, optimizer, writer, current_lr, log_every=100):
_ = model.train()
if torch.cuda.is_available():
model.cuda()
y_preds = []
y_trues = []
losses = []
for i, (image, label, weight) in enumerate(train_loader):
optimizer.zero_grad()
if torch.cuda.is_available():
image = image.cuda()
label = label.cuda()
weight = weight.cuda()
label = label[0]
weight = weight[0]
prediction = model.forward(image.float())
loss = torch.nn.BCEWithLogitsLoss(weight=weight)(prediction, label)
loss.backward()
optimizer.step()
loss_value = loss.item()
losses.append(loss_value)
probas = torch.sigmoid(prediction)
y_trues.append(int(label[0][1]))
y_preds.append(probas[0][1].item())
try:
auc = metrics.roc_auc_score(y_trues, y_preds)
except:
auc = 0.5
writer.add_scalar('Train/Loss', loss_value,
epoch * len(train_loader) + i)
writer.add_scalar('Train/AUC', auc, epoch * len(train_loader) + i)
if (i % log_every == 0) & (i > 0):
print('''[Epoch: {0} / {1} |Single batch number : {2} / {3} ]| avg train loss {4} | train auc : {5} | lr : {6}'''.
format(
epoch + 1,
num_epochs,
i,
len(train_loader),
np.round(np.mean(losses), 4),
np.round(auc, 4),
current_lr
)
)
writer.add_scalar('Train/AUC_epoch', auc, epoch + i)
train_loss_epoch = np.round(np.mean(losses), 4)
train_auc_epoch = np.round(auc, 4)
return train_loss_epoch, train_auc_epoch
def evaluate_model(model, val_loader, epoch, num_epochs, writer, current_lr, log_every=20):
_ = model.eval()
if torch.cuda.is_available():
model.cuda()
y_trues = []
y_preds = []
losses = []
for i, (image, label, weight) in enumerate(val_loader):
if torch.cuda.is_available():
image = image.cuda()
label = label.cuda()
weight = weight.cuda()
label = label[0]
weight = weight[0]
prediction = model.forward(image.float())
loss = torch.nn.BCEWithLogitsLoss(weight=weight)(prediction, label)
loss_value = loss.item()
losses.append(loss_value)
probas = torch.sigmoid(prediction)
y_trues.append(int(label[0][1]))
y_preds.append(probas[0][1].item())
try:
auc = metrics.roc_auc_score(y_trues, y_preds)
except:
auc = 0.5
writer.add_scalar('Val/Loss', loss_value, epoch * len(val_loader) + i)
writer.add_scalar('Val/AUC', auc, epoch * len(val_loader) + i)
if (i % log_every == 0) & (i > 0):
print('''[Epoch: {0} / {1} |Single batch number : {2} / {3} ] | avg val loss {4} | val auc : {5} | lr : {6}'''.
format(
epoch + 1,
num_epochs,
i,
len(val_loader),
np.round(np.mean(losses), 4),
np.round(auc, 4),
current_lr
)
)
writer.add_scalar('Val/AUC_epoch', auc, epoch + i)
val_loss_epoch = np.round(np.mean(losses), 4)
val_auc_epoch = np.round(auc, 4)
return val_loss_epoch, val_auc_epoch
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group['lr']
def run(args):
log_root_folder = "./logs/{0}/{1}/".format(args.task, args.plane)
if args.flush_history == 1:
objects = os.listdir(log_root_folder)
for f in objects:
if os.path.isdir(log_root_folder + f):
shutil.rmtree(log_root_folder + f)
now = datetime.now()
logdir = log_root_folder + now.strftime("%Y%m%d-%H%M%S") + "/"
os.makedirs(logdir)
writer = SummaryWriter(logdir)
augmentor = Compose([
transforms.Lambda(lambda x: torch.Tensor(x)),
RandomRotate(25),
RandomTranslate([0.11, 0.11]),
RandomFlip(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1, 1).permute(1, 0, 2, 3)),
])
train_dataset = MRDataset('./data/', args.task,
args.plane, transform=augmentor, train=True)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=True, num_workers=11, drop_last=False)
validation_dataset = MRDataset(
'./data/', args.task, args.plane, train=False)
validation_loader = torch.utils.data.DataLoader(
validation_dataset, batch_size=1, shuffle=-True, num_workers=11, drop_last=False)
mrnet = model.MRNet()
if torch.cuda.is_available():
mrnet = mrnet.cuda()
optimizer = optim.Adam(mrnet.parameters(), lr=args.lr, weight_decay=0.1)
if args.lr_scheduler == "plateau":
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, patience=3, factor=.3, threshold=1e-4, verbose=True)
elif args.lr_scheduler == "step":
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=3, gamma=args.gamma)
best_val_loss = float('inf')
best_val_auc = float(0)
num_epochs = args.epochs
iteration_change_loss = 0
patience = args.patience
log_every = args.log_every
t_start_training = time.time()
for epoch in range(num_epochs):
current_lr = get_lr(optimizer)
t_start = time.time()
train_loss, train_auc = train_model(
mrnet, train_loader, epoch, num_epochs, optimizer, writer, current_lr, log_every)
val_loss, val_auc = evaluate_model(
mrnet, validation_loader, epoch, num_epochs, writer, current_lr)
if args.lr_scheduler == 'plateau':
scheduler.step(val_loss)
elif args.lr_scheduler == 'step':
scheduler.step()
t_end = time.time()
delta = t_end - t_start
print("train loss : {0} | train auc {1} | val loss {2} | val auc {3} | elapsed time {4} s".format(
train_loss, train_auc, val_loss, val_auc, delta))
iteration_change_loss += 1
print('-' * 30)
if val_auc > best_val_auc:
best_val_auc = val_auc
if bool(args.save_model):
file_name = f'model_{args.prefix_name}_{args.task}_{args.plane}_val_auc_{val_auc:0.4f}_train_auc_{train_auc:0.4f}_epoch_{epoch+1}.pth'
for f in os.listdir('./models/'):
if (args.task in f) and (args.plane in f) and (args.prefix_name in f):
os.remove(f'./models/{f}')
torch.save(mrnet, f'./models/{file_name}')
if val_loss < best_val_loss:
best_val_loss = val_loss
iteration_change_loss = 0
if iteration_change_loss == patience:
print('Early stopping after {0} iterations without the decrease of the val loss'.
format(iteration_change_loss))
break
t_end_training = time.time()
print(f'training took {t_end_training - t_start_training} s')
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--task', type=str, required=True,
choices=['abnormal', 'acl', 'meniscus'])
parser.add_argument('-p', '--plane', type=str, required=True,
choices=['sagittal', 'coronal', 'axial'])
parser.add_argument('--prefix_name', type=str, required=True)
parser.add_argument('--augment', type=int, choices=[0, 1], default=1)
parser.add_argument('--lr_scheduler', type=str,
default='plateau', choices=['plateau', 'step'])
parser.add_argument('--gamma', type=float, default=0.5)
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--flush_history', type=int, choices=[0, 1], default=0)
parser.add_argument('--save_model', type=int, choices=[0, 1], default=1)
parser.add_argument('--patience', type=int, default=5)
parser.add_argument('--log_every', type=int, default=100)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_arguments()
run(args)