a b/tasks/ae-train.py
1
""" Barlow Twin self-supervision training
2
"""
3
import argparse
4
import json
5
import math
6
import os
7
import random
8
import signal
9
import subprocess
10
import sys
11
import time
12
from tqdm import tqdm
13
14
from torch import nn, optim
15
import torch
16
import torchvision
17
import torchinfo
18
19
sys.path.append(os.getcwd())
20
import utilities.runUtils as rutl
21
import utilities.logUtils as lutl
22
from algorithms.autoencoder import AutoEncoder
23
from datacode.natural_image_data import Cifar100Dataset
24
from datacode.ultrasound_data import FetalUSFramesDataset
25
from datacode.augmentations import AEncStandardTransform, AEncInpaintTransform
26
27
28
print(f"Pytorch version: {torch.__version__}")
29
print(f"cuda version: {torch.version.cuda}")
30
device = 'cuda' if torch.cuda.is_available() else 'cpu'
31
print("Device Used:", device)
32
33
###============================= Configure and Setup ===========================
34
35
CFG = rutl.ObjDict(
36
use_amp = True, #automatic Mixed precision
37
38
datapath    = "/home/USR/WERK/data/",
39
valdatapath = "/home/USR/WERK/valdata/",
40
skip_count  = 5,
41
42
epochs      = 1000,
43
batch_size  = 2048,
44
workers     = 16,
45
image_size  = 256,
46
47
learning_rate = 1e-3,
48
weight_decay  = 1e-6,
49
sched_step    = 50, ## epoch
50
sched_gamma   = 0.5624,   # 1/10 every 200
51
autoenc_map   = "standard",       # standard, denoise, inpaint
52
53
54
featx_arch = "resnet50", # "resnet34/50/101"
55
featx_pretrain = "IMAGENET-1K", # "IMAGENET-1K" or None
56
57
print_freq_step  = 1000, #steps
58
ckpt_freq_epoch  = 5,  #epochs
59
valid_freq_epoch = 5,  #epochs
60
disable_tqdm     = False,   #True--> to disable
61
62
checkpoint_dir  = "hypotheses/-dummy/ssl-autoenc",
63
resume_training = False,
64
)
65
66
## --------
67
parser = argparse.ArgumentParser(description='Auto Encoder architecture training Training')
68
parser.add_argument('--load-json', type=str, metavar='JSON',
69
    help='Load settings from file in json format. Command line options override values in file.')
70
71
args = parser.parse_args()
72
73
if args.load_json:
74
    with open(args.load_json, 'rt') as f:
75
        CFG.__dict__.update(json.load(f))
76
77
### ----------------------------------------------------------------------------
78
CFG.gLogPath = CFG.checkpoint_dir
79
CFG.gWeightPath = CFG.checkpoint_dir + '/weights/'
80
81
### ============================================================================
82
83
84
def getDataLoaders():
85
86
    if CFG.autoenc_map == "standard":
87
        transform_obj = AEncStandardTransform(image_size=CFG.image_size)
88
    elif CFG.autoenc_map == "inpaint":
89
        transform_obj = AEncInpaintTransform(image_size=CFG.image_size)
90
    else:
91
        raise Exception("Unknown Auto Encoder augmentatoion")
92
93
94
    traindataset = FetalUSFramesDataset( hdf5_file= CFG.datapath,
95
                                transform = transform_obj,
96
                                load2ram = False, frame_skip=CFG.skip_count)
97
98
99
    trainloader  = torch.utils.data.DataLoader( traindataset, shuffle=True,
100
                        batch_size=CFG.batch_size, num_workers=CFG.workers,
101
                        pin_memory=True)
102
103
    validdataset = FetalUSFramesDataset( hdf5_file= CFG.valdatapath,
104
                                transform = transform_obj,
105
                                load2ram = False, frame_skip=CFG.skip_count)
106
107
108
    validloader  = torch.utils.data.DataLoader( validdataset, shuffle=False,
109
                        batch_size=CFG.batch_size, num_workers=CFG.workers,
110
                        pin_memory=True)
111
112
113
    lutl.LOG2DICTXT({"TRAIN DatasetClass":traindataset.get_info(),
114
                    "TransformsClass": str(transform_obj.get_composition()),
115
                    }, CFG.gLogPath +'/misc.txt')
116
    lutl.LOG2DICTXT({"VALID DatasetClass":validdataset.get_info(),
117
                    "TransformsClass": str(transform_obj.get_composition()),
118
                    }, CFG.gLogPath +'/misc.txt')
119
120
    return trainloader, validloader
121
122
123
def getModelnOptimizer():
124
    model = AutoEncoder(arch=CFG.featx_arch,
125
                        pretrained=CFG.featx_pretrain).to(device)
126
127
    optimizer = optim.AdamW(model.parameters(), lr=CFG.learning_rate,
128
                        weight_decay=CFG.weight_decay)
129
130
    scheduler = optim.lr_scheduler.StepLR(optimizer,
131
                        step_size=CFG.sched_step, gamma=CFG.sched_gamma)
132
133
    model_info = torchinfo.summary(model, (1, 3, CFG.image_size, CFG.image_size),
134
                                verbose=0)
135
    lutl.LOG2TXT(model_info, CFG.gLogPath +'/misc.txt', console= False)
136
137
    return model.to(device), optimizer, scheduler
