a b/tasks/moco-train.py
1
""" Momentum Contrastive (MoCo) Learning
2
"""
3
import argparse
4
import json
5
import os
6
import sys
7
import time
8
import numpy as np
9
from tqdm import tqdm
10
11
import torch
12
import torch.nn as nn
13
import torchinfo
14
15
sys.path.append(os.getcwd())
16
import utilities.runUtils as rutl
17
import utilities.logUtils as lutl
18
from algorithms.moco import MoCo
19
from algorithms.loss.ssl_losses import NTXentLoss
20
from datacode.natural_image_data import Cifar100Dataset
21
from datacode.ultrasound_data import FetalUSFramesDataset
22
from datacode.augmentations import SimCLRTransform
23
24
print(f"Pytorch version: {torch.__version__}")
25
print(f"cuda version: {torch.version.cuda}")
26
device = 'cuda' if torch.cuda.is_available() else 'cpu'
27
print("Device Used:", device)
28
29
###============================= Configure and Setup ===========================
30
31
CFG = rutl.ObjDict(
32
use_amp = True, #automatic Mixed precision
33
34
datapath    = "/home/mothilal.asokan/Downloads/HC701/Project/US-Fetal-Video-Frames_V1-1/train-all-frames.hdf5",
35
valdatapath = "/home/mothilal.asokan/Downloads/HC701/Project/US-Fetal-Video-Frames_V1-1/valid-all-frames.hdf5",
36
skip_count  = 5,
37
38
epochs      = 20,
39
batch_size  = 288,
40
workers     = 24,
41
image_size  = 256,
42
43
weight_decay = 1e-4,
44
lr           = 0.03,
45
46
featx_arch     = "resnet50", # "resnet34/50/101"
47
featx_pretrain = "IMGNET-1K" , # "IMGNET-1K" or None
48
49
print_freq_step   = 10, #steps
50
ckpt_freq_epoch   = 5,  #epochs
51
valid_freq_epoch  = 5,  #epochs
52
disable_tqdm      = False,   #True--> to disable
53
54
checkpoint_dir= "hypotheses/-dummy/ssl-moco",
55
resume_training = True,
56
)
57
58
## --------
59
parser = argparse.ArgumentParser(description='MoCo Training')
60
parser.add_argument('--load-json', type=str, metavar='JSON',
61
    help='Load settings from file in json format. Command line options override values in python file.')
62
63
64
args = parser.parse_args()
65
66
if args.load_json:
67
    with open(args.load_json, 'rt') as f:
68
        CFG.__dict__.update(json.load(f))
69
70
### ----------------------------------------------------------------------------
71
CFG.gLogPath = CFG.checkpoint_dir
72
CFG.gWeightPath = CFG.checkpoint_dir + '/weights/'
73
74
### ============================================================================
75
76
def getDataLoaders():
77
78
    transform_obj = SimCLRTransform(image_size=CFG.image_size)
79
80
    traindataset = FetalUSFramesDataset( hdf5_file= CFG.datapath,
81
                                transform = transform_obj,
82
                                load2ram = False, frame_skip=CFG.skip_count)
83
84
85
    trainloader  = torch.utils.data.DataLoader( traindataset, shuffle=True,
86
                        batch_size=CFG.batch_size, num_workers=CFG.workers,
87
                        pin_memory=True,drop_last=True )
88
89
90
    validdataset = FetalUSFramesDataset( hdf5_file= CFG.valdatapath,
91
                                transform = transform_obj,
92
                                load2ram = False, frame_skip=CFG.skip_count)
93
94
95
    validloader  = torch.utils.data.DataLoader( validdataset, shuffle=False,
96
                        batch_size=CFG.batch_size, num_workers=CFG.workers,
97
                        pin_memory=True, drop_last=True)
98
99
100
    lutl.LOG2DICTXT({"TRAIN DatasetClass":traindataset.get_info(),
101
                    "TransformsClass": str(transform_obj.get_composition()),
102
                    }, CFG.gLogPath +'/misc.txt')
103
    lutl.LOG2DICTXT({"VALID DatasetClass":validdataset.get_info(),
104
                    "TransformsClass": str(transform_obj.get_composition()),
105
                    }, CFG.gLogPath +'/misc.txt')
106
107
    return trainloader, validloader
108
109
110
def getModelnOptimizer():
111
    model = MoCo(featx_arch=CFG.featx_arch,
112
                        pretrained=CFG.featx_pretrain).to(device)
