Diff of /tasks/ae-train.py [000000] .. [a18f15]

Switch to side-by-side view

--- a
+++ b/tasks/ae-train.py
@@ -0,0 +1,260 @@
+""" Barlow Twin self-supervision training
+"""
+import argparse
+import json
+import math
+import os
+import random
+import signal
+import subprocess
+import sys
+import time
+from tqdm import tqdm
+
+from torch import nn, optim
+import torch
+import torchvision
+import torchinfo
+
+sys.path.append(os.getcwd())
+import utilities.runUtils as rutl
+import utilities.logUtils as lutl
+from algorithms.autoencoder import AutoEncoder
+from datacode.natural_image_data import Cifar100Dataset
+from datacode.ultrasound_data import FetalUSFramesDataset
+from datacode.augmentations import AEncStandardTransform, AEncInpaintTransform
+
+
+print(f"Pytorch version: {torch.__version__}")
+print(f"cuda version: {torch.version.cuda}")
+device = 'cuda' if torch.cuda.is_available() else 'cpu'
+print("Device Used:", device)
+
+###============================= Configure and Setup ===========================
+
+CFG = rutl.ObjDict(
+use_amp = True, #automatic Mixed precision
+
+datapath    = "/home/USR/WERK/data/",
+valdatapath = "/home/USR/WERK/valdata/",
+skip_count  = 5,
+
+epochs      = 1000,
+batch_size  = 2048,
+workers     = 16,
+image_size  = 256,
+
+learning_rate = 1e-3,
+weight_decay  = 1e-6,
+sched_step    = 50, ## epoch
+sched_gamma   = 0.5624,   # 1/10 every 200
+autoenc_map   = "standard",       # standard, denoise, inpaint
+
+
+featx_arch = "resnet50", # "resnet34/50/101"
+featx_pretrain = "IMAGENET-1K", # "IMAGENET-1K" or None
+
+print_freq_step  = 1000, #steps
+ckpt_freq_epoch  = 5,  #epochs
+valid_freq_epoch = 5,  #epochs
+disable_tqdm     = False,   #True--> to disable
+
+checkpoint_dir  = "hypotheses/-dummy/ssl-autoenc",
+resume_training = False,
+)
+
+## --------
+parser = argparse.ArgumentParser(description='Auto Encoder architecture training Training')
+parser.add_argument('--load-json', type=str, metavar='JSON',
+    help='Load settings from file in json format. Command line options override values in file.')
+
+args = parser.parse_args()
+
+if args.load_json:
+    with open(args.load_json, 'rt') as f:
+        CFG.__dict__.update(json.load(f))
+
+### ----------------------------------------------------------------------------
+CFG.gLogPath = CFG.checkpoint_dir
+CFG.gWeightPath = CFG.checkpoint_dir + '/weights/'
+
+### ============================================================================
+
+
+def getDataLoaders():
+
+    if CFG.autoenc_map == "standard":
+        transform_obj = AEncStandardTransform(image_size=CFG.image_size)
+    elif CFG.autoenc_map == "inpaint":
+        transform_obj = AEncInpaintTransform(image_size=CFG.image_size)
+    else:
+        raise Exception("Unknown Auto Encoder augmentatoion")
+
+
+    traindataset = FetalUSFramesDataset( hdf5_file= CFG.datapath,
+                                transform = transform_obj,
+                                load2ram = False, frame_skip=CFG.skip_count)
+
+
+    trainloader  = torch.utils.data.DataLoader( traindataset, shuffle=True,
+                        batch_size=CFG.batch_size, num_workers=CFG.workers,
+                        pin_memory=True)
+
+    validdataset = FetalUSFramesDataset( hdf5_file= CFG.valdatapath,
+                                transform = transform_obj,
+                                load2ram = False, frame_skip=CFG.skip_count)
+
+
+    validloader  = torch.utils.data.DataLoader( validdataset, shuffle=False,
+                        batch_size=CFG.batch_size, num_workers=CFG.workers,
+                        pin_memory=True)
+
+
+    lutl.LOG2DICTXT({"TRAIN DatasetClass":traindataset.get_info(),
+                    "TransformsClass": str(transform_obj.get_composition()),
+                    }, CFG.gLogPath +'/misc.txt')
+    lutl.LOG2DICTXT({"VALID DatasetClass":validdataset.get_info(),
+                    "TransformsClass": str(transform_obj.get_composition()),
+                    }, CFG.gLogPath +'/misc.txt')
+
+    return trainloader, validloader
+
+
+def getModelnOptimizer():
+    model = AutoEncoder(arch=CFG.featx_arch,
+                        pretrained=CFG.featx_pretrain).to(device)
+
+    optimizer = optim.AdamW(model.parameters(), lr=CFG.learning_rate,
+                        weight_decay=CFG.weight_decay)
+
+    scheduler = optim.lr_scheduler.StepLR(optimizer,
+                        step_size=CFG.sched_step, gamma=CFG.sched_gamma)
+
+    model_info = torchinfo.summary(model, (1, 3, CFG.image_size, CFG.image_size),
+                                verbose=0)
+    lutl.LOG2TXT(model_info, CFG.gLogPath +'/misc.txt', console= False)
+
+    return model.to(device), optimizer, scheduler
+
+
+def getLossFunc():
+    mse = nn.MSELoss()
+    # def scaledMSE(pred, tgt):
+    #     loss = mse(pred, tgt) *256
+    #     return loss
+    return mse
+
+
+
+def simple_main():
+    ### SETUP
+    rutl.START_SEED()
+    torch.cuda.device(device)
+    torch.backends.cudnn.benchmark = True
+
+    if os.path.exists(CFG.checkpoint_dir) and (not CFG.resume_training):
+        raise Exception("CheckPoint folder already exists and restart_training not enabled; Somethings Wrong!")
+    if not os.path.exists(CFG.gWeightPath): os.makedirs(CFG.gWeightPath)
+
+    with open(CFG.gLogPath+"/exp_config.json", 'a') as f:
+        json.dump(vars(CFG), f, indent=4)
+
+
+    ### DATA ACCESS
+    trainloader, validloader = getDataLoaders()
+
+    ### MODEL, OPTIM
+    model, optimizer, scheduler = getModelnOptimizer()
+    lossfn = getLossFunc()
+
+    ## Automatically resume from checkpoint if it exists and enabled
+    ckpt = None
+    if CFG.resume_training:
+        try:    ckpt = torch.load(CFG.gWeightPath+'/checkpoint-1.pth', map_location='cpu')
+        except:
+            try:ckpt = torch.load(CFG.gWeightPath+'/checkpoint-0.pth', map_location='cpu')
+            except: print("Check points are not loadable. Starting fresh...")
+    if ckpt:
+        start_epoch = ckpt['epoch']
+        model.load_state_dict(ckpt['model'])
+        optimizer.load_state_dict(ckpt['optimizer'])
+        lutl.LOG2TXT(f"Restarting Training from EPOCH:{start_epoch} of {CFG.checkpoint_dir}",  CFG.gLogPath +'/misc.txt')
+    else:
+        start_epoch = 0
+
+
+    ### MODEL TRAINING
+    start_time = time.time()
+    best_loss  = float('inf')
+    wgt_suf    = 0  # foolproof savetime crash
+    if CFG.use_amp: scaler = torch.cuda.amp.GradScaler() # for mixed precision
+
+    for epoch in range(start_epoch, CFG.epochs):
+
+        ## ---- Training Routine ----
+        t_running_loss_ = 0
+        model.train()
+        for step, (y1, y2) in tqdm(enumerate(trainloader,
+                                    start=epoch * len(trainloader)),
+                                    disable=CFG.disable_tqdm):
+            y1 = y1.to(device, non_blocking=True)
+            y2 = y2.to(device, non_blocking=True)
+            optimizer.zero_grad()
+
+            if CFG.use_amp: ## with mixed precision
+                with torch.cuda.amp.autocast():
+                    y_pred = model.forward(y1)
+                    loss = lossfn(y_pred, y2)
+                scaler.scale(loss).backward()
+                scaler.step(optimizer)
+                scaler.update()
+            else:
+                y_pred = model.forward(y1)
+                loss = lossfn(y_pred, y2)
+                loss.backward()
+                optimizer.step()
+            t_running_loss_+=loss.item()
+
+            if step % CFG.print_freq_step == 0:
+                stats = dict(epoch=epoch, step=step,
+                             step_loss=loss.item(),
+                             time=int(time.time() - start_time))
+                lutl.LOG2DICTXT(stats, CFG.checkpoint_dir +'/train-stats.txt')
+        train_epoch_loss = t_running_loss_/len(trainloader)
+
+        if scheduler: scheduler.step()
+
+        # save checkpoint
+        if (epoch+1) % CFG.ckpt_freq_epoch == 0:
+            wgt_suf = (wgt_suf+1) %2
+            state = dict(epoch=epoch, model=model.state_dict(),
+                            optimizer=optimizer.state_dict())
+            torch.save(state, CFG.gWeightPath +f'/checkpoint-{wgt_suf}.pth')
+
+
+        ## ---- Validation Routine ----
+        if (epoch+1) % CFG.valid_freq_epoch == 0:
+            model.eval()
+            v_running_loss_ = 0
+            with torch.no_grad():
+                for (y1, y2) in tqdm(validloader,  total=len(validloader),
+                                    disable=CFG.disable_tqdm):
+                    y1 = y1.to(device, non_blocking=True)
+                    y2 = y2.to(device, non_blocking=True)
+                    y_pred = model.forward(y1)
+                    loss = lossfn(y_pred, y2)
+                    v_running_loss_ += loss.item()
+            valid_epoch_loss = v_running_loss_/len(validloader)
+            best_flag = False
+            if valid_epoch_loss < best_loss:
+                best_flag = True
+                best_loss = valid_epoch_loss
+
+            v_stats = dict(epoch=epoch, best=best_flag, wgt_suf=wgt_suf,
+                            train_loss=train_epoch_loss,
+                            valid_loss=valid_epoch_loss)
+            lutl.LOG2DICTXT(v_stats, CFG.gLogPath+'/valid-stats.txt')
+
+
+if __name__ == '__main__':
+    simple_main()
\ No newline at end of file