a b/tasks/vic-train.py
1
""" VIC-reg self-supervision training
2
"""
3
4
import argparse
5
import json
6
import math
7
import os
8
import random
9
import signal
10
import subprocess
11
import sys
12
import time
13
from tqdm import tqdm
14
15
from torch import nn, optim
16
import torch
17
import torchvision
18
import torchinfo
19
20
sys.path.append(os.getcwd())
21
import utilities.runUtils as rutl
22
import utilities.logUtils as lutl
23
from algorithms.vicreg import VICReg, LARS, adjust_learning_rate
24
from datacode.natural_image_data import Cifar100Dataset
25
from datacode.ultrasound_data import FetalUSFramesDataset
26
from datacode.augmentations import BarlowTwinsTransformOrig, CustomInfoMaxTransform
27
28
29
print(f"Pytorch version: {torch.__version__}")
30
print(f"cuda version: {torch.version.cuda}")
31
device = 'cuda' if torch.cuda.is_available() else 'cpu'
32
print("Device Used:", device)
33
34
###============================= Configure and Setup ===========================
35
36
CFG = rutl.ObjDict(
37
use_amp = True, #automatic Mixed precision
38
39
datapath    = "/home/USR/WERK/data/a.hdf5",
40
valdatapath = "/home/USR/WERK/valdata/b.hdf5",
41
skip_count = 5,
42
43
epochs      = 1000,
44
batch_size  = 2048,
45
workers     = 24,
46
image_size  = 256,
47
48
base_lr      = 0.2,
49
weight_decay = 1e-6,
50
sim_coeff    = 25.0, # Invariance
51
std_coeff    = 25.0, # Variance
52
cov_coeff    = 1.0,  # Covariance
53
54
featx_arch     = "resnet50",
55
featx_pretrain =  None, # "IMGNET-1K"
56
projector      = [8192,8192,8192],
57
58
print_freq_step   = 1000, #steps
59
ckpt_freq_epoch   = 5,  #epochs
60
valid_freq_epoch  = 5,  #epochs
61
disable_tqdm      = False,   #True--> to disable
62
63
checkpoint_dir  = "hypotheses/-dummy/ssl-vicreg/",
64
resume_training = True,
65
)
66
67
## --------
68
parser = argparse.ArgumentParser(description='VIC-Reg ISIC Training')
69
parser.add_argument('--load-json', type=str, metavar='JSON',
70
    help='Load settings from file in json format. Command line options override values in file.')
71
72
args = parser.parse_args()
73
74
if args.load_json:
75
    with open(args.load_json, 'rt') as f:
76
        CFG.__dict__.update(json.load(f))
77
78
### ----------------------------------------------------------------------------
79
CFG.gLogPath = CFG.checkpoint_dir
80
CFG.gWeightPath = CFG.checkpoint_dir + '/weights/'
81
82
### ============================================================================
83
84
85
def getDataLoaders():
86
87
    transform_obj = BarlowTwinsTransformOrig(image_size=CFG.image_size)
88
89
    traindataset = FetalUSFramesDataset( hdf5_file= CFG.datapath,
90
                                transform = transform_obj,
91
                                load2ram = False, frame_skip=CFG.skip_count)
92
93
94
    trainloader  = torch.utils.data.DataLoader( traindataset, shuffle=True,
95
                        batch_size=CFG.batch_size, num_workers=CFG.workers,
96
                        pin_memory=True)
97
98
    validdataset = FetalUSFramesDataset( hdf5_file= CFG.valdatapath,
99
                                transform = transform_obj,
100
                                load2ram = False, frame_skip=CFG.skip_count)
101
102
103
    validloader  = torch.utils.data.DataLoader( validdataset, shuffle=False,
104
                        batch_size=CFG.batch_size, num_workers=CFG.workers,
105
                        pin_memory=True)
106
107
108
    lutl.LOG2DICTXT({"TRAIN DatasetClass":traindataset.get_info(),
109
                    "TransformsClass": str(transform_obj.get_composition()),
110
                    }, CFG.gLogPath +'/misc.txt')
111
    lutl.LOG2DICTXT({"VALID DatasetClass":validdataset.get_info(),
112
                    "TransformsClass": str(transform_obj.get_composition()),
113
                    }, CFG.gLogPath +'/misc.txt')
114
115
    return trainloader, validloader
116
117
118
def getModelnOptimizer():
119
    model = VICReg( featx_arch=CFG.featx_arch,
120
                    projector_sizes=CFG.projector,
121
                    batch_size=CFG.batch_size,
122
                    sim_coeff = CFG.sim_coeff,
123
                    std_coeff = CFG.std_coeff,
124
                    cov_coeff = CFG.cov_coeff,
125
                    featx_pretrain=CFG.featx_pretrain,
126
                    ).to(device)
127
128
    optimizer = LARS(model.parameters(), lr=0, weight_decay=CFG.weight_decay,
129
                     weight_decay_filter=True, lars_adaptation_filter=True)
130
131
    model_info = torchinfo.summary(model, 2*[(1, 3, CFG.image_size, CFG.image_size)],
132
                                verbose=0)
