a b/tasks/bt-train.py
1
""" Barlow Twin self-supervision training
2
"""
3
import argparse
4
import json
5
import math
6
import os
7
import random
8
import signal
9
import subprocess
10
import sys
11
import time
12
from tqdm import tqdm
13
14
from torch import nn, optim
15
import torch
16
import torchvision
17
import torchinfo
18
import numpy as np
19
from sklearn.model_selection import train_test_split as sk_train_test_split
20
21
sys.path.append(os.getcwd())
22
import utilities.runUtils as rutl
23
import utilities.logUtils as lutl
24
from algorithms.barlowtwins import BarlowTwins, LARS, adjust_learning_rate
25
from datacode.natural_image_data import Cifar100Dataset
26
from datacode.ultrasound_data import FetalUSFramesDataset, ClassifyDataFromCSV
27
from datacode.augmentations import BarlowTwinsTransformOrig, ClassifierTransform
28
29
30
print(f"Pytorch version: {torch.__version__}")
31
print(f"cuda version: {torch.version.cuda}")
32
device = 'cuda' if torch.cuda.is_available() else 'cpu'
33
print("Device Used:", device)
34
35
###============================= Configure and Setup ===========================
36
37
CFG = rutl.ObjDict(
38
use_amp = True, #automatic Mixed precision
39
40
datapath    = "/home/USR/WERK/data/a.hdf5",
41
valdatapath = "/home/USR/WERK/valdata/b.hdf5",
42
skip_count  = 5,
43
44
epochs      = 1000,
45
batch_size  = 2048,
46
workers     = 24,
47
image_size  = 256,
48
49
learning_rate_weights = 0.2,
50
learning_rate_biases  = 0.0048,
51
weight_decay = 1e-6,
52
lmbd         = 0.0051,
53
54
featx_arch = "resnet50", # "resnet34/50/101"
55
featx_pretrain =  None, # "IMGNET-1K" or None
56
projector = [8192,8192,8192],
57
58
print_freq_step  = 10 , #steps
59
ckpt_freq_epoch  = 5,  #epochs
60
valid_freq_epoch = 5,  #epochs
61
disable_tqdm     = False,   #True--> to disable
62
63
checkpoint_dir  = "hypotheses/-dummy/ssl-barlow/",
64
resume_training = False,
65
)
66
67
## --------
68
parser = argparse.ArgumentParser(description='Barlow Twins Training')
69
parser.add_argument('--load-json', type=str, metavar='JSON',
70
    help='Load settings from file in json format. Command line options override values in file.')
71
72
args = parser.parse_args()
73
74
if args.load_json:
75
    with open(args.load_json, 'rt') as f:
76
        CFG.__dict__.update(json.load(f))
77
78
### ----------------------------------------------------------------------------
79
CFG.gLogPath = CFG.checkpoint_dir
80
CFG.gWeightPath = CFG.checkpoint_dir + '/weights/'
81
82
### ============================================================================
83
84
85
def getDataLoaders():
86
    """ Unlabelled SSL Dataset
87
    """
88
89
    transform_obj = BarlowTwinsTransformOrig(image_size=CFG.image_size)
90
91
    traindataset = FetalUSFramesDataset( hdf5_file= CFG.datapath,
92
                                transform = transform_obj,
93
                                load2ram = False, frame_skip=CFG.skip_count)
94
95
96
    trainloader  = torch.utils.data.DataLoader( traindataset, shuffle=True,
97
                        batch_size=CFG.batch_size, num_workers=CFG.workers,
98
                        pin_memory=True)
99
100
    validdataset = FetalUSFramesDataset( hdf5_file= CFG.valdatapath,
101
                                transform = transform_obj,
102
                                load2ram = False, frame_skip=CFG.skip_count)
103
104
105
    validloader  = torch.utils.data.DataLoader( validdataset, shuffle=False,
106
                        batch_size=CFG.batch_size, num_workers=CFG.workers,
107
                        pin_memory=True)
108
109
110
    lutl.LOG2DICTXT({"TRAIN DatasetClass":traindataset.get_info(),
111
                    "TransformsClass": str(transform_obj.get_composition()),
112
                    }, CFG.gLogPath +'/misc.txt')
113
    lutl.LOG2DICTXT({"VALID DatasetClass":validdataset.get_info(),
114
                    "TransformsClass": str(transform_obj.get_composition()),
115
                    }, CFG.gLogPath +'/misc.txt')
116
117
    return trainloader, validloader
118
119
120
def getModelnOptimizer():
121
    model = BarlowTwins(featx_arch=CFG.featx_arch,
122
                        projector_sizes=CFG.projector,
123
                        batch_size=CFG.batch_size,
124
                        lmbd=CFG.lmbd,
125
                        pretrained=CFG.featx_pretrain).to(device)
126
127
    optimizer = LARS(model.parameters(), lr=0, weight_decay=CFG.weight_decay,
128
                     weight_decay_filter=True, lars_adaptation_filter=True)
129
130
    model_info = torchinfo.summary(model, 2*[(1, 3, CFG.image_size, CFG.image_size)],
131
                                verbose=0)
132
    lutl.LOG2TXT(model_info, CFG.gLogPath +'/misc.txt', console= False)
