|
a |
|
b/CRCNet/train.py |
|
|
1 |
import datetime |
|
|
2 |
import os |
|
|
3 |
import time |
|
|
4 |
|
|
|
5 |
import torch |
|
|
6 |
import torch.utils.data |
|
|
7 |
from torch import nn |
|
|
8 |
import torchvision |
|
|
9 |
from torchvision import transforms |
|
|
10 |
|
|
|
11 |
import utils |
|
|
12 |
|
|
|
13 |
try: |
|
|
14 |
from apex import amp |
|
|
15 |
except ImportError: |
|
|
16 |
amp = None |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False): |
|
|
20 |
model.train() |
|
|
21 |
metric_logger = utils.MetricLogger(delimiter=" ") |
|
|
22 |
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) |
|
|
23 |
metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}')) |
|
|
24 |
|
|
|
25 |
header = 'Epoch: [{}]'.format(epoch) |
|
|
26 |
for image, target in metric_logger.log_every(data_loader, print_freq, header): |
|
|
27 |
start_time = time.time() |
|
|
28 |
image, target = image.to(device), target.to(device) |
|
|
29 |
output = model(image) |
|
|
30 |
loss = criterion(output, target) |
|
|
31 |
|
|
|
32 |
optimizer.zero_grad() |
|
|
33 |
if apex: |
|
|
34 |
with amp.scale_loss(loss, optimizer) as scaled_loss: |
|
|
35 |
scaled_loss.backward() |
|
|
36 |
else: |
|
|
37 |
loss.backward() |
|
|
38 |
optimizer.step() |
|
|
39 |
|
|
|
40 |
acc1, acc5 = utils.accuracy(output, target, topk=(1, 2)) |
|
|
41 |
batch_size = image.shape[0] |
|
|
42 |
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) |
|
|
43 |
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) |
|
|
44 |
metric_logger.meters['acc2'].update(acc5.item(), n=batch_size) |
|
|
45 |
metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time)) |
|
|
46 |
|
|
|
47 |
|
|
|
48 |
def evaluate(model, criterion, data_loader, device, print_freq=100): |
|
|
49 |
model.eval() |
|
|
50 |
metric_logger = utils.MetricLogger(delimiter=" ") |
|
|
51 |
header = 'Test:' |
|
|
52 |
with torch.no_grad(): |
|
|
53 |
for image, target in metric_logger.log_every(data_loader, print_freq, header): |
|
|
54 |
image = image.to(device, non_blocking=True) |
|
|
55 |
target = target.to(device, non_blocking=True) |
|
|
56 |
output = model(image) |
|
|
57 |
loss = criterion(output, target) |
|
|
58 |
|
|
|
59 |
acc1, acc5 = utils.accuracy(output, target, topk=(1, 2)) |
|
|
60 |
# FIXME need to take into account that the datasets |
|
|
61 |
# could have been padded in distributed setup |
|
|
62 |
batch_size = image.shape[0] |
|
|
63 |
metric_logger.update(loss=loss.item()) |
|
|
64 |
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) |
|
|
65 |
metric_logger.meters['acc2'].update(acc5.item(), n=batch_size) |
|
|
66 |
# gather the stats from all processes |
|
|
67 |
metric_logger.synchronize_between_processes() |
|
|
68 |
|
|
|
69 |
print(' * Acc@1 {top1.global_avg:.3f} Acc@2 {top5.global_avg:.3f}' |
|
|
70 |
.format(top1=metric_logger.acc1, top5=metric_logger.acc2)) |
|
|
71 |
return metric_logger.acc1.global_avg |
|
|
72 |
|
|
|
73 |
|
|
|
74 |
def _get_cache_path(filepath): |
|
|
75 |
import hashlib |
|
|
76 |
h = hashlib.sha1(filepath.encode()).hexdigest() |
|
|
77 |
cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt") |
|
|
78 |
cache_path = os.path.expanduser(cache_path) |
|
|
79 |
return cache_path |
|
|
80 |
|
|
|
81 |
|
|
|
82 |
def load_data(traindir, valdir, cache_dataset, distributed): |
|
|
83 |
# Data loading code |
|
|
84 |
print("Loading data") |
|
|
85 |
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
86 |
std=[0.229, 0.224, 0.225]) |
|
|
87 |
|
|
|
88 |
print("Loading training data") |
|
|
89 |
st = time.time() |
|
|
90 |
cache_path = _get_cache_path(traindir) |
|
|
91 |
if cache_dataset and os.path.exists(cache_path): |
|
|
92 |
# Attention, as the transforms are also cached! |
|
|
93 |
print("Loading dataset_train from {}".format(cache_path)) |
|
|
94 |
dataset, _ = torch.load(cache_path) |
|
|
95 |
else: |
|
|
96 |
#dataset = torchvision.datasets.ImageFolder( |
|
|
97 |
# traindir, |
|
|
98 |
# transforms.Compose([ |
|
|
99 |
# transforms.RandomResizedCrop(224), |
|
|
100 |
# transforms.RandomHorizontalFlip(), |
|
|
101 |
# transforms.ToTensor(), |
|
|
102 |
# normalize, |
|
|
103 |
# ])) |
|
|
104 |
dataset = utils.CSVDataset( |
|
|
105 |
traindir, |
|
|
106 |
transforms.Compose([ |
|
|
107 |
transforms.RandomResizedCrop(224), |
|
|
108 |
transforms.RandomPerspective(), |
|
|
109 |
transforms.RandomHorizontalFlip(), |
|
|
110 |
transforms.RandomRotation(degrees=180), |
|
|
111 |
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.0), |
|
|
112 |
transforms.ToTensor(), |
|
|
113 |
normalize, |
|
|
114 |
])) |
|
|
115 |
if cache_dataset: |
|
|
116 |
print("Saving dataset_train to {}".format(cache_path)) |
|
|
117 |
utils.mkdir(os.path.dirname(cache_path)) |
|
|
118 |
utils.save_on_master((dataset, traindir), cache_path) |
|
|
119 |
print("Took", time.time() - st) |
|
|
120 |
|
|
|
121 |
print("Loading validation data") |
|
|
122 |
cache_path = _get_cache_path(valdir) |
|
|
123 |
if cache_dataset and os.path.exists(cache_path): |
|
|
124 |
# Attention, as the transforms are also cached! |
|
|
125 |
print("Loading dataset_test from {}".format(cache_path)) |
|
|
126 |
dataset_test, _ = torch.load(cache_path) |
|
|
127 |
else: |
|
|
128 |
#dataset_test = torchvision.datasets.ImageFolder( |
|
|
129 |
# valdir, |
|
|
130 |
# transforms.Compose([ |
|
|
131 |
# transforms.Resize(256), |
|
|
132 |
# transforms.CenterCrop(224), |
|
|
133 |
# transforms.ToTensor(), |
|
|
134 |
# normalize, |
|
|
135 |
# ])) |
|
|
136 |
dataset_test = utils.CSVDataset( |
|
|
137 |
valdir, |
|
|
138 |
transforms.Compose([ |
|
|
139 |
transforms.Resize(256), |
|
|
140 |
transforms.CenterCrop(224), |
|
|
141 |
transforms.ToTensor(), |
|
|
142 |
normalize, |
|
|
143 |
])) |
|
|
144 |
if cache_dataset: |
|
|
145 |
print("Saving dataset_test to {}".format(cache_path)) |
|
|
146 |
utils.mkdir(os.path.dirname(cache_path)) |
|
|
147 |
utils.save_on_master((dataset_test, valdir), cache_path) |
|
|
148 |
|
|
|
149 |
print("Creating data loaders") |
|
|
150 |
if distributed: |
|
|
151 |
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) |
|
|
152 |
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) |
|
|
153 |
else: |
|
|
154 |
train_sampler = torch.utils.data.RandomSampler(dataset) |
|
|
155 |
test_sampler = torch.utils.data.SequentialSampler(dataset_test) |
|
|
156 |
|
|
|
157 |
return dataset, dataset_test, train_sampler, test_sampler |
|
|
158 |
|
|
|
159 |
|
|
|
160 |
def main(args): |
|
|
161 |
if args.apex and amp is None: |
|
|
162 |
raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " |
|
|
163 |
"to enable mixed-precision training.") |
|
|
164 |
|
|
|
165 |
if args.output_dir: |
|
|
166 |
utils.mkdir(args.output_dir) |
|
|
167 |
|
|
|
168 |
utils.init_distributed_mode(args) |
|
|
169 |
print(args) |
|
|
170 |
|
|
|
171 |
device = torch.device(args.device) |
|
|
172 |
|
|
|
173 |
torch.backends.cudnn.benchmark = True |
|
|
174 |
|
|
|
175 |
#train_dir = os.path.join(args.data_path, 'train') |
|
|
176 |
#val_dir = os.path.join(args.data_path, 'val') |
|
|
177 |
train_dir = args.train_file |
|
|
178 |
val_dir = args.val_file |
|
|
179 |
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, |
|
|
180 |
args.cache_dataset, args.distributed) |
|
|
181 |
data_loader = torch.utils.data.DataLoader( |
|
|
182 |
dataset, batch_size=args.batch_size, |
|
|
183 |
sampler=train_sampler, num_workers=args.workers, pin_memory=True) |
|
|
184 |
|
|
|
185 |
data_loader_test = torch.utils.data.DataLoader( |
|
|
186 |
dataset_test, batch_size=args.batch_size, |
|
|
187 |
sampler=test_sampler, num_workers=args.workers, pin_memory=True) |
|
|
188 |
|
|
|
189 |
print("Creating model") |
|
|
190 |
#model = torchvision.models.__dict__[args.model](pretrained=args.pretrained) |
|
|
191 |
model = torchvision.models.densenet169(pretrained=args.pretrained) |
|
|
192 |
# modify the last layer for objective task |
|
|
193 |
num_ftrs = model.features.norm5.num_features |
|
|
194 |
model.classifier = nn.Linear(num_ftrs, args.num_classes) |
|
|
195 |
|
|
|
196 |
model.to(device) |
|
|
197 |
if args.distributed and args.sync_bn: |
|
|
198 |
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
|
|
199 |
|
|
|
200 |
if args.focal_loss: |
|
|
201 |
import pytorch_toolbelt.losses |
|
|
202 |
criterion = pytorch_toolbelt.losses.FocalLoss() |
|
|
203 |
else: |
|
|
204 |
criterion = nn.CrossEntropyLoss() |
|
|
205 |
|
|
|
206 |
optimizer = torch.optim.SGD( |
|
|
207 |
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) |
|
|
208 |
|
|
|
209 |
if args.apex: |
|
|
210 |
model, optimizer = amp.initialize(model, optimizer, |
|
|
211 |
opt_level=args.apex_opt_level |
|
|
212 |
) |
|
|
213 |
|
|
|
214 |
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) |
|
|
215 |
|
|
|
216 |
model_without_ddp = model |
|
|
217 |
if args.distributed: |
|
|
218 |
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) |
|
|
219 |
model_without_ddp = model.module |
|
|
220 |
|
|
|
221 |
if args.resume: |
|
|
222 |
checkpoint = torch.load(args.resume, map_location='cpu') |
|
|
223 |
model_without_ddp.load_state_dict(checkpoint['model']) |
|
|
224 |
optimizer.load_state_dict(checkpoint['optimizer']) |
|
|
225 |
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
|
|
226 |
args.start_epoch = checkpoint['epoch'] + 1 |
|
|
227 |
|
|
|
228 |
if args.test_only: |
|
|
229 |
evaluate(model, criterion, data_loader_test, device=device) |
|
|
230 |
return |
|
|
231 |
|
|
|
232 |
print("Start training") |
|
|
233 |
start_time = time.time() |
|
|
234 |
for epoch in range(args.start_epoch, args.epochs): |
|
|
235 |
if args.distributed: |
|
|
236 |
train_sampler.set_epoch(epoch) |
|
|
237 |
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex) |
|
|
238 |
lr_scheduler.step() |
|
|
239 |
evaluate(model, criterion, data_loader_test, device=device) |
|
|
240 |
if args.output_dir: |
|
|
241 |
checkpoint = { |
|
|
242 |
'model': model_without_ddp.state_dict(), |
|
|
243 |
'optimizer': optimizer.state_dict(), |
|
|
244 |
'lr_scheduler': lr_scheduler.state_dict(), |
|
|
245 |
'epoch': epoch, |
|
|
246 |
'args': args} |
|
|
247 |
utils.save_on_master( |
|
|
248 |
checkpoint, |
|
|
249 |
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) |
|
|
250 |
utils.save_on_master( |
|
|
251 |
checkpoint, |
|
|
252 |
os.path.join(args.output_dir, 'checkpoint.pth')) |
|
|
253 |
|
|
|
254 |
total_time = time.time() - start_time |
|
|
255 |
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
|
|
256 |
print('Training time {}'.format(total_time_str)) |
|
|
257 |
|
|
|
258 |
|
|
|
259 |
def parse_args(): |
|
|
260 |
import argparse |
|
|
261 |
parser = argparse.ArgumentParser(description='PyTorch Classification Training') |
|
|
262 |
|
|
|
263 |
#parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', help='dataset') |
|
|
264 |
parser.add_argument('--train-file', help='training set') |
|
|
265 |
parser.add_argument('--val-file', help='validation set') |
|
|
266 |
parser.add_argument('--num-classes', help='number of classes for the objective task', type=int) |
|
|
267 |
parser.add_argument( |
|
|
268 |
"--focal-loss", help="Use focal loss",action="store_true") |
|
|
269 |
|
|
|
270 |
#parser.add_argument('--model', default='resnet18', help='model') |
|
|
271 |
parser.add_argument('--device', default='cuda', help='device') |
|
|
272 |
parser.add_argument('-b', '--batch-size', default=32, type=int) |
|
|
273 |
parser.add_argument('--epochs', default=90, type=int, metavar='N', |
|
|
274 |
help='number of total epochs to run') |
|
|
275 |
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', |
|
|
276 |
help='number of data loading workers (default: 16)') |
|
|
277 |
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate') |
|
|
278 |
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', |
|
|
279 |
help='momentum') |
|
|
280 |
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, |
|
|
281 |
metavar='W', help='weight decay (default: 1e-4)', |
|
|
282 |
dest='weight_decay') |
|
|
283 |
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') |
|
|
284 |
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') |
|
|
285 |
parser.add_argument('--print-freq', default=10, type=int, help='print frequency') |
|
|
286 |
parser.add_argument('--output-dir', default='.', help='path where to save') |
|
|
287 |
parser.add_argument('--resume', default='', help='resume from checkpoint') |
|
|
288 |
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', |
|
|
289 |
help='start epoch') |
|
|
290 |
parser.add_argument( |
|
|
291 |
"--cache-dataset", |
|
|
292 |
dest="cache_dataset", |
|
|
293 |
help="Cache the datasets for quicker initialization. It also serializes the transforms", |
|
|
294 |
action="store_true", |
|
|
295 |
) |
|
|
296 |
parser.add_argument( |
|
|
297 |
"--sync-bn", |
|
|
298 |
dest="sync_bn", |
|
|
299 |
help="Use sync batch norm", |
|
|
300 |
action="store_true", |
|
|
301 |
) |
|
|
302 |
parser.add_argument( |
|
|
303 |
"--test-only", |
|
|
304 |
dest="test_only", |
|
|
305 |
help="Only test the model", |
|
|
306 |
action="store_true", |
|
|
307 |
) |
|
|
308 |
parser.add_argument( |
|
|
309 |
"--pretrained", |
|
|
310 |
dest="pretrained", |
|
|
311 |
help="Use pre-trained models from the modelzoo", |
|
|
312 |
action="store_true", |
|
|
313 |
) |
|
|
314 |
|
|
|
315 |
# Mixed precision training parameters |
|
|
316 |
parser.add_argument('--apex', action='store_true', |
|
|
317 |
help='Use apex for mixed precision training') |
|
|
318 |
parser.add_argument('--apex-opt-level', default='O1', type=str, |
|
|
319 |
help='For apex mixed precision training' |
|
|
320 |
'O0 for FP32 training, O1 for mixed precision training.' |
|
|
321 |
'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet' |
|
|
322 |
) |
|
|
323 |
|
|
|
324 |
# distributed training parameters |
|
|
325 |
parser.add_argument('--world-size', default=1, type=int, |
|
|
326 |
help='number of distributed processes') |
|
|
327 |
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') |
|
|
328 |
|
|
|
329 |
args = parser.parse_args() |
|
|
330 |
|
|
|
331 |
return args |
|
|
332 |
|
|
|
333 |
|
|
|
334 |
if __name__ == "__main__": |
|
|
335 |
args = parse_args() |
|
|
336 |
main(args) |