138
139
140
def getLossFunc():
141
    mse = nn.MSELoss()
142
    # def scaledMSE(pred, tgt):
143
    #     loss = mse(pred, tgt) *256
144
    #     return loss
145
    return mse
146
147
148
149
def simple_main():
150
    ### SETUP
151
    rutl.START_SEED()
152
    torch.cuda.device(device)
153
    torch.backends.cudnn.benchmark = True
154
155
    if os.path.exists(CFG.checkpoint_dir) and (not CFG.resume_training):
156
        raise Exception("CheckPoint folder already exists and restart_training not enabled; Somethings Wrong!")
157
    if not os.path.exists(CFG.gWeightPath): os.makedirs(CFG.gWeightPath)
158
159
    with open(CFG.gLogPath+"/exp_config.json", 'a') as f:
160
        json.dump(vars(CFG), f, indent=4)
161
162
163
    ### DATA ACCESS
164
    trainloader, validloader = getDataLoaders()
165
166
    ### MODEL, OPTIM
167
    model, optimizer, scheduler = getModelnOptimizer()
168
    lossfn = getLossFunc()
169
170
    ## Automatically resume from checkpoint if it exists and enabled
171
    ckpt = None
172
    if CFG.resume_training:
173
        try:    ckpt = torch.load(CFG.gWeightPath+'/checkpoint-1.pth', map_location='cpu')
174
        except:
175
            try:ckpt = torch.load(CFG.gWeightPath+'/checkpoint-0.pth', map_location='cpu')
176
            except: print("Check points are not loadable. Starting fresh...")
177
    if ckpt:
178
        start_epoch = ckpt['epoch']
179
        model.load_state_dict(ckpt['model'])
180
        optimizer.load_state_dict(ckpt['optimizer'])
181
        lutl.LOG2TXT(f"Restarting Training from EPOCH:{start_epoch} of {CFG.checkpoint_dir}",  CFG.gLogPath +'/misc.txt')
182
    else:
183
        start_epoch = 0
184
185
186
    ### MODEL TRAINING
187
    start_time = time.time()
188
    best_loss  = float('inf')
189
    wgt_suf    = 0  # foolproof savetime crash
190
    if CFG.use_amp: scaler = torch.cuda.amp.GradScaler() # for mixed precision
191
192
    for epoch in range(start_epoch, CFG.epochs):
193
194
        ## ---- Training Routine ----
195
        t_running_loss_ = 0
196
        model.train()
197
        for step, (y1, y2) in tqdm(enumerate(trainloader,
198
                                    start=epoch * len(trainloader)),
199
                                    disable=CFG.disable_tqdm):
200
            y1 = y1.to(device, non_blocking=True)
201
            y2 = y2.to(device, non_blocking=True)
202
            optimizer.zero_grad()
203
204
            if CFG.use_amp: ## with mixed precision
205
                with torch.cuda.amp.autocast():
206
                    y_pred = model.forward(y1)
207
                    loss = lossfn(y_pred, y2)
208
                scaler.scale(loss).backward()
209
                scaler.step(optimizer)
210
                scaler.update()
211
            else:
212
                y_pred = model.forward(y1)
213
                loss = lossfn(y_pred, y2)
214
                loss.backward()
215
                optimizer.step()
216
            t_running_loss_+=loss.item()
217
218
            if step % CFG.print_freq_step == 0:
219
                stats = dict(epoch=epoch, step=step,
220
                             step_loss=loss.item(),
221
                             time=int(time.time() - start_time))
222
                lutl.LOG2DICTXT(stats, CFG.checkpoint_dir +'/train-stats.txt')
223
        train_epoch_loss = t_running_loss_/len(trainloader)
224
225
        if scheduler: scheduler.step()
226
227
        # save checkpoint
228
        if (epoch+1) % CFG.ckpt_freq_epoch == 0:
229
            wgt_suf = (wgt_suf+1) %2
230
            state = dict(epoch=epoch, model=model.state_dict(),
231
                            optimizer=optimizer.state_dict())
232
            torch.save(state, CFG.gWeightPath +f'/checkpoint-{wgt_suf}.pth')
233
234
235
        ## ---- Validation Routine ----
236
        if (epoch+1) % CFG.valid_freq_epoch == 0:
237
            model.eval()
238
            v_running_loss_ = 0
239
            with torch.no_grad():
240
                for (y1, y2) in tqdm(validloader,  total=len(validloader),
241
                                    disable=CFG.disable_tqdm):
242
                    y1 = y1.to(device, non_blocking=True)
243
                    y2 = y2.to(device, non_blocking=True)
244
                    y_pred = model.forward(y1)
245
                    loss = lossfn(y_pred, y2)
246
                    v_running_loss_ += loss.item()
247
            valid_epoch_loss = v_running_loss_/len(validloader)
248
            best_flag = False
249
            if valid_epoch_loss < best_loss:
250
                best_flag = True
251
                best_loss = valid_epoch_loss
252
253
            v_stats = dict(epoch=epoch, best=best_flag, wgt_suf=wgt_suf,
254
                            train_loss=train_epoch_loss,
255
                            valid_loss=valid_epoch_loss)
256
            lutl.LOG2DICTXT(v_stats, CFG.gLogPath+'/valid-stats.txt')
257
258
259
if __name__ == '__main__':
260
    simple_main()