--- a +++ b/semseg_train/utils.py @@ -0,0 +1,198 @@ +import json +from datetime import datetime +from pathlib import Path + +import random +import numpy as np + +import torch +from torch.autograd import Variable +import tqdm + +from callbacks import EarlyStopping + +def variable(x, volatile=False): + #xの型がlistまたはtupleに等しいときTrue + if isinstance(x, (list, tuple)): + return [variable(y, volatile=volatile) for y in x] + #以下一行UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead. + #return cuda(Variable(x, volatile=volatile)) + with torch.no_grad(): + return cuda(Variable(x)) + + +def cuda(x): + #Python3.7以降予約語にasyncが指定されたため以下であると"SyntaxError: invalid syntax"、代わりにnon_blocking + #return x.cuda(async=True) if torch.cuda.is_available() else x + return x.cuda(non_blocking=True) if torch.cuda.is_available() else x + + + +def write_event(log, step: int, **data): + data['step'] = step + data['dt'] = datetime.now().isoformat() + log.write(json.dumps(data, sort_keys=True)) + log.write('\n') + log.flush() + + +def train(args, model, criterion, train_loader, valid_loader, validation, init_optimizer, n_epochs=None, fold=None): + lr = args.lr + n_epochs = n_epochs or args.n_epochs + optimizer = init_optimizer(lr) + + root = Path(args.root) + model_path = root / 'model_{fold}.pt'.format(fold=fold) + print(torch.cuda.is_available()) + if model_path.exists(): + #RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU. + #state = torch.load(str(model_path)) + #state = torch.load(str(model_path), map_location=torch.device("gpu")) + #RuntimeError: Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu device type at start of device string: 0 + state = torch.load(str(model_path), map_location=torch.device("cuda:0")) + epoch = state['epoch'] + step = state['step'] + model.load_state_dict(state['model']) + print('Restored model, epoch {}, step {:,}'.format(epoch, step)) + else: + epoch = 1 + step = 0 + + save = lambda ep: torch.save({ + 'model': model.state_dict(), + 'epoch': ep, + 'step': step, + }, str(model_path)) + + report_each = 10 + log = root.joinpath('train_{fold}.log'.format(fold=fold)).open('at', encoding='utf8') + valid_losses = [] + for epoch in range(epoch, n_epochs + 1): + model.train() + random.seed() + #プログレスバーの表示 + tq = tqdm.tqdm(total=(len(train_loader) * args.batch_size)) + tq.set_description('Epoch {}, lr {}'.format(epoch, lr)) + losses = [] + tl = train_loader + try: + mean_loss = 0 + for i, (inputs, targets) in enumerate(tl): + #勾配計算をしたいからVariable()で囲む? + inputs, targets = variable(inputs), variable(targets) + outputs = model(inputs) + loss = criterion(outputs, targets) + #パラメータW,Bの勾配値(偏微分)は蓄積してしまうため毎ループで0にする + optimizer.zero_grad() + batch_size = inputs.size(0) + #誤差逆伝搬法 + loss.backward() + #パラメータ(W,B)更新 + optimizer.step() + step += 1 + tq.update(batch_size) + #invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number + #問題となったコード <class 'torch.Tensor'> は0インデックスが使えないようです + #losses.append(loss.data[0]) + losses.append(loss.data.item()) + mean_loss = np.mean(losses[-report_each:]) + tq.set_postfix(loss='{:.5f}'.format(mean_loss)) + if i and i % report_each == 0: + write_event(log, step, loss=mean_loss) + write_event(log, step, loss=mean_loss) + tq.close() + save(epoch + 1) + valid_metrics = validation(model, criterion, valid_loader) + write_event(log, step, **valid_metrics) + valid_loss = valid_metrics['valid_loss'] + valid_losses.append(valid_loss) + except KeyboardInterrupt: + tq.close() + print('Ctrl+C, saving snapshot') + save(epoch) + print('done.') + return + +#EarlyStoppingを追加した学習をしたい場合 +def train_callbacks(args, model, criterion, train_loader, valid_loader, validation, init_optimizer, n_epochs=None, fold=None): + #torch_fix_seed(seed=42) + lr = args.lr + n_epochs = n_epochs or args.n_epochs + optimizer = init_optimizer(lr) + root = Path(args.root) + model_path = root / 'model_{fold}.pt'.format(fold=fold) + if model_path.exists(): + state = torch.load(str(model_path), map_location=torch.device("cuda:0")) + epoch = state['epoch'] + step = state['step'] + model.load_state_dict(state['model']) + print('Restored model, epoch {}, step {:,}'.format(epoch, step)) + else: + epoch = 1 + step = 0 + + save = lambda ep: torch.save({ + 'model': model.state_dict(), + 'epoch': ep, + 'step': step, + }, str(model_path)) + + report_each = 100 + log = root.joinpath('train_{fold}.log'.format(fold=fold)).open('at', encoding='utf8') + valid_losses = [] + early_stopping = EarlyStopping(patience=15) + for epoch in range(epoch, n_epochs + 1): + #torch_fix_seed(seed=42+epoch) + model.train() + #random.seed() + #プログレスバーの表示 + tq = tqdm.tqdm(total=(len(train_loader) * args.batch_size)) + tq.set_description('Epoch {}, lr {}'.format(epoch, lr)) + losses = [] + tl = train_loader + try: + mean_loss = 0 + for i, (inputs, targets) in enumerate(tl): + #勾配計算をしたいからVariable()で囲む? + inputs, targets = variable(inputs), variable(targets) + outputs = model(inputs) + loss = criterion(outputs, targets) + #パラメータW,Bの勾配値(偏微分)は蓄積してしまうため毎ループで0にする + optimizer.zero_grad() + batch_size = inputs.size(0) + #誤差逆伝搬法 + loss.backward() + #パラメータ(W,B)更新 + optimizer.step() + step += 1 + tq.update(batch_size) + #loss.data.item()はfloatでバッチ数分のlossで0.8/1バッチとか + losses.append(loss.data.item()) + mean_loss = np.mean(losses[-report_each:]) + tq.set_postfix(loss='{:.5f}'.format(mean_loss)) + if i and i % report_each == 0: + write_event(log, step, loss=mean_loss) + write_event(log, step, loss=mean_loss) + tq.close() + #save(epoch + 1) + #if epoch==24: + # torch.save({ + # 'model': model.state_dict(), + # 'epoch': 24, + # 'step': step, + # }, str(root / 'model_{fold}_24epoch.pt'.format(fold=fold))) + valid_metrics = validation(model, criterion, valid_loader) + write_event(log, step, **valid_metrics) + valid_loss = valid_metrics['valid_loss'] + early_stopping(valid_loss#) + , model, epoch+1, step, root / 'model_{fold}_{ep}epoch.pt'.format(fold=fold, ep=epoch)) + if early_stopping.early_stop: + #一定epochだけval_lossが最低値を更新しなかった場合、学習終了 + break + valid_losses.append(valid_loss) + except KeyboardInterrupt: + tq.close() + print('Ctrl+C, saving snapshot') + save(epoch) + print('done.') + return \ No newline at end of file