|
a |
|
b/opensrh/train/train_contrastive.py |
|
|
1 |
"""Contrastive learning experiment training script. |
|
|
2 |
|
|
|
3 |
Copyright (c) 2022 University of Michigan. All rights reserved. |
|
|
4 |
Licensed under the MIT License. See LICENSE for license information. |
|
|
5 |
""" |
|
|
6 |
|
|
|
7 |
import yaml |
|
|
8 |
import logging |
|
|
9 |
from functools import partial |
|
|
10 |
from typing import Dict, Any |
|
|
11 |
|
|
|
12 |
import torch |
|
|
13 |
|
|
|
14 |
import pytorch_lightning as pl |
|
|
15 |
import torchmetrics |
|
|
16 |
|
|
|
17 |
from opensrh.models import MLP, resnet_backbone, ContrastiveLearningNetwork, vit_backbone |
|
|
18 |
from opensrh.train.common import (setup_output_dirs, parse_args, get_exp_name, |
|
|
19 |
get_contrastive_dataloaders, config_loggers, |
|
|
20 |
get_optimizer_func, get_scheduler_func) |
|
|
21 |
from opensrh.losses.supcon import SupConLoss |
|
|
22 |
|
|
|
23 |
|
|
|
24 |
class ContrastiveSystem(pl.LightningModule): |
|
|
25 |
"""Lightning system for contrastive learning experiments.""" |
|
|
26 |
|
|
|
27 |
def __init__(self, cf: Dict[str, Any], num_it_per_ep: int): |
|
|
28 |
super().__init__() |
|
|
29 |
self.cf_ = cf |
|
|
30 |
|
|
|
31 |
if cf["model"]["backbone"] == "resnet50": |
|
|
32 |
bb = partial(resnet_backbone, arch=cf["model"]["backbone"]) |
|
|
33 |
elif cf["model"]["backbone"] == "vit": |
|
|
34 |
bb = partial(vit_backbone, cf["model"]["backbone_params"]) |
|
|
35 |
else: |
|
|
36 |
raise NotImplementedError() |
|
|
37 |
|
|
|
38 |
mlp = partial(MLP, |
|
|
39 |
n_in=bb().num_out, |
|
|
40 |
hidden_layers=cf["model"]["mlp_hidden"], |
|
|
41 |
n_out=cf["model"]["num_embedding_out"]) |
|
|
42 |
self.model = ContrastiveLearningNetwork(bb, mlp) |
|
|
43 |
self.criterion = SupConLoss() |
|
|
44 |
self.train_loss = torchmetrics.MeanMetric() |
|
|
45 |
self.val_loss = torchmetrics.MeanMetric() |
|
|
46 |
|
|
|
47 |
self.num_it_per_ep_ = num_it_per_ep |
|
|
48 |
|
|
|
49 |
def predict_step(self, batch, batch_idx): |
|
|
50 |
out = self.model.bb(batch["image"]) |
|
|
51 |
return { |
|
|
52 |
"path": batch["path"], |
|
|
53 |
"label": batch["label"], |
|
|
54 |
"embeddings": out |
|
|
55 |
} |
|
|
56 |
|
|
|
57 |
def on_train_epoch_end(self): |
|
|
58 |
train_loss = self.train_loss.compute() |
|
|
59 |
self.log("train/contrastive_manualepoch", |
|
|
60 |
train_loss, |
|
|
61 |
on_epoch=True, |
|
|
62 |
sync_dist=False, |
|
|
63 |
rank_zero_only=True) |
|
|
64 |
logging.info(f"train/contrastive_manualepoch {train_loss}") |
|
|
65 |
self.train_loss.reset() |
|
|
66 |
|
|
|
67 |
def on_validation_epoch_end(self): |
|
|
68 |
val_loss = self.val_loss.compute() |
|
|
69 |
self.log("val/contrastive_manualepoch", |
|
|
70 |
val_loss, |
|
|
71 |
on_epoch=True, |
|
|
72 |
sync_dist=False, |
|
|
73 |
rank_zero_only=True) |
|
|
74 |
logging.info(f"val/contrastive_manualepoch {val_loss}") |
|
|
75 |
self.val_loss.reset() |
|
|
76 |
|
|
|
77 |
def configure_optimizers(self): |
|
|
78 |
# if not training, no optimizer |
|
|
79 |
if "training" not in self.cf_: |
|
|
80 |
return None |
|
|
81 |
|
|
|
82 |
# get optimizer |
|
|
83 |
opt = get_optimizer_func(self.cf_)(self.model.parameters()) |
|
|
84 |
|
|
|
85 |
# check if use a learn rate scheduler |
|
|
86 |
sched_func = get_scheduler_func(self.cf_, self.num_it_per_ep_) |
|
|
87 |
if not sched_func: |
|
|
88 |
return opt |
|
|
89 |
|
|
|
90 |
# get learn rate scheduler |
|
|
91 |
lr_scheduler_config = { |
|
|
92 |
"scheduler": sched_func(opt), |
|
|
93 |
"interval": "step", |
|
|
94 |
"frequency": 1, |
|
|
95 |
"name": "lr" |
|
|
96 |
} |
|
|
97 |
|
|
|
98 |
return [opt], lr_scheduler_config |
|
|
99 |
|
|
|
100 |
def configure_ddp(self, *args, **kwargs): |
|
|
101 |
logging.basicConfig(level=logging.INFO) |
|
|
102 |
return super().configure_ddp(*args, **kwargs) |
|
|
103 |
|
|
|
104 |
|
|
|
105 |
class SimCLRSystem(ContrastiveSystem): |
|
|
106 |
"""Lightning system for SimCLR experiment""" |
|
|
107 |
|
|
|
108 |
def __init__(self, cf, num_it_per_ep): |
|
|
109 |
super().__init__(cf, num_it_per_ep) |
|
|
110 |
|
|
|
111 |
def forward(self, data): |
|
|
112 |
return torch.cat([self.model(x) for x in data["image"]], dim=1) |
|
|
113 |
|
|
|
114 |
def training_step(self, batch, batch_idx): |
|
|
115 |
pred = torch.cat([self.model(x) for x in batch["image"]], dim=1) |
|
|
116 |
pred_gather = self.all_gather(pred, sync_grads=True) |
|
|
117 |
pred_gather = pred_gather.reshape(-1, *pred_gather.shape[-2:]) |
|
|
118 |
|
|
|
119 |
loss = self.criterion(pred_gather) |
|
|
120 |
bs = batch["image"][0].shape[0] |
|
|
121 |
self.log("train/contrastive", |
|
|
122 |
loss, |
|
|
123 |
on_step=True, |
|
|
124 |
on_epoch=True, |
|
|
125 |
batch_size=bs) |
|
|
126 |
self.train_loss.update(loss, weight=bs) |
|
|
127 |
return loss |
|
|
128 |
|
|
|
129 |
def validation_step(self, batch, batch_idx): |
|
|
130 |
bs = batch["image"][0].shape[0] |
|
|
131 |
pred = torch.cat([self.model(x) for x in batch["image"]], dim=1) |
|
|
132 |
pred_gather = self.all_gather(pred, sync_grads=True) |
|
|
133 |
pred_gather = pred_gather.reshape(-1, *pred_gather.shape[-2:]) |
|
|
134 |
|
|
|
135 |
loss = self.criterion(pred_gather) |
|
|
136 |
self.val_loss.update(loss, weight=bs) |
|
|
137 |
|
|
|
138 |
|
|
|
139 |
class SupConSystem(ContrastiveSystem): |
|
|
140 |
"""Lightning system for SupCon experiment""" |
|
|
141 |
|
|
|
142 |
def __init__(self, cf, num_it_per_ep): |
|
|
143 |
super().__init__(cf, num_it_per_ep) |
|
|
144 |
|
|
|
145 |
def forward(self, data): |
|
|
146 |
return torch.cat([self.model(x) for x in data["image"]], dim=1) |
|
|
147 |
|
|
|
148 |
def training_step(self, batch, batch_idx): |
|
|
149 |
pred = torch.cat([self.model(x) for x in batch["image"]], dim=1) |
|
|
150 |
pred_gather = self.all_gather(pred, sync_grads=True) |
|
|
151 |
pred_gather = pred_gather.reshape(-1, *pred_gather.shape[-2:]) |
|
|
152 |
label_gather = self.all_gather(batch["label"]).reshape(-1, 1) |
|
|
153 |
|
|
|
154 |
loss = self.criterion(pred_gather, label_gather) |
|
|
155 |
bs = batch["image"][0].shape[0] |
|
|
156 |
self.log("train/contrastive", |
|
|
157 |
loss, |
|
|
158 |
on_step=True, |
|
|
159 |
on_epoch=True, |
|
|
160 |
batch_size=bs) |
|
|
161 |
self.train_loss.update(loss, weight=bs) |
|
|
162 |
return loss |
|
|
163 |
|
|
|
164 |
def validation_step(self, batch, batch_idx): |
|
|
165 |
bs = batch["image"][0].shape[0] |
|
|
166 |
pred = torch.cat([self.model(x) for x in batch["image"]], dim=1) |
|
|
167 |
pred_gather = self.all_gather(pred, sync_grads=True) |
|
|
168 |
pred_gather = pred_gather.reshape(-1, *pred_gather.shape[-2:]) |
|
|
169 |
label_gather = self.all_gather(batch["label"]).reshape(-1, 1) |
|
|
170 |
|
|
|
171 |
loss = self.criterion(pred_gather, label_gather) |
|
|
172 |
self.val_loss.update(loss, weight=bs) |
|
|
173 |
|
|
|
174 |
|
|
|
175 |
def main(): |
|
|
176 |
cf_fd = parse_args() |
|
|
177 |
cf = yaml.load(cf_fd, Loader=yaml.FullLoader) |
|
|
178 |
exp_root, model_dir, cp_config = setup_output_dirs(cf, get_exp_name, "") |
|
|
179 |
pl.seed_everything(cf["infra"]["seed"]) |
|
|
180 |
|
|
|
181 |
# logging and copying config files |
|
|
182 |
cp_config(cf_fd.name) |
|
|
183 |
config_loggers(exp_root) |
|
|
184 |
|
|
|
185 |
# get dataloaders |
|
|
186 |
train_loader, valid_loader = get_contrastive_dataloaders(cf) |
|
|
187 |
logging.info(f"num devices: {torch.cuda.device_count()}") |
|
|
188 |
logging.info(f"num workers in dataloader: {train_loader.num_workers}") |
|
|
189 |
|
|
|
190 |
num_it_per_ep = len(train_loader) |
|
|
191 |
if torch.cuda.device_count() > 1: |
|
|
192 |
num_it_per_ep //= torch.cuda.device_count() |
|
|
193 |
|
|
|
194 |
if cf["training"]["objective"] == "supcon": |
|
|
195 |
system_func = SupConSystem |
|
|
196 |
elif cf["training"]["objective"] == "simclr": |
|
|
197 |
system_func = SimCLRSystem |
|
|
198 |
else: |
|
|
199 |
raise NotImplementedError() |
|
|
200 |
|
|
|
201 |
ce_exp = system_func(cf, num_it_per_ep) |
|
|
202 |
|
|
|
203 |
# config loggers |
|
|
204 |
logger = [ |
|
|
205 |
pl.loggers.TensorBoardLogger(save_dir=exp_root, name="tb"), |
|
|
206 |
pl.loggers.CSVLogger(save_dir=exp_root, name="csv") |
|
|
207 |
] |
|
|
208 |
|
|
|
209 |
# config callbacks |
|
|
210 |
epoch_ckpt = pl.callbacks.ModelCheckpoint( |
|
|
211 |
dirpath=model_dir, |
|
|
212 |
save_top_k=-1, |
|
|
213 |
save_on_train_epoch_end=True, |
|
|
214 |
filename="ckpt-epoch{epoch}-loss{val/contrastive_manualepoch:.2f}", |
|
|
215 |
auto_insert_metric_name=False) |
|
|
216 |
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step", |
|
|
217 |
log_momentum=False) |
|
|
218 |
|
|
|
219 |
# create trainer |
|
|
220 |
trainer = pl.Trainer(accelerator="gpu", |
|
|
221 |
devices=-1, |
|
|
222 |
default_root_dir=exp_root, |
|
|
223 |
strategy=pl.strategies.DDPStrategy( |
|
|
224 |
find_unused_parameters=False, static_graph=True), |
|
|
225 |
logger=logger, |
|
|
226 |
log_every_n_steps=10, |
|
|
227 |
callbacks=[epoch_ckpt, lr_monitor], |
|
|
228 |
max_epochs=cf["training"]["num_epochs"]) |
|
|
229 |
trainer.fit(ce_exp, |
|
|
230 |
train_dataloaders=train_loader, |
|
|
231 |
val_dataloaders=valid_loader) |
|
|
232 |
|
|
|
233 |
|
|
|
234 |
if __name__ == '__main__': |
|
|
235 |
main() |