133
    lutl.LOG2TXT(model_info, CFG.gLogPath +'/misc.txt', console= False)
134
135
    return model.to(device), optimizer
136
137
138
### ----------------------------------------------------------------------------
139
140
def simple_main():
141
    ### SETUP
142
    rutl.START_SEED()
143
    torch.cuda.device(device)
144
    torch.backends.cudnn.benchmark = True
145
146
    if os.path.exists(CFG.checkpoint_dir) and (not CFG.resume_training):
147
        raise Exception("CheckPoint folder already exists and restart_training not enabled; Somethings Wrong!")
148
    if not os.path.exists(CFG.gWeightPath): os.makedirs(CFG.gWeightPath)
149
150
    with open(CFG.gLogPath+"/exp_config.json", 'a') as f:
151
        json.dump(vars(CFG), f, indent=4)
152
153
154
    ### DATA ACCESS
155
    trainloader, validloader = getDataLoaders()
156
157
    ### MODEL, OPTIM
158
    model, optimizer = getModelnOptimizer()
159
160
    ## Automatically resume from checkpoint if it exists and enabled
161
    ckpt = None
162
    if CFG.resume_training:
163
        try:    ckpt = torch.load(CFG.gWeightPath+'/checkpoint-1.pth', map_location='cpu')
164
        except:
165
            try:ckpt = torch.load(CFG.gWeightPath+'/checkpoint-0.pth', map_location='cpu')
166
            except: print("Check points are not loadable. Starting fresh...")
167
    if ckpt:
168
        start_epoch = ckpt['epoch']
169
        model.load_state_dict(ckpt['model'])
170
        optimizer.load_state_dict(ckpt['optimizer'])
171
        lutl.LOG2TXT(f"Restarting Training from EPOCH:{start_epoch} of {CFG.checkpoint_dir}",  CFG.gLogPath +'/misc.txt')
172
    else:
173
        start_epoch = 0
174
175
176
    ### MODEL TRAINING
177
    start_time = time.time()
178
    best_loss = float('inf')
179
    wgt_suf   = 0  # foolproof savetime crash
180
    if CFG.use_amp: scaler = torch.cuda.amp.GradScaler() # for mixed precision
181
182
    for epoch in range(start_epoch, CFG.epochs):
183
184
        ## ---- Training Routine ----
185
        t_running_loss_ = 0
186
        model.train()
187
        for step, (y1, y2) in tqdm(enumerate(trainloader,
188
                                            start=epoch * len(trainloader)),
189
                                    disable=CFG.disable_tqdm):
190
            y1 = y1.to(device, non_blocking=True)
191
            y2 = y2.to(device, non_blocking=True)
192
            lr_ = adjust_learning_rate(CFG, optimizer, trainloader, step)
193
            optimizer.zero_grad()
194
195
            if CFG.use_amp: ## with mixed precision
196
                with torch.cuda.amp.autocast():
197
                    loss = model.forward(y1, y2)
198
                scaler.scale(loss).backward()
199
                scaler.step(optimizer)
200
                scaler.update()
201
            else:
202
                loss = model.forward(y1, y2)
203
                loss.backward()
204
                optimizer.step()
205
            t_running_loss_+=loss.item()
206
207
            if step % CFG.print_freq_step == 0:
208
                stats = dict(epoch=epoch, step=step,
209
                             time=int(time.time() - start_time),
210
                             step_loss=loss.item(),
211
                             lr= lr_,)
212
                lutl.LOG2DICTXT(stats, CFG.checkpoint_dir +'/train-stats.txt')
213
        train_epoch_loss = t_running_loss_/len(trainloader)
214
215
        # save checkpoint
216
        if (epoch+1) % CFG.ckpt_freq_epoch == 0:
217
            wgt_suf = (wgt_suf+1) %2
218
            state = dict(epoch=epoch, model=model.state_dict(),
219
                            optimizer=optimizer.state_dict())
220
            torch.save(state, CFG.gWeightPath +f'/checkpoint-{wgt_suf}.pth')
221
222
223
        ## ---- Validation Routine ----
224
        if (epoch+1) % CFG.valid_freq_epoch == 0:
225
            model.eval()
226
            v_running_loss_ = 0
227
            with torch.no_grad():
228
                for (y1, y2) in tqdm(validloader,  total=len(validloader),
229
                                    disable=CFG.disable_tqdm):
230
                    y1 = y1.to(device, non_blocking=True)
231
                    y2 = y2.to(device, non_blocking=True)
232
                    loss = model.forward(y1, y2)
233
                    v_running_loss_ += loss.item()
234
            valid_epoch_loss = v_running_loss_/len(validloader)
235
            best_flag = False
236
            if valid_epoch_loss < best_loss:
237
                best_flag = True
238
                best_loss = valid_epoch_loss
239
240
            v_stats = dict(epoch=epoch, best=best_flag, wgt_suf=wgt_suf,
241
                            train_loss=train_epoch_loss,
242
                            valid_loss=valid_epoch_loss)
243
            lutl.LOG2DICTXT(v_stats, CFG.gLogPath+'/valid-stats.txt')
244
245
246
if __name__ == '__main__':
247
    simple_main()