Diff of /tasks/simclr-train.py [000000] .. [a18f15]

Switch to unified view

a b/tasks/simclr-train.py
1
""" SimCLR self-supervision training
2
"""
3
import argparse
4
import json
5
import os
6
import sys
7
import time
8
from tqdm import tqdm
9
10
import torch
11
import torchinfo
12
13
sys.path.append(os.getcwd())
14
import utilities.runUtils as rutl
15
import utilities.logUtils as lutl
16
from algorithms.simclr import SimCLR, LARS
17
from datacode.natural_image_data import Cifar100Dataset
18
from datacode.ultrasound_data import FetalUSFramesDataset
19
from datacode.augmentations import SimCLRTransform
20
21
print(f"Pytorch version: {torch.__version__}")
22
print(f"cuda version: {torch.version.cuda}")
23
device = 'cuda' if torch.cuda.is_available() else 'cpu'
24
print("Device Used:", device)
25
26
###============================= Configure and Setup ===========================
27
28
CFG = rutl.ObjDict(
29
use_amp = True, #automatic Mixed precision
30
31
datapath    = "/home/mothilal.asokan/Downloads/HC701/Project/US-Fetal-Video-Frames_V1-1/train-all-frames.hdf5",
32
valdatapath = "/home/mothilal.asokan/Downloads/HC701/Project/US-Fetal-Video-Frames_V1-1/valid-all-frames.hdf5",
33
skip_count  = 5,
34
35
epochs      = 10,
36
batch_size  = 160,
37
workers     = 24,
38
image_size  = 256,
39
40
weight_decay = 1e-4,
41
temp         = 0.5,
42
lr           = 0.3,
43
44
featx_arch = "resnet50", # "resnet34/50/101"
45
featx_pretrain =  None, # "IMGNET-1K" or None
46
projector = [512, 128],
47
48
print_freq_step   = 10, #steps
49
ckpt_freq_epoch   = 5,  #epochs
50
valid_freq_epoch  = 5,  #epochs
51
disable_tqdm      = False,   #True--> to disable
52
53
checkpoint_dir= "hypotheses/-dummy/ssl-simclr/",
54
resume_training = True,
55
)
56
57
## --------
58
parser = argparse.ArgumentParser(description='SimCLR Training')
59
parser.add_argument('--load-json', type=str, metavar='JSON',
60
    help='Load settings from file in json format. Command line options override values in python file.')
61
62
63
64
args = parser.parse_args()
65
66
if args.load_json:
67
    with open(args.load_json, 'rt') as f:
68
        CFG.__dict__.update(json.load(f))
69
70
### ----------------------------------------------------------------------------
71
CFG.gLogPath = CFG.checkpoint_dir
72
CFG.gWeightPath = CFG.checkpoint_dir + '/weights/'
73
74
### ============================================================================
75
76
def getDataLoaders():
77
78
    transform_obj = SimCLRTransform(image_size=CFG.image_size)
79
80
    traindataset = FetalUSFramesDataset( hdf5_file= CFG.datapath,
81
                                transform = transform_obj,
82
                                load2ram = False, frame_skip=CFG.skip_count)
83
84
85
    trainloader  = torch.utils.data.DataLoader( traindataset, shuffle=True,
86
                        batch_size=CFG.batch_size, num_workers=CFG.workers,
87
                        pin_memory=True,drop_last=True )
88
89
    # val_transform_obj = SimCLREvalDataTransform(image_size=CFG.image_size)
90
91
    validdataset = FetalUSFramesDataset( hdf5_file= CFG.valdatapath,
92
                                transform = transform_obj,
93
                                load2ram = False, frame_skip=CFG.skip_count)
94
95
96
    validloader  = torch.utils.data.DataLoader( validdataset, shuffle=False,
97
                        batch_size=CFG.batch_size, num_workers=CFG.workers,
98
                        pin_memory=True, drop_last=True)
99
100
101
    lutl.LOG2DICTXT({"TRAIN DatasetClass":traindataset.get_info(),
102
                    "TransformsClass": str(transform_obj.get_composition()),
103
                    }, CFG.gLogPath +'/misc.txt')
104
    lutl.LOG2DICTXT({"VALID DatasetClass":validdataset.get_info(),
105
                    "TransformsClass": str(transform_obj.get_composition()),
106
                    }, CFG.gLogPath +'/misc.txt')
107
108
    return trainloader, validloader
109
110
def getModelnOptimizer():
111
    model = SimCLR(featx_arch=CFG.featx_arch,
112
                        projector_sizes=CFG.projector,
113
                        batch_size=CFG.batch_size,
114
                        temp=CFG.temp,
115
                        pretrained=CFG.featx_pretrain).to(device)
116
117
    optimizer = LARS(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay,
118
                     momentum=0.9)
119
120
121
    model_info = torchinfo.summary(model, 2*[(CFG.batch_size, 3, CFG.image_size, CFG.image_size)],
122
                                verbose=0)
123
    lutl.LOG2TXT(model_info, CFG.gLogPath +'/misc.txt', console= False)
124
125
    return model.to(device), optimizer