113
114
    optimizer = torch.optim.SGD(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay,
115
                     momentum=0.9)
116
117
118
    model_info = torchinfo.summary(model, [(CFG.batch_size, 3, CFG.image_size, CFG.image_size)],
119
                                verbose=0)
120
    lutl.LOG2TXT(model_info, CFG.gLogPath +'/misc.txt', console= False)
121
122
    return model.to(device), optimizer
123
124
125
def update_momentum(model: nn.Module, model_ema: nn.Module, m: float):
126
    """Updates parameters of `model_ema` with Exponential Moving Average of `model`
127
    Momentum encoders are a crucial component fo models such as MoCo or BYOL.
128
    Examples:
129
        >>> backbone = resnet18()
130
        >>> projection_head = MoCoProjectionHead()
131
        >>> backbone_momentum = copy.deepcopy(moco)
132
        >>> projection_head_momentum = copy.deepcopy(projection_head)
133
        >>>
134
        >>> # update momentum
135
        >>> update_momentum(moco, moco_momentum, m=0.999)
136
        >>> update_momentum(projection_head, projection_head_momentum, m=0.999)
137
    """
138
    for model_ema, model in zip(model_ema.parameters(), model.parameters()):
139
        model_ema.data = model_ema.data * m + model.data * (1.0 - m)
140
141
142
143
def cosine_schedule(
144
    step: int, max_steps: int, start_value: float, end_value: float
145
) -> float:
146
    """
147
    Use cosine decay to gradually modify start_value to reach target end_value during iterations.
148
    Args:
149
        step:
150
            Current step number.
151
        max_steps:
152
            Total number of steps.
153
        start_value:
154
            Starting value.
155
        end_value:
156
            Target value.
157
    Returns:
158
        Cosine decay value.
159
    """
160
    if step < 0:
161
        raise ValueError("Current step number can't be negative")
162
    if max_steps < 1:
163
        raise ValueError("Total step number must be >= 1")
164
    if step > max_steps:
165
        # Note: we allow step == max_steps even though step starts at 0 and should end
166
        # at max_steps - 1. This is because Pytorch Lightning updates the LR scheduler
167
        # always for the next epoch, even after the last training epoch. This results in
168
        # Pytorch Lightning calling the scheduler with step == max_steps.
169
        raise ValueError(
170
            f"The current step cannot be larger than max_steps but found step {step} and max_steps {max_steps}."
171
        )
172
173
    if max_steps == 1:
174
        # Avoid division by zero
175
        decay = end_value
176
    elif step == max_steps:
177
        # Special case for Pytorch Lightning which updates LR scheduler also for epoch
178
        # after last training epoch.
179
        decay = end_value
180
    else:
181
        decay = (
182
            end_value
183
            - (end_value - start_value)
184
            * (np.cos(np.pi * step / (max_steps - 1)) + 1)
185
            / 2
186
        )
187
    return decay
188
189
190
191
### ----------------------------------------------------------------------------
192
193
def simple_main():
194
    ### SETUP
195
    rutl.START_SEED()
196
    torch.cuda.device(device)
197
    torch.backends.cudnn.benchmark = True
198
199
    if os.path.exists(CFG.checkpoint_dir) and (not CFG.resume_training):
200
        raise Exception("CheckPoint folder already exists and restart_training not enabled; Somethings Wrong!")
201
    if not os.path.exists(CFG.gWeightPath): os.makedirs(CFG.gWeightPath)
202
203
    with open(CFG.gLogPath+"/exp_config.json", 'a') as f:
204
        json.dump(vars(CFG), f, indent=4)
205
206
207
    ### DATA ACCESS
208
    trainloader, validloader = getDataLoaders()
209
210
    ### MODEL, OPTIM
211
    model, optimizer = getModelnOptimizer()
212
213
    criterion = NTXentLoss(memory_bank_size=4096)
214
215
216
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
217
                                    len(trainloader), eta_min=0,last_epoch=-1)
218
    ## Automatically resume from checkpoint if it exists and enabled
219
    ckpt = None
220
    if CFG.resume_training:
221
        try:    ckpt = torch.load(CFG.gWeightPath+'/checkpoint-1.pth', map_location='cpu')
222
        except:
223
            try:ckpt = torch.load(CFG.gWeightPath+'/checkpoint-0.pth', map_location='cpu')
224
            except: print("Check points are not loadable. Starting fresh...")