133
134
    return model.to(device), optimizer
135
136
137
### ----------------------------------------------------------------------------
138
139
def simple_main():
140
    ### SETUP
141
    rutl.START_SEED()
142
    torch.cuda.device(device)
143
    torch.backends.cudnn.benchmark = True
144
145
    if os.path.exists(CFG.checkpoint_dir) and (not CFG.resume_training):
146
        raise Exception("CheckPoint folder already exists and restart_training not enabled; Somethings Wrong!")
147
    if not os.path.exists(CFG.gWeightPath): os.makedirs(CFG.gWeightPath)
148
149
    with open(CFG.gLogPath+"/exp_config.json", 'a') as f:
150
        json.dump(vars(CFG), f, indent=4)
151
152
153
    ### DATA ACCESS
154
    trainloader, validloader = getDataLoaders()
155
156
    ### MODEL, OPTIM
157
    model, optimizer = getModelnOptimizer()
158
159
    ## Automatically resume from checkpoint if it exists and enabled
160
    ckpt = None
161
    if CFG.resume_training:
162
        try:    ckpt = torch.load(CFG.gWeightPath+'/checkpoint-1.pth', map_location='cpu')
163
        except:
164
            try:ckpt = torch.load(CFG.gWeightPath+'/checkpoint-0.pth', map_location='cpu')
165
            except: print("Check points are not loadable. Starting fresh...")
166
    if ckpt:
167
        start_epoch = ckpt['epoch']
168
        model.load_state_dict(ckpt['model'])
169
        optimizer.load_state_dict(ckpt['optimizer'])
170
        lutl.LOG2TXT(f"Restarting Training from EPOCH:{start_epoch} of {CFG.checkpoint_dir}",  CFG.gLogPath +'/misc.txt')
171
    else:
172
        start_epoch = 0
173
174
175
    ### MODEL TRAINING
176
    start_time = time.time()
177
    best_loss = float('inf')
178
    wgt_suf   = 0  # foolproof savetime crash
179
    if CFG.use_amp: scaler = torch.cuda.amp.GradScaler() # for mixed precision
180
181
    for epoch in range(start_epoch, CFG.epochs):
182
183
        ## ---- Training Routine ----
184
        t_running_loss_ = 0
185
        model.train()
186
        for step, (y1, y2) in tqdm(enumerate(trainloader,
187
                                    start=epoch * len(trainloader)),
188
                                    disable=CFG.disable_tqdm):
189
            y1 = y1.to(device, non_blocking=True)
190
            y2 = y2.to(device, non_blocking=True)
191
192
            adjust_learning_rate(CFG, optimizer, trainloader, step)
193
            optimizer.zero_grad()
194
195
            if CFG.use_amp: ## with mixed precision
196
                with torch.cuda.amp.autocast():
197
                    loss = model.forward(y1, y2)
198
                scaler.scale(loss).backward()
199
                scaler.step(optimizer)
200
                scaler.update()
201
            else:
202
                loss = model.forward(y1, y2)
203
                loss.backward()
204
                optimizer.step()
205
            t_running_loss_+=loss.item()
206
207
            if step % CFG.print_freq_step == 0:
208
                stats = dict(epoch=epoch, step=step,
209
                             time=int(time.time() - start_time),
210
                             step_loss   = loss.item(),
211
                             lr_weights  = optimizer.param_groups[0]['lr'],
212
                             lr_biases   = optimizer.param_groups[1]['lr'],)
213
                lutl.LOG2DICTXT(stats, CFG.checkpoint_dir +'/train-stats.txt')
214
        train_epoch_loss = t_running_loss_/len(trainloader)
215
216
        # save checkpoint
217
        if (epoch+1) % CFG.ckpt_freq_epoch == 0:
218
            wgt_suf = (wgt_suf+1) %2
219
            state = dict(epoch=epoch, model=model.state_dict(),
220
                            optimizer=optimizer.state_dict())
221
            torch.save(state, CFG.gWeightPath +f'/checkpoint-{wgt_suf}.pth')
222
223
224
        ## ---- Validation Routine ----
225
        if (epoch+1) % CFG.valid_freq_epoch == 0:
226
            model.eval()
227
            v_running_loss_ = 0
228
            with torch.no_grad():
229
                for (y1, y2) in tqdm(validloader,  total=len(validloader),
230
                                    disable=CFG.disable_tqdm):
231
                    y1 = y1.to(device, non_blocking=True)
232
                    y2 = y2.to(device, non_blocking=True)
233
                    loss = model.forward(y1, y2)
234
                    v_running_loss_ += loss.item()
235
            valid_epoch_loss = v_running_loss_/len(validloader)
236
237
            # just check
238
            best_flag = False
239
            if valid_epoch_loss < best_loss:
240
                best_flag = True
241
                best_loss = valid_epoch_loss
242
243
            v_stats = dict(epoch=epoch, best=best_flag, wgt_suf=wgt_suf,
244
                            train_loss=train_epoch_loss,
245
                            valid_loss=valid_epoch_loss)
246
            lutl.LOG2DICTXT(v_stats, CFG.gLogPath+'/valid-stats.txt')
247
248
249
if __name__ == '__main__':
250
    simple_main()