126
127
128
### ----------------------------------------------------------------------------
129
130
def simple_main():
131
    ### SETUP
132
    rutl.START_SEED()
133
    torch.cuda.device(device)
134
    torch.backends.cudnn.benchmark = True
135
136
    if os.path.exists(CFG.checkpoint_dir) and (not CFG.resume_training):
137
        raise Exception("CheckPoint folder already exists and restart_training not enabled; Somethings Wrong!")
138
    if not os.path.exists(CFG.gWeightPath): os.makedirs(CFG.gWeightPath)
139
140
    with open(CFG.gLogPath+"/exp_config.json", 'a') as f:
141
        json.dump(vars(CFG), f, indent=4)
142
143
144
    ### DATA ACCESS
145
    trainloader, validloader = getDataLoaders()
146
147
    ### MODEL, OPTIM
148
    model, optimizer = getModelnOptimizer()
149
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
150
                                    len(trainloader), eta_min=0,last_epoch=-1)
151
    ## Automatically resume from checkpoint if it exists and enabled
152
    ckpt = None
153
    if CFG.resume_training:
154
        try:    ckpt = torch.load(CFG.gWeightPath+'/checkpoint-1.pth', map_location='cpu')
155
        except:
156
            try:ckpt = torch.load(CFG.gWeightPath+'/checkpoint-0.pth', map_location='cpu')
157
            except: print("Check points are not loadable. Starting fresh...")
158
    if ckpt:
159
        start_epoch = ckpt['epoch']
160
        model.load_state_dict(ckpt['model'])
161
        optimizer.load_state_dict(ckpt['optimizer'])
162
        lutl.LOG2TXT(f"Restarting Training from EPOCH:{start_epoch} of {CFG.checkpoint_dir}",  CFG.gLogPath +'/misc.txt')
163
    else:
164
        start_epoch = 0
165
166
167
    ### MODEL TRAINING
168
    start_time = time.time()
169
    best_loss = float('inf')
170
    wgt_suf   = 0  # foolproof savetime crash
171
    if CFG.use_amp: scaler = torch.cuda.amp.GradScaler() # for mixed precision
172
173
    for epoch in range(start_epoch, CFG.epochs):
174
175
        ## ---- Training Routine ----
176
        t_running_loss_ = 0
177
        model.train()
178
        for step, (y1, y2) in tqdm(enumerate(trainloader,
179
                                    start=epoch * len(trainloader)),
180
                                    disable=CFG.disable_tqdm):
181
182
            y1 = y1.to(device, non_blocking=True)
183
            y2 = y2.to(device, non_blocking=True)
184
            optimizer.zero_grad()
185
186
            if CFG.use_amp: ## with mixed precision
187
                with torch.cuda.amp.autocast():
188
                    loss = model.forward(y1, y2)
189
                scaler.scale(loss).backward()
190
                scaler.step(optimizer)
191
                scaler.update()
192
            else:
193
                loss = model.forward(y1, y2)
194
                loss.backward()
195
                optimizer.step()
196
            t_running_loss_+=loss.item()
197
198
            if step % CFG.print_freq_step == 0:
199
                stats = dict(epoch=epoch, step=step,
200
                             lr_weights=optimizer.param_groups[0]['lr'],
201
                             step_loss=loss.item(),
202
                             time=int(time.time() - start_time))
203
                lutl.LOG2DICTXT(stats, CFG.checkpoint_dir +'/train-stats.txt')
204
        train_epoch_loss = t_running_loss_/len(trainloader)
205
206
        scheduler.step()
207
208
        # save checkpoint
209
        if (epoch+1) % CFG.ckpt_freq_epoch == 0:
210
            wgt_suf = (wgt_suf+1) %2
211
            state = dict(epoch=epoch, model=model.state_dict(),
212
                            optimizer=optimizer.state_dict())
213
            torch.save(state, CFG.gWeightPath +f'/checkpoint-{wgt_suf}.pth')
214
215
216
        ## ---- Validation Routine ----
217
        if (epoch+1) % CFG.valid_freq_epoch == 0:
218
            model.eval()
219
            v_running_loss_ = 0
220
            with torch.no_grad():
221
                for (y1, y2) in tqdm(validloader,  total=len(validloader),
222
                                    disable=CFG.disable_tqdm):
223
                    y1 = y1.to(device, non_blocking=True)
224
                    y2 = y2.to(device, non_blocking=True)
225
                    loss = model.forward(y1, y2)
226
                    v_running_loss_ += loss.item()
227
            valid_epoch_loss = v_running_loss_/len(validloader)
228
229
            # just check
230
            best_flag = False
231
            if valid_epoch_loss < best_loss:
232
                best_flag = True
233
                best_loss = valid_epoch_loss
234
235
            v_stats = dict(epoch=epoch, best=best_flag, wgt_suf=wgt_suf,
236
                            train_loss=train_epoch_loss,
237
                            valid_loss=valid_epoch_loss)
238
            lutl.LOG2DICTXT(v_stats, CFG.gLogPath+'/valid-stats.txt')
239
240
241
if __name__ == '__main__':
242
    simple_main()