225
    if ckpt:
226
        start_epoch = ckpt['epoch']
227
        model.load_state_dict(ckpt['model'])
228
        optimizer.load_state_dict(ckpt['optimizer'])
229
        lutl.LOG2TXT(f"Restarting Training from EPOCH:{start_epoch} of {CFG.checkpoint_dir}",  CFG.gLogPath +'/misc.txt')
230
    else:
231
        start_epoch = 0
232
233
234
    ### MODEL TRAINING
235
    start_time = time.time()
236
    best_loss = float('inf')
237
    wgt_suf   = 0  # foolproof savetime crash
238
    if CFG.use_amp: scaler = torch.cuda.amp.GradScaler() # for mixed precision
239
240
    for epoch in range(start_epoch, CFG.epochs):
241
242
        ## ---- Training Routine ----
243
        t_running_loss_ = 0
244
        momentum_val = cosine_schedule(epoch, CFG.epochs, 0.996, 1)
245
246
        model.train()
247
        for step, (x_query, x_key) in tqdm(enumerate(trainloader,
248
                                    start=epoch * len(trainloader)),
249
                                    disable=CFG.disable_tqdm):
250
251
            update_momentum(model.backbone, model.backbone_momentum, m=momentum_val)
252
            update_momentum(model.projection_head, model.projection_head_momentum, m=momentum_val)
253
            x_query = x_query.to(device, non_blocking=True)
254
            x_key = x_key.to(device, non_blocking=True)
255
            optimizer.zero_grad()
256
257
            if CFG.use_amp: ## with mixed precision
258
                with torch.cuda.amp.autocast():
259
                    query = model(x_query)
260
                    key = model.forward_momentum(x_key)
261
                    loss = criterion(query, key)
262
263
                scaler.scale(loss).backward()
264
                scaler.step(optimizer)
265
                scaler.update()
266
            else:
267
                query = model(x_query)
268
                key = model.forward_momentum(x_key)
269
                loss = criterion(query, key)
270
                loss.backward()
271
                optimizer.step()
272
            t_running_loss_+=loss.item()
273
274
            if step % CFG.print_freq_step == 0:
275
                stats = dict(epoch=epoch, step=step,
276
                             lr_weights=optimizer.param_groups[0]['lr'],
277
                             step_loss=loss.item(),
278
                             time=int(time.time() - start_time))
279
                lutl.LOG2DICTXT(stats, CFG.checkpoint_dir +'/train-stats.txt')
280
        train_epoch_loss = t_running_loss_/len(trainloader)
281
282
        scheduler.step()
283
284
        # save checkpoint
285
        if (epoch+1) % CFG.ckpt_freq_epoch == 0:
286
            wgt_suf = (wgt_suf+1) %2
287
            state = dict(epoch=epoch, model=model.state_dict(),
288
                            optimizer=optimizer.state_dict())
289
            torch.save(state, CFG.gWeightPath +f'/checkpoint-{wgt_suf}.pth')
290
291
292
        ## ---- Validation Routine ----
293
        if (epoch+1) % CFG.valid_freq_epoch == 0:
294
            model.eval()
295
            v_running_loss_ = 0
296
            with torch.no_grad():
297
                for (x_query, x_key) in tqdm(validloader,  total=len(validloader),
298
                                    disable=CFG.disable_tqdm):
299
                    update_momentum(model.backbone, model.backbone_momentum, m=momentum_val)
300
                    update_momentum(model.projection_head, model.projection_head_momentum, m=momentum_val)
301
                    x_query = x_query.to(device, non_blocking=True)
302
                    x_key = x_key.to(device, non_blocking=True)
303
                    query = model(x_query)
304
                    key = model.forward_momentum(x_key)
305
                    loss = criterion(query, key)
306
                    v_running_loss_ += loss.item()
307
            valid_epoch_loss = v_running_loss_/len(validloader)
308
309
            # just check
310
            best_flag = False
311
            if valid_epoch_loss < best_loss:
312
                best_flag = True
313
                best_loss = valid_epoch_loss
314
315
            v_stats = dict(epoch=epoch, best=best_flag, wgt_suf=wgt_suf,
316
                            train_loss=train_epoch_loss,
317
                            valid_loss=valid_epoch_loss)
318
            lutl.LOG2DICTXT(v_stats, CFG.gLogPath+'/valid-stats.txt')
319
320
321
if __name__ == '__main__':
322
    simple_main()