|
a |
|
b/HTNet/multi-modality/train.py |
|
|
1 |
import datetime |
|
|
2 |
import os |
|
|
3 |
import time |
|
|
4 |
|
|
|
5 |
import torch |
|
|
6 |
import torch.utils.data |
|
|
7 |
from torch import nn |
|
|
8 |
from torchvision import transforms |
|
|
9 |
from resnet import resnet152 |
|
|
10 |
import utils |
|
|
11 |
|
|
|
12 |
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq): |
|
|
13 |
model.train() |
|
|
14 |
metric_logger = utils.MetricLogger(delimiter=" ") |
|
|
15 |
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) |
|
|
16 |
metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}')) |
|
|
17 |
|
|
|
18 |
header = 'Epoch: [{}]'.format(epoch) |
|
|
19 |
for image, antibody, target in metric_logger.log_every(data_loader, print_freq, header): |
|
|
20 |
start_time = time.time() |
|
|
21 |
image, antibody, target = image.to(device), antibody.to(device), target.to(device) |
|
|
22 |
output = model(image, antibody) |
|
|
23 |
loss = criterion(output, target) |
|
|
24 |
|
|
|
25 |
optimizer.zero_grad() |
|
|
26 |
loss.backward() |
|
|
27 |
optimizer.step() |
|
|
28 |
|
|
|
29 |
acc1, acc5 = utils.accuracy(output, target, topk=(1, 2)) |
|
|
30 |
batch_size = image.shape[0] |
|
|
31 |
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) |
|
|
32 |
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) |
|
|
33 |
metric_logger.meters['acc2'].update(acc5.item(), n=batch_size) |
|
|
34 |
metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time)) |
|
|
35 |
|
|
|
36 |
|
|
|
37 |
def evaluate(model, criterion, data_loader, device, print_freq=100): |
|
|
38 |
model.eval() |
|
|
39 |
metric_logger = utils.MetricLogger(delimiter=" ") |
|
|
40 |
header = 'Test:' |
|
|
41 |
with torch.no_grad(): |
|
|
42 |
for image, antibody, target in metric_logger.log_every(data_loader, print_freq, header): |
|
|
43 |
image = image.to(device, non_blocking=True) |
|
|
44 |
antibody = antibody.to(device, non_blocking=True) |
|
|
45 |
target = target.to(device, non_blocking=True) |
|
|
46 |
output = model(image, antibody) |
|
|
47 |
loss = criterion(output, target) |
|
|
48 |
|
|
|
49 |
acc1, acc5 = utils.accuracy(output, target, topk=(1, 2)) |
|
|
50 |
# FIXME need to take into account that the datasets |
|
|
51 |
# could have been padded in distributed setup |
|
|
52 |
batch_size = image.shape[0] |
|
|
53 |
metric_logger.update(loss=loss.item()) |
|
|
54 |
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) |
|
|
55 |
metric_logger.meters['acc2'].update(acc5.item(), n=batch_size) |
|
|
56 |
# gather the stats from all processes |
|
|
57 |
metric_logger.synchronize_between_processes() |
|
|
58 |
|
|
|
59 |
print(' * Acc@1 {top1.global_avg:.3f} Acc@2 {top5.global_avg:.3f}' |
|
|
60 |
.format(top1=metric_logger.acc1, top5=metric_logger.acc2)) |
|
|
61 |
return metric_logger.acc1.global_avg |
|
|
62 |
|
|
|
63 |
|
|
|
64 |
def _get_cache_path(filepath): |
|
|
65 |
import hashlib |
|
|
66 |
h = hashlib.sha1(filepath.encode()).hexdigest() |
|
|
67 |
cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt") |
|
|
68 |
cache_path = os.path.expanduser(cache_path) |
|
|
69 |
return cache_path |
|
|
70 |
|
|
|
71 |
|
|
|
72 |
def load_data(traindir, valdir, antibody_train, antibody_val, cache_dataset, distributed): |
|
|
73 |
# Data loading code |
|
|
74 |
print("Loading data") |
|
|
75 |
normalize = transforms.Normalize(mean=[0.168, 0.174, 0.182], |
|
|
76 |
std =[0.159, 0.160, 0.162]) |
|
|
77 |
|
|
|
78 |
expression_tfs = transforms.Compose([nn.Dropout(0.3)]) |
|
|
79 |
|
|
|
80 |
print("Loading data") |
|
|
81 |
st = time.time() |
|
|
82 |
|
|
|
83 |
dataset = utils.HTDataset( |
|
|
84 |
traindir, antibody_train, |
|
|
85 |
transforms.Compose([ |
|
|
86 |
transforms.RandomResizedCrop(224), |
|
|
87 |
transforms.RandomHorizontalFlip(), |
|
|
88 |
transforms.RandomRotation(degrees=180), |
|
|
89 |
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.0), |
|
|
90 |
transforms.ToTensor(), |
|
|
91 |
normalize, |
|
|
92 |
]), expression_tfs) |
|
|
93 |
|
|
|
94 |
dataset_test = utils.HTDataset( |
|
|
95 |
valdir, antibody_val, |
|
|
96 |
transforms.Compose([ |
|
|
97 |
transforms.Resize(256), |
|
|
98 |
transforms.CenterCrop(224), |
|
|
99 |
transforms.ToTensor(), |
|
|
100 |
normalize, |
|
|
101 |
]), None) |
|
|
102 |
|
|
|
103 |
print("Took", time.time() - st) |
|
|
104 |
|
|
|
105 |
print("Creating data loaders") |
|
|
106 |
if distributed: |
|
|
107 |
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) |
|
|
108 |
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) |
|
|
109 |
else: |
|
|
110 |
train_sampler = torch.utils.data.RandomSampler(dataset) |
|
|
111 |
test_sampler = torch.utils.data.SequentialSampler(dataset_test) |
|
|
112 |
|
|
|
113 |
return dataset, dataset_test, train_sampler, test_sampler |
|
|
114 |
|
|
|
115 |
|
|
|
116 |
def main(args): |
|
|
117 |
|
|
|
118 |
if args.output_dir: |
|
|
119 |
utils.mkdir(args.output_dir) |
|
|
120 |
|
|
|
121 |
utils.init_distributed_mode(args) |
|
|
122 |
print(args) |
|
|
123 |
|
|
|
124 |
device = torch.device(args.device) |
|
|
125 |
torch.backends.cudnn.benchmark = True |
|
|
126 |
|
|
|
127 |
train_dir = args.train_file |
|
|
128 |
val_dir = args.val_file |
|
|
129 |
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, |
|
|
130 |
args.antibodytrn, args.antibodyval, |
|
|
131 |
args.cache_dataset, args.distributed) |
|
|
132 |
data_loader = torch.utils.data.DataLoader( |
|
|
133 |
dataset, batch_size=args.batch_size, |
|
|
134 |
sampler=train_sampler, num_workers=args.workers, pin_memory=True) |
|
|
135 |
|
|
|
136 |
data_loader_test = torch.utils.data.DataLoader( |
|
|
137 |
dataset_test, batch_size=args.batch_size, |
|
|
138 |
sampler=test_sampler, num_workers=args.workers, pin_memory=True) |
|
|
139 |
|
|
|
140 |
print("Creating model") |
|
|
141 |
model = resnet152(num_classes=2, antibody_nums=6) # 6 antibodies |
|
|
142 |
image_checkpoint = "../hashimoto_thyroiditis/model_79.pth" |
|
|
143 |
flag = os.path.exists(image_checkpoint) |
|
|
144 |
|
|
|
145 |
if flag: |
|
|
146 |
checkpoint = torch.load(image_checkpoint, map_location='cpu') |
|
|
147 |
msg = model.load_state_dict(checkpoint['model'], strict=False) |
|
|
148 |
print(msg) |
|
|
149 |
|
|
|
150 |
print("Parameters to be updated:") |
|
|
151 |
parameters_to_be_updated = ['fc.weight', 'fc.bias'] + msg.missing_keys |
|
|
152 |
print(parameters_to_be_updated) |
|
|
153 |
|
|
|
154 |
for name, param in model.named_parameters(): |
|
|
155 |
if name not in parameters_to_be_updated: |
|
|
156 |
param.requires_grad = False |
|
|
157 |
|
|
|
158 |
model.to(device) |
|
|
159 |
if args.distributed and args.sync_bn: |
|
|
160 |
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
|
|
161 |
|
|
|
162 |
if flag: |
|
|
163 |
parameters = list(filter(lambda p: p.requires_grad, model.parameters())) |
|
|
164 |
assert len(parameters) == len(parameters_to_be_updated) |
|
|
165 |
else: |
|
|
166 |
parameters = model.parameters() |
|
|
167 |
|
|
|
168 |
criterion = nn.CrossEntropyLoss() |
|
|
169 |
|
|
|
170 |
optimizer = torch.optim.SGD( |
|
|
171 |
parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) |
|
|
172 |
|
|
|
173 |
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) |
|
|
174 |
|
|
|
175 |
model_without_ddp = model |
|
|
176 |
if args.distributed: |
|
|
177 |
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) |
|
|
178 |
model_without_ddp = model.module |
|
|
179 |
|
|
|
180 |
if args.resume: |
|
|
181 |
checkpoint = torch.load(args.resume, map_location='cpu') |
|
|
182 |
model_without_ddp.load_state_dict(checkpoint['model']) |
|
|
183 |
optimizer.load_state_dict(checkpoint['optimizer']) |
|
|
184 |
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
|
|
185 |
args.start_epoch = checkpoint['epoch'] + 1 |
|
|
186 |
|
|
|
187 |
if args.test_only: |
|
|
188 |
evaluate(model, criterion, data_loader_test, device=device) |
|
|
189 |
return |
|
|
190 |
|
|
|
191 |
print("Start training") |
|
|
192 |
start_time = time.time() |
|
|
193 |
for epoch in range(args.start_epoch, args.epochs): |
|
|
194 |
if args.distributed: |
|
|
195 |
train_sampler.set_epoch(epoch) |
|
|
196 |
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq) |
|
|
197 |
lr_scheduler.step() |
|
|
198 |
evaluate(model, criterion, data_loader_test, device=device) |
|
|
199 |
if args.output_dir: |
|
|
200 |
checkpoint = { |
|
|
201 |
'model': model_without_ddp.state_dict(), |
|
|
202 |
'optimizer': optimizer.state_dict(), |
|
|
203 |
'lr_scheduler': lr_scheduler.state_dict(), |
|
|
204 |
'epoch': epoch, |
|
|
205 |
'args': args} |
|
|
206 |
utils.save_on_master( |
|
|
207 |
checkpoint, |
|
|
208 |
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) |
|
|
209 |
utils.save_on_master( |
|
|
210 |
checkpoint, |
|
|
211 |
os.path.join(args.output_dir, 'checkpoint.pth')) |
|
|
212 |
|
|
|
213 |
total_time = time.time() - start_time |
|
|
214 |
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
|
|
215 |
print('Training time {}'.format(total_time_str)) |
|
|
216 |
|
|
|
217 |
|
|
|
218 |
def parse_args(): |
|
|
219 |
import argparse |
|
|
220 |
parser = argparse.ArgumentParser(description='PyTorch Classification Training') |
|
|
221 |
|
|
|
222 |
parser.add_argument('--train-file', help='training set of image file') |
|
|
223 |
parser.add_argument('--val-file', help='validation set of image file') |
|
|
224 |
parser.add_argument('--antibodytrn', help='training set of antibody') |
|
|
225 |
parser.add_argument('--antibodyval', help='validation set of antibody') |
|
|
226 |
parser.add_argument('--num-classes', help='number of classes for the objective task', type=int) |
|
|
227 |
|
|
|
228 |
parser.add_argument('--device', default='cuda', help='device') |
|
|
229 |
parser.add_argument('-b', '--batch-size', default=32, type=int) |
|
|
230 |
parser.add_argument('--epochs', default=90, type=int, metavar='N', |
|
|
231 |
help='number of total epochs to run') |
|
|
232 |
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', |
|
|
233 |
help='number of data loading workers (default: 16)') |
|
|
234 |
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate') |
|
|
235 |
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', |
|
|
236 |
help='momentum') |
|
|
237 |
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, |
|
|
238 |
metavar='W', help='weight decay (default: 1e-4)', |
|
|
239 |
dest='weight_decay') |
|
|
240 |
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') |
|
|
241 |
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') |
|
|
242 |
parser.add_argument('--print-freq', default=10, type=int, help='print frequency') |
|
|
243 |
parser.add_argument('--output-dir', default='.', help='path where to save') |
|
|
244 |
parser.add_argument('--resume', default='', help='resume from checkpoint') |
|
|
245 |
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', |
|
|
246 |
help='start epoch') |
|
|
247 |
parser.add_argument( |
|
|
248 |
"--cache-dataset", |
|
|
249 |
dest="cache_dataset", |
|
|
250 |
help="Cache the datasets for quicker initialization. It also serializes the transforms", |
|
|
251 |
action="store_true", |
|
|
252 |
) |
|
|
253 |
parser.add_argument( |
|
|
254 |
"--sync-bn", |
|
|
255 |
dest="sync_bn", |
|
|
256 |
help="Use sync batch norm", |
|
|
257 |
action="store_true", |
|
|
258 |
) |
|
|
259 |
parser.add_argument( |
|
|
260 |
"--test-only", |
|
|
261 |
dest="test_only", |
|
|
262 |
help="Only test the model", |
|
|
263 |
action="store_true", |
|
|
264 |
) |
|
|
265 |
parser.add_argument( |
|
|
266 |
"--pretrained", |
|
|
267 |
dest="pretrained", |
|
|
268 |
help="Use pre-trained models from the modelzoo", |
|
|
269 |
action="store_true", |
|
|
270 |
) |
|
|
271 |
|
|
|
272 |
# distributed training parameters |
|
|
273 |
parser.add_argument('--world-size', default=1, type=int, |
|
|
274 |
help='number of distributed processes') |
|
|
275 |
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') |
|
|
276 |
|
|
|
277 |
args = parser.parse_args() |
|
|
278 |
|
|
|
279 |
return args |
|
|
280 |
|
|
|
281 |
|
|
|
282 |
if __name__ == "__main__": |
|
|
283 |
args = parse_args() |
|
|
284 |
main(args) |