--- a +++ b/tasks/ae-train.py @@ -0,0 +1,260 @@ +""" Barlow Twin self-supervision training +""" +import argparse +import json +import math +import os +import random +import signal +import subprocess +import sys +import time +from tqdm import tqdm + +from torch import nn, optim +import torch +import torchvision +import torchinfo + +sys.path.append(os.getcwd()) +import utilities.runUtils as rutl +import utilities.logUtils as lutl +from algorithms.autoencoder import AutoEncoder +from datacode.natural_image_data import Cifar100Dataset +from datacode.ultrasound_data import FetalUSFramesDataset +from datacode.augmentations import AEncStandardTransform, AEncInpaintTransform + + +print(f"Pytorch version: {torch.__version__}") +print(f"cuda version: {torch.version.cuda}") +device = 'cuda' if torch.cuda.is_available() else 'cpu' +print("Device Used:", device) + +###============================= Configure and Setup =========================== + +CFG = rutl.ObjDict( +use_amp = True, #automatic Mixed precision + +datapath = "/home/USR/WERK/data/", +valdatapath = "/home/USR/WERK/valdata/", +skip_count = 5, + +epochs = 1000, +batch_size = 2048, +workers = 16, +image_size = 256, + +learning_rate = 1e-3, +weight_decay = 1e-6, +sched_step = 50, ## epoch +sched_gamma = 0.5624, # 1/10 every 200 +autoenc_map = "standard", # standard, denoise, inpaint + + +featx_arch = "resnet50", # "resnet34/50/101" +featx_pretrain = "IMAGENET-1K", # "IMAGENET-1K" or None + +print_freq_step = 1000, #steps +ckpt_freq_epoch = 5, #epochs +valid_freq_epoch = 5, #epochs +disable_tqdm = False, #True--> to disable + +checkpoint_dir = "hypotheses/-dummy/ssl-autoenc", +resume_training = False, +) + +## -------- +parser = argparse.ArgumentParser(description='Auto Encoder architecture training Training') +parser.add_argument('--load-json', type=str, metavar='JSON', + help='Load settings from file in json format. Command line options override values in file.') + +args = parser.parse_args() + +if args.load_json: + with open(args.load_json, 'rt') as f: + CFG.__dict__.update(json.load(f)) + +### ---------------------------------------------------------------------------- +CFG.gLogPath = CFG.checkpoint_dir +CFG.gWeightPath = CFG.checkpoint_dir + '/weights/' + +### ============================================================================ + + +def getDataLoaders(): + + if CFG.autoenc_map == "standard": + transform_obj = AEncStandardTransform(image_size=CFG.image_size) + elif CFG.autoenc_map == "inpaint": + transform_obj = AEncInpaintTransform(image_size=CFG.image_size) + else: + raise Exception("Unknown Auto Encoder augmentatoion") + + + traindataset = FetalUSFramesDataset( hdf5_file= CFG.datapath, + transform = transform_obj, + load2ram = False, frame_skip=CFG.skip_count) + + + trainloader = torch.utils.data.DataLoader( traindataset, shuffle=True, + batch_size=CFG.batch_size, num_workers=CFG.workers, + pin_memory=True) + + validdataset = FetalUSFramesDataset( hdf5_file= CFG.valdatapath, + transform = transform_obj, + load2ram = False, frame_skip=CFG.skip_count) + + + validloader = torch.utils.data.DataLoader( validdataset, shuffle=False, + batch_size=CFG.batch_size, num_workers=CFG.workers, + pin_memory=True) + + + lutl.LOG2DICTXT({"TRAIN DatasetClass":traindataset.get_info(), + "TransformsClass": str(transform_obj.get_composition()), + }, CFG.gLogPath +'/misc.txt') + lutl.LOG2DICTXT({"VALID DatasetClass":validdataset.get_info(), + "TransformsClass": str(transform_obj.get_composition()), + }, CFG.gLogPath +'/misc.txt') + + return trainloader, validloader + + +def getModelnOptimizer(): + model = AutoEncoder(arch=CFG.featx_arch, + pretrained=CFG.featx_pretrain).to(device) + + optimizer = optim.AdamW(model.parameters(), lr=CFG.learning_rate, + weight_decay=CFG.weight_decay) + + scheduler = optim.lr_scheduler.StepLR(optimizer, + step_size=CFG.sched_step, gamma=CFG.sched_gamma) + + model_info = torchinfo.summary(model, (1, 3, CFG.image_size, CFG.image_size), + verbose=0) + lutl.LOG2TXT(model_info, CFG.gLogPath +'/misc.txt', console= False) + + return model.to(device), optimizer, scheduler + + +def getLossFunc(): + mse = nn.MSELoss() + # def scaledMSE(pred, tgt): + # loss = mse(pred, tgt) *256 + # return loss + return mse + + + +def simple_main(): + ### SETUP + rutl.START_SEED() + torch.cuda.device(device) + torch.backends.cudnn.benchmark = True + + if os.path.exists(CFG.checkpoint_dir) and (not CFG.resume_training): + raise Exception("CheckPoint folder already exists and restart_training not enabled; Somethings Wrong!") + if not os.path.exists(CFG.gWeightPath): os.makedirs(CFG.gWeightPath) + + with open(CFG.gLogPath+"/exp_config.json", 'a') as f: + json.dump(vars(CFG), f, indent=4) + + + ### DATA ACCESS + trainloader, validloader = getDataLoaders() + + ### MODEL, OPTIM + model, optimizer, scheduler = getModelnOptimizer() + lossfn = getLossFunc() + + ## Automatically resume from checkpoint if it exists and enabled + ckpt = None + if CFG.resume_training: + try: ckpt = torch.load(CFG.gWeightPath+'/checkpoint-1.pth', map_location='cpu') + except: + try:ckpt = torch.load(CFG.gWeightPath+'/checkpoint-0.pth', map_location='cpu') + except: print("Check points are not loadable. Starting fresh...") + if ckpt: + start_epoch = ckpt['epoch'] + model.load_state_dict(ckpt['model']) + optimizer.load_state_dict(ckpt['optimizer']) + lutl.LOG2TXT(f"Restarting Training from EPOCH:{start_epoch} of {CFG.checkpoint_dir}", CFG.gLogPath +'/misc.txt') + else: + start_epoch = 0 + + + ### MODEL TRAINING + start_time = time.time() + best_loss = float('inf') + wgt_suf = 0 # foolproof savetime crash + if CFG.use_amp: scaler = torch.cuda.amp.GradScaler() # for mixed precision + + for epoch in range(start_epoch, CFG.epochs): + + ## ---- Training Routine ---- + t_running_loss_ = 0 + model.train() + for step, (y1, y2) in tqdm(enumerate(trainloader, + start=epoch * len(trainloader)), + disable=CFG.disable_tqdm): + y1 = y1.to(device, non_blocking=True) + y2 = y2.to(device, non_blocking=True) + optimizer.zero_grad() + + if CFG.use_amp: ## with mixed precision + with torch.cuda.amp.autocast(): + y_pred = model.forward(y1) + loss = lossfn(y_pred, y2) + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + y_pred = model.forward(y1) + loss = lossfn(y_pred, y2) + loss.backward() + optimizer.step() + t_running_loss_+=loss.item() + + if step % CFG.print_freq_step == 0: + stats = dict(epoch=epoch, step=step, + step_loss=loss.item(), + time=int(time.time() - start_time)) + lutl.LOG2DICTXT(stats, CFG.checkpoint_dir +'/train-stats.txt') + train_epoch_loss = t_running_loss_/len(trainloader) + + if scheduler: scheduler.step() + + # save checkpoint + if (epoch+1) % CFG.ckpt_freq_epoch == 0: + wgt_suf = (wgt_suf+1) %2 + state = dict(epoch=epoch, model=model.state_dict(), + optimizer=optimizer.state_dict()) + torch.save(state, CFG.gWeightPath +f'/checkpoint-{wgt_suf}.pth') + + + ## ---- Validation Routine ---- + if (epoch+1) % CFG.valid_freq_epoch == 0: + model.eval() + v_running_loss_ = 0 + with torch.no_grad(): + for (y1, y2) in tqdm(validloader, total=len(validloader), + disable=CFG.disable_tqdm): + y1 = y1.to(device, non_blocking=True) + y2 = y2.to(device, non_blocking=True) + y_pred = model.forward(y1) + loss = lossfn(y_pred, y2) + v_running_loss_ += loss.item() + valid_epoch_loss = v_running_loss_/len(validloader) + best_flag = False + if valid_epoch_loss < best_loss: + best_flag = True + best_loss = valid_epoch_loss + + v_stats = dict(epoch=epoch, best=best_flag, wgt_suf=wgt_suf, + train_loss=train_epoch_loss, + valid_loss=valid_epoch_loss) + lutl.LOG2DICTXT(v_stats, CFG.gLogPath+'/valid-stats.txt') + + +if __name__ == '__main__': + simple_main() \ No newline at end of file