--- a +++ b/src/swa.py @@ -0,0 +1,139 @@ + +#!/usr/bin/env python + +""" +Stochastic Weight Averaging (SWA) +Averaging Weights Leads to Wider Optima and Better Generalization +https://github.com/timgaripov/swa +""" +import torch +import models +from tqdm import tqdm +import glob + + +def moving_average(net1, net2, alpha=1.): + for param1, param2 in zip(net1.parameters(), net2.parameters()): + param1.data *= (1.0 - alpha) + param1.data += param2.data * alpha + + +def _check_bn(module, flag): + if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): + flag[0] = True + + +def check_bn(model): + flag = [False] + model.apply(lambda module: _check_bn(module, flag)) + return flag[0] + + +def reset_bn(module): + if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): + module.running_mean = torch.zeros_like(module.running_mean) + module.running_var = torch.ones_like(module.running_var) + + +def _get_momenta(module, momenta): + if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): + momenta[module] = module.momentum + + +def _set_momenta(module, momenta): + if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): + module.momentum = momenta[module] + + +def bn_update(loader, model): + """ + BatchNorm buffers update (if any). + Performs 1 epochs to estimate buffers average using train dataset. + :param loader: train dataset loader for buffers average estimation. + :param model: model being update + :return: None + """ + if not check_bn(model): + return + model.train() + momenta = {} + model.apply(reset_bn) + model.apply(lambda module: _get_momenta(module, momenta)) + n = 0 + + pbar = tqdm(loader, unit="images", unit_scale=loader.batch_size) + for batch in pbar: + input, targets = batch['images'], batch['targets'] + input = input.cuda() + b = input.size(0) + + momentum = b / (n + b) + for module in momenta.keys(): + module.momentum = momentum + + model(input) + n += b + + model.apply(lambda module: _set_momenta(module, momenta)) + + +if __name__ == '__main__': + import argparse + from pathlib import Path + from torchvision.transforms import Compose + from torch.utils.data import DataLoader + from augmentation import valid_aug + from dataset import SIIMDataset + + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--input", type=str, help='input directory') + parser.add_argument("--output", type=str, default='swa_model.pth', help='output model file') + parser.add_argument("--batch-size", type=int, default=16, help='batch size') + args = parser.parse_args() + + # directory = Path(args.input) + # files = [f for f in directory.iterdir() if f.suffix == ".pth"] + files = glob.glob(args.input + "/stage1/checkpoints/stage1.*.pth") + files += glob.glob(args.input + "/stage2/checkpoints/stage1.*.pth") + assert(len(files) > 1) + + net = models.Unet( + encoder_name="resnet34", + activation='sigmoid', + classes=1, + # center=True + ) + checkpoint = torch.load(files[0]) + net.load_state_dict(checkpoint['model_state_dict']) + + for i, f in enumerate(files[1:]): + # net2 = model.load(f) + net2 = models.Unet( + encoder_name="resnet34", + activation='sigmoid', + classes=1, + # center=True + ) + checkpoint = torch.load(f) + net2.load_state_dict(checkpoint['model_state_dict']) + moving_average(net, net2, 1. / (i + 2)) + + test_csv = './csv/train_0.csv' + root = "/raid/data/kaggle/siim/siim256/" + # img_size = 128 + batch_size = 16 + train_transform = valid_aug() + train_dataset = SIIMDataset( + csv_file=test_csv, + root=root, + transform=train_transform, + mode='train' + ) + train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, drop_last=True) + net.cuda() + bn_update(train_dataloader, net) + + # models.save(net, args.output) + torch.save({ + 'model_state_dict': net.state_dict() + }, args.output)