Switch to unified view

a b/opensrh/train/train_contrastive.py
1
"""Contrastive learning experiment training script.
2
3
Copyright (c) 2022 University of Michigan. All rights reserved.
4
Licensed under the MIT License. See LICENSE for license information.
5
"""
6
7
import yaml
8
import logging
9
from functools import partial
10
from typing import Dict, Any
11
12
import torch
13
14
import pytorch_lightning as pl
15
import torchmetrics
16
17
from opensrh.models import MLP, resnet_backbone, ContrastiveLearningNetwork, vit_backbone
18
from opensrh.train.common import (setup_output_dirs, parse_args, get_exp_name,
19
                                  get_contrastive_dataloaders, config_loggers,
20
                                  get_optimizer_func, get_scheduler_func)
21
from opensrh.losses.supcon import SupConLoss
22
23
24
class ContrastiveSystem(pl.LightningModule):
25
    """Lightning system for contrastive learning experiments."""
26
27
    def __init__(self, cf: Dict[str, Any], num_it_per_ep: int):
28
        super().__init__()
29
        self.cf_ = cf
30
31
        if cf["model"]["backbone"] == "resnet50":
32
            bb = partial(resnet_backbone, arch=cf["model"]["backbone"])
33
        elif cf["model"]["backbone"] == "vit":
34
            bb = partial(vit_backbone, cf["model"]["backbone_params"])
35
        else:
36
            raise NotImplementedError()
37
38
        mlp = partial(MLP,
39
                      n_in=bb().num_out,
40
                      hidden_layers=cf["model"]["mlp_hidden"],
41
                      n_out=cf["model"]["num_embedding_out"])
42
        self.model = ContrastiveLearningNetwork(bb, mlp)
43
        self.criterion = SupConLoss()
44
        self.train_loss = torchmetrics.MeanMetric()
45
        self.val_loss = torchmetrics.MeanMetric()
46
47
        self.num_it_per_ep_ = num_it_per_ep
48
49
    def predict_step(self, batch, batch_idx):
50
        out = self.model.bb(batch["image"])
51
        return {
52
            "path": batch["path"],
53
            "label": batch["label"],
54
            "embeddings": out
55
        }
56
57
    def on_train_epoch_end(self):
58
        train_loss = self.train_loss.compute()
59
        self.log("train/contrastive_manualepoch",
60
                 train_loss,
61
                 on_epoch=True,
62
                 sync_dist=False,
63
                 rank_zero_only=True)
64
        logging.info(f"train/contrastive_manualepoch {train_loss}")
65
        self.train_loss.reset()
66
67
    def on_validation_epoch_end(self):
68
        val_loss = self.val_loss.compute()
69
        self.log("val/contrastive_manualepoch",
70
                 val_loss,
71
                 on_epoch=True,
72
                 sync_dist=False,
73
                 rank_zero_only=True)
74
        logging.info(f"val/contrastive_manualepoch {val_loss}")
75
        self.val_loss.reset()
76
77
    def configure_optimizers(self):
78
        # if not training, no optimizer
79
        if "training" not in self.cf_:
80
            return None
81
82
        # get optimizer
83
        opt = get_optimizer_func(self.cf_)(self.model.parameters())
84
85
        # check if use a learn rate scheduler
86
        sched_func = get_scheduler_func(self.cf_, self.num_it_per_ep_)
87
        if not sched_func:
88
            return opt
89
90
        # get learn rate scheduler
91
        lr_scheduler_config = {
92
            "scheduler": sched_func(opt),
93
            "interval": "step",
94
            "frequency": 1,
95
            "name": "lr"
96
        }
97
98
        return [opt], lr_scheduler_config
99
100
    def configure_ddp(self, *args, **kwargs):
101
        logging.basicConfig(level=logging.INFO)
102
        return super().configure_ddp(*args, **kwargs)
103
104
105
class SimCLRSystem(ContrastiveSystem):
106
    """Lightning system for SimCLR experiment"""
107
108
    def __init__(self, cf, num_it_per_ep):
109
        super().__init__(cf, num_it_per_ep)
110
111
    def forward(self, data):
112
        return torch.cat([self.model(x) for x in data["image"]], dim=1)
113
114
    def training_step(self, batch, batch_idx):
115
        pred = torch.cat([self.model(x) for x in batch["image"]], dim=1)
116
        pred_gather = self.all_gather(pred, sync_grads=True)
117
        pred_gather = pred_gather.reshape(-1, *pred_gather.shape[-2:])
118
119
        loss = self.criterion(pred_gather)
120
        bs = batch["image"][0].shape[0]
121
        self.log("train/contrastive",
122
                 loss,
123
                 on_step=True,
124
                 on_epoch=True,
125
                 batch_size=bs)
