Diff of /src/swa.py [000000] .. [95f789]

Switch to unified view

a b/src/swa.py
1
2
#!/usr/bin/env python
3
4
"""
5
Stochastic Weight Averaging (SWA)
6
Averaging Weights Leads to Wider Optima and Better Generalization
7
https://github.com/timgaripov/swa
8
"""
9
import torch
10
import models
11
from tqdm import tqdm
12
import glob
13
14
15
def moving_average(net1, net2, alpha=1.):
16
    for param1, param2 in zip(net1.parameters(), net2.parameters()):
17
        param1.data *= (1.0 - alpha)
18
        param1.data += param2.data * alpha
19
20
21
def _check_bn(module, flag):
22
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
23
        flag[0] = True
24
25
26
def check_bn(model):
27
    flag = [False]
28
    model.apply(lambda module: _check_bn(module, flag))
29
    return flag[0]
30
31
32
def reset_bn(module):
33
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
34
        module.running_mean = torch.zeros_like(module.running_mean)
35
        module.running_var = torch.ones_like(module.running_var)
36
37
38
def _get_momenta(module, momenta):
39
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
40
        momenta[module] = module.momentum
41
42
43
def _set_momenta(module, momenta):
44
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
45
        module.momentum = momenta[module]
46
47
48
def bn_update(loader, model):
49
    """
50
        BatchNorm buffers update (if any).
51
        Performs 1 epochs to estimate buffers average using train dataset.
52
        :param loader: train dataset loader for buffers average estimation.
53
        :param model: model being update
54
        :return: None
55
    """
56
    if not check_bn(model):
57
        return
58
    model.train()
59
    momenta = {}
60
    model.apply(reset_bn)
61
    model.apply(lambda module: _get_momenta(module, momenta))
62
    n = 0
63
64
    pbar = tqdm(loader, unit="images", unit_scale=loader.batch_size)
65
    for batch in pbar:
66
        input, targets = batch['images'], batch['targets']
67
        input = input.cuda()
68
        b = input.size(0)
69
70
        momentum = b / (n + b)
71
        for module in momenta.keys():
72
            module.momentum = momentum
73
74
        model(input)
75
        n += b
76
77
    model.apply(lambda module: _set_momenta(module, momenta))
78
79
80
if __name__ == '__main__':
81
    import argparse
82
    from pathlib import Path
83
    from torchvision.transforms import Compose
84
    from torch.utils.data import DataLoader
85
    from augmentation import valid_aug
86
    from dataset import SIIMDataset
87
88
    parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
89
    parser.add_argument("--input", type=str, help='input directory')
90
    parser.add_argument("--output", type=str, default='swa_model.pth', help='output model file')
91
    parser.add_argument("--batch-size", type=int, default=16, help='batch size')
92
    args = parser.parse_args()
93
94
    # directory = Path(args.input)
95
    # files = [f for f in directory.iterdir() if f.suffix == ".pth"]
96
    files = glob.glob(args.input + "/stage1/checkpoints/stage1.*.pth")
97
    files += glob.glob(args.input + "/stage2/checkpoints/stage1.*.pth")
98
    assert(len(files) > 1)
99
100
    net = models.Unet(
101
        encoder_name="resnet34",
102
        activation='sigmoid',
103
        classes=1,
104
        # center=True
105
    )
106
    checkpoint = torch.load(files[0])
107
    net.load_state_dict(checkpoint['model_state_dict'])
108
109
    for i, f in enumerate(files[1:]):
110
        # net2 = model.load(f)
111
        net2 = models.Unet(
112
            encoder_name="resnet34",
113
            activation='sigmoid',
114
            classes=1,
115
            # center=True
116
        )
117
        checkpoint = torch.load(f)
118
        net2.load_state_dict(checkpoint['model_state_dict'])
119
        moving_average(net, net2, 1. / (i + 2))
120
121
    test_csv = './csv/train_0.csv'
122
    root = "/raid/data/kaggle/siim/siim256/"
123
    # img_size = 128
124
    batch_size = 16
125
    train_transform = valid_aug()
126
    train_dataset = SIIMDataset(
127
        csv_file=test_csv,
128
        root=root,
129
        transform=train_transform,
130
        mode='train'
131
    )
132
    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, drop_last=True)
133
    net.cuda()
134
    bn_update(train_dataloader, net)
135
136
    # models.save(net, args.output)
137
    torch.save({
138
        'model_state_dict': net.state_dict()
139
    }, args.output)