|
a |
|
b/src/swa.py |
|
|
1 |
|
|
|
2 |
#!/usr/bin/env python |
|
|
3 |
|
|
|
4 |
""" |
|
|
5 |
Stochastic Weight Averaging (SWA) |
|
|
6 |
Averaging Weights Leads to Wider Optima and Better Generalization |
|
|
7 |
https://github.com/timgaripov/swa |
|
|
8 |
""" |
|
|
9 |
import torch |
|
|
10 |
import models |
|
|
11 |
from tqdm import tqdm |
|
|
12 |
import glob |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
def moving_average(net1, net2, alpha=1.): |
|
|
16 |
for param1, param2 in zip(net1.parameters(), net2.parameters()): |
|
|
17 |
param1.data *= (1.0 - alpha) |
|
|
18 |
param1.data += param2.data * alpha |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
def _check_bn(module, flag): |
|
|
22 |
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): |
|
|
23 |
flag[0] = True |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
def check_bn(model): |
|
|
27 |
flag = [False] |
|
|
28 |
model.apply(lambda module: _check_bn(module, flag)) |
|
|
29 |
return flag[0] |
|
|
30 |
|
|
|
31 |
|
|
|
32 |
def reset_bn(module): |
|
|
33 |
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): |
|
|
34 |
module.running_mean = torch.zeros_like(module.running_mean) |
|
|
35 |
module.running_var = torch.ones_like(module.running_var) |
|
|
36 |
|
|
|
37 |
|
|
|
38 |
def _get_momenta(module, momenta): |
|
|
39 |
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): |
|
|
40 |
momenta[module] = module.momentum |
|
|
41 |
|
|
|
42 |
|
|
|
43 |
def _set_momenta(module, momenta): |
|
|
44 |
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): |
|
|
45 |
module.momentum = momenta[module] |
|
|
46 |
|
|
|
47 |
|
|
|
48 |
def bn_update(loader, model): |
|
|
49 |
""" |
|
|
50 |
BatchNorm buffers update (if any). |
|
|
51 |
Performs 1 epochs to estimate buffers average using train dataset. |
|
|
52 |
:param loader: train dataset loader for buffers average estimation. |
|
|
53 |
:param model: model being update |
|
|
54 |
:return: None |
|
|
55 |
""" |
|
|
56 |
if not check_bn(model): |
|
|
57 |
return |
|
|
58 |
model.train() |
|
|
59 |
momenta = {} |
|
|
60 |
model.apply(reset_bn) |
|
|
61 |
model.apply(lambda module: _get_momenta(module, momenta)) |
|
|
62 |
n = 0 |
|
|
63 |
|
|
|
64 |
pbar = tqdm(loader, unit="images", unit_scale=loader.batch_size) |
|
|
65 |
for batch in pbar: |
|
|
66 |
input, targets = batch['images'], batch['targets'] |
|
|
67 |
input = input.cuda() |
|
|
68 |
b = input.size(0) |
|
|
69 |
|
|
|
70 |
momentum = b / (n + b) |
|
|
71 |
for module in momenta.keys(): |
|
|
72 |
module.momentum = momentum |
|
|
73 |
|
|
|
74 |
model(input) |
|
|
75 |
n += b |
|
|
76 |
|
|
|
77 |
model.apply(lambda module: _set_momenta(module, momenta)) |
|
|
78 |
|
|
|
79 |
|
|
|
80 |
if __name__ == '__main__': |
|
|
81 |
import argparse |
|
|
82 |
from pathlib import Path |
|
|
83 |
from torchvision.transforms import Compose |
|
|
84 |
from torch.utils.data import DataLoader |
|
|
85 |
from augmentation import valid_aug |
|
|
86 |
from dataset import SIIMDataset |
|
|
87 |
|
|
|
88 |
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
|
|
89 |
parser.add_argument("--input", type=str, help='input directory') |
|
|
90 |
parser.add_argument("--output", type=str, default='swa_model.pth', help='output model file') |
|
|
91 |
parser.add_argument("--batch-size", type=int, default=16, help='batch size') |
|
|
92 |
args = parser.parse_args() |
|
|
93 |
|
|
|
94 |
# directory = Path(args.input) |
|
|
95 |
# files = [f for f in directory.iterdir() if f.suffix == ".pth"] |
|
|
96 |
files = glob.glob(args.input + "/stage1/checkpoints/stage1.*.pth") |
|
|
97 |
files += glob.glob(args.input + "/stage2/checkpoints/stage1.*.pth") |
|
|
98 |
assert(len(files) > 1) |
|
|
99 |
|
|
|
100 |
net = models.Unet( |
|
|
101 |
encoder_name="resnet34", |
|
|
102 |
activation='sigmoid', |
|
|
103 |
classes=1, |
|
|
104 |
# center=True |
|
|
105 |
) |
|
|
106 |
checkpoint = torch.load(files[0]) |
|
|
107 |
net.load_state_dict(checkpoint['model_state_dict']) |
|
|
108 |
|
|
|
109 |
for i, f in enumerate(files[1:]): |
|
|
110 |
# net2 = model.load(f) |
|
|
111 |
net2 = models.Unet( |
|
|
112 |
encoder_name="resnet34", |
|
|
113 |
activation='sigmoid', |
|
|
114 |
classes=1, |
|
|
115 |
# center=True |
|
|
116 |
) |
|
|
117 |
checkpoint = torch.load(f) |
|
|
118 |
net2.load_state_dict(checkpoint['model_state_dict']) |
|
|
119 |
moving_average(net, net2, 1. / (i + 2)) |
|
|
120 |
|
|
|
121 |
test_csv = './csv/train_0.csv' |
|
|
122 |
root = "/raid/data/kaggle/siim/siim256/" |
|
|
123 |
# img_size = 128 |
|
|
124 |
batch_size = 16 |
|
|
125 |
train_transform = valid_aug() |
|
|
126 |
train_dataset = SIIMDataset( |
|
|
127 |
csv_file=test_csv, |
|
|
128 |
root=root, |
|
|
129 |
transform=train_transform, |
|
|
130 |
mode='train' |
|
|
131 |
) |
|
|
132 |
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, drop_last=True) |
|
|
133 |
net.cuda() |
|
|
134 |
bn_update(train_dataloader, net) |
|
|
135 |
|
|
|
136 |
# models.save(net, args.output) |
|
|
137 |
torch.save({ |
|
|
138 |
'model_state_dict': net.state_dict() |
|
|
139 |
}, args.output) |