126
        self.train_loss.update(loss, weight=bs)
127
        return loss
128
129
    def validation_step(self, batch, batch_idx):
130
        bs = batch["image"][0].shape[0]
131
        pred = torch.cat([self.model(x) for x in batch["image"]], dim=1)
132
        pred_gather = self.all_gather(pred, sync_grads=True)
133
        pred_gather = pred_gather.reshape(-1, *pred_gather.shape[-2:])
134
135
        loss = self.criterion(pred_gather)
136
        self.val_loss.update(loss, weight=bs)
137
138
139
class SupConSystem(ContrastiveSystem):
140
    """Lightning system for SupCon experiment"""
141
142
    def __init__(self, cf, num_it_per_ep):
143
        super().__init__(cf, num_it_per_ep)
144
145
    def forward(self, data):
146
        return torch.cat([self.model(x) for x in data["image"]], dim=1)
147
148
    def training_step(self, batch, batch_idx):
149
        pred = torch.cat([self.model(x) for x in batch["image"]], dim=1)
150
        pred_gather = self.all_gather(pred, sync_grads=True)
151
        pred_gather = pred_gather.reshape(-1, *pred_gather.shape[-2:])
152
        label_gather = self.all_gather(batch["label"]).reshape(-1, 1)
153
154
        loss = self.criterion(pred_gather, label_gather)
155
        bs = batch["image"][0].shape[0]
156
        self.log("train/contrastive",
157
                 loss,
158
                 on_step=True,
159
                 on_epoch=True,
160
                 batch_size=bs)
161
        self.train_loss.update(loss, weight=bs)
162
        return loss
163
164
    def validation_step(self, batch, batch_idx):
165
        bs = batch["image"][0].shape[0]
166
        pred = torch.cat([self.model(x) for x in batch["image"]], dim=1)
167
        pred_gather = self.all_gather(pred, sync_grads=True)
168
        pred_gather = pred_gather.reshape(-1, *pred_gather.shape[-2:])
169
        label_gather = self.all_gather(batch["label"]).reshape(-1, 1)
170
171
        loss = self.criterion(pred_gather, label_gather)
172
        self.val_loss.update(loss, weight=bs)
173
174
175
def main():
176
    cf_fd = parse_args()
177
    cf = yaml.load(cf_fd, Loader=yaml.FullLoader)
178
    exp_root, model_dir, cp_config = setup_output_dirs(cf, get_exp_name, "")
179
    pl.seed_everything(cf["infra"]["seed"])
180
181
    # logging and copying config files
182
    cp_config(cf_fd.name)
183
    config_loggers(exp_root)
184
185
    # get dataloaders
186
    train_loader, valid_loader = get_contrastive_dataloaders(cf)
187
    logging.info(f"num devices: {torch.cuda.device_count()}")
188
    logging.info(f"num workers in dataloader: {train_loader.num_workers}")
189
190
    num_it_per_ep = len(train_loader)
191
    if torch.cuda.device_count() > 1:
192
        num_it_per_ep //= torch.cuda.device_count()
193
194
    if cf["training"]["objective"] == "supcon":
195
        system_func = SupConSystem
196
    elif cf["training"]["objective"] == "simclr":
197
        system_func = SimCLRSystem
198
    else:
199
        raise NotImplementedError()
200
201
    ce_exp = system_func(cf, num_it_per_ep)
202
203
    # config loggers
204
    logger = [
205
        pl.loggers.TensorBoardLogger(save_dir=exp_root, name="tb"),
206
        pl.loggers.CSVLogger(save_dir=exp_root, name="csv")
207
    ]
208
209
    # config callbacks
210
    epoch_ckpt = pl.callbacks.ModelCheckpoint(
211
        dirpath=model_dir,
212
        save_top_k=-1,
213
        save_on_train_epoch_end=True,
214
        filename="ckpt-epoch{epoch}-loss{val/contrastive_manualepoch:.2f}",
215
        auto_insert_metric_name=False)
216
    lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step",
217
                                                  log_momentum=False)
218
219
    # create trainer
220
    trainer = pl.Trainer(accelerator="gpu",
221
                         devices=-1,
222
                         default_root_dir=exp_root,
223
                         strategy=pl.strategies.DDPStrategy(
224
                             find_unused_parameters=False, static_graph=True),
225
                         logger=logger,
226
                         log_every_n_steps=10,
227
                         callbacks=[epoch_ckpt, lr_monitor],
228
                         max_epochs=cf["training"]["num_epochs"])
229
    trainer.fit(ce_exp,
230
                train_dataloaders=train_loader,
231
                val_dataloaders=valid_loader)
232
233
234
if __name__ == '__main__':
235
    main()