--- a +++ b/tasks/vic-train.py @@ -0,0 +1,247 @@ +""" VIC-reg self-supervision training +""" + +import argparse +import json +import math +import os +import random +import signal +import subprocess +import sys +import time +from tqdm import tqdm + +from torch import nn, optim +import torch +import torchvision +import torchinfo + +sys.path.append(os.getcwd()) +import utilities.runUtils as rutl +import utilities.logUtils as lutl +from algorithms.vicreg import VICReg, LARS, adjust_learning_rate +from datacode.natural_image_data import Cifar100Dataset +from datacode.ultrasound_data import FetalUSFramesDataset +from datacode.augmentations import BarlowTwinsTransformOrig, CustomInfoMaxTransform + + +print(f"Pytorch version: {torch.__version__}") +print(f"cuda version: {torch.version.cuda}") +device = 'cuda' if torch.cuda.is_available() else 'cpu' +print("Device Used:", device) + +###============================= Configure and Setup =========================== + +CFG = rutl.ObjDict( +use_amp = True, #automatic Mixed precision + +datapath = "/home/USR/WERK/data/a.hdf5", +valdatapath = "/home/USR/WERK/valdata/b.hdf5", +skip_count = 5, + +epochs = 1000, +batch_size = 2048, +workers = 24, +image_size = 256, + +base_lr = 0.2, +weight_decay = 1e-6, +sim_coeff = 25.0, # Invariance +std_coeff = 25.0, # Variance +cov_coeff = 1.0, # Covariance + +featx_arch = "resnet50", +featx_pretrain = None, # "IMGNET-1K" +projector = [8192,8192,8192], + +print_freq_step = 1000, #steps +ckpt_freq_epoch = 5, #epochs +valid_freq_epoch = 5, #epochs +disable_tqdm = False, #True--> to disable + +checkpoint_dir = "hypotheses/-dummy/ssl-vicreg/", +resume_training = True, +) + +## -------- +parser = argparse.ArgumentParser(description='VIC-Reg ISIC Training') +parser.add_argument('--load-json', type=str, metavar='JSON', + help='Load settings from file in json format. Command line options override values in file.') + +args = parser.parse_args() + +if args.load_json: + with open(args.load_json, 'rt') as f: + CFG.__dict__.update(json.load(f)) + +### ---------------------------------------------------------------------------- +CFG.gLogPath = CFG.checkpoint_dir +CFG.gWeightPath = CFG.checkpoint_dir + '/weights/' + +### ============================================================================ + + +def getDataLoaders(): + + transform_obj = BarlowTwinsTransformOrig(image_size=CFG.image_size) + + traindataset = FetalUSFramesDataset( hdf5_file= CFG.datapath, + transform = transform_obj, + load2ram = False, frame_skip=CFG.skip_count) + + + trainloader = torch.utils.data.DataLoader( traindataset, shuffle=True, + batch_size=CFG.batch_size, num_workers=CFG.workers, + pin_memory=True) + + validdataset = FetalUSFramesDataset( hdf5_file= CFG.valdatapath, + transform = transform_obj, + load2ram = False, frame_skip=CFG.skip_count) + + + validloader = torch.utils.data.DataLoader( validdataset, shuffle=False, + batch_size=CFG.batch_size, num_workers=CFG.workers, + pin_memory=True) + + + lutl.LOG2DICTXT({"TRAIN DatasetClass":traindataset.get_info(), + "TransformsClass": str(transform_obj.get_composition()), + }, CFG.gLogPath +'/misc.txt') + lutl.LOG2DICTXT({"VALID DatasetClass":validdataset.get_info(), + "TransformsClass": str(transform_obj.get_composition()), + }, CFG.gLogPath +'/misc.txt') + + return trainloader, validloader + + +def getModelnOptimizer(): + model = VICReg( featx_arch=CFG.featx_arch, + projector_sizes=CFG.projector, + batch_size=CFG.batch_size, + sim_coeff = CFG.sim_coeff, + std_coeff = CFG.std_coeff, + cov_coeff = CFG.cov_coeff, + featx_pretrain=CFG.featx_pretrain, + ).to(device) + + optimizer = LARS(model.parameters(), lr=0, weight_decay=CFG.weight_decay, + weight_decay_filter=True, lars_adaptation_filter=True) + + model_info = torchinfo.summary(model, 2*[(1, 3, CFG.image_size, CFG.image_size)], + verbose=0) + lutl.LOG2TXT(model_info, CFG.gLogPath +'/misc.txt', console= False) + + return model.to(device), optimizer + + +### ---------------------------------------------------------------------------- + +def simple_main(): + ### SETUP + rutl.START_SEED() + torch.cuda.device(device) + torch.backends.cudnn.benchmark = True + + if os.path.exists(CFG.checkpoint_dir) and (not CFG.resume_training): + raise Exception("CheckPoint folder already exists and restart_training not enabled; Somethings Wrong!") + if not os.path.exists(CFG.gWeightPath): os.makedirs(CFG.gWeightPath) + + with open(CFG.gLogPath+"/exp_config.json", 'a') as f: + json.dump(vars(CFG), f, indent=4) + + + ### DATA ACCESS + trainloader, validloader = getDataLoaders() + + ### MODEL, OPTIM + model, optimizer = getModelnOptimizer() + + ## Automatically resume from checkpoint if it exists and enabled + ckpt = None + if CFG.resume_training: + try: ckpt = torch.load(CFG.gWeightPath+'/checkpoint-1.pth', map_location='cpu') + except: + try:ckpt = torch.load(CFG.gWeightPath+'/checkpoint-0.pth', map_location='cpu') + except: print("Check points are not loadable. Starting fresh...") + if ckpt: + start_epoch = ckpt['epoch'] + model.load_state_dict(ckpt['model']) + optimizer.load_state_dict(ckpt['optimizer']) + lutl.LOG2TXT(f"Restarting Training from EPOCH:{start_epoch} of {CFG.checkpoint_dir}", CFG.gLogPath +'/misc.txt') + else: + start_epoch = 0 + + + ### MODEL TRAINING + start_time = time.time() + best_loss = float('inf') + wgt_suf = 0 # foolproof savetime crash + if CFG.use_amp: scaler = torch.cuda.amp.GradScaler() # for mixed precision + + for epoch in range(start_epoch, CFG.epochs): + + ## ---- Training Routine ---- + t_running_loss_ = 0 + model.train() + for step, (y1, y2) in tqdm(enumerate(trainloader, + start=epoch * len(trainloader)), + disable=CFG.disable_tqdm): + y1 = y1.to(device, non_blocking=True) + y2 = y2.to(device, non_blocking=True) + lr_ = adjust_learning_rate(CFG, optimizer, trainloader, step) + optimizer.zero_grad() + + if CFG.use_amp: ## with mixed precision + with torch.cuda.amp.autocast(): + loss = model.forward(y1, y2) + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + loss = model.forward(y1, y2) + loss.backward() + optimizer.step() + t_running_loss_+=loss.item() + + if step % CFG.print_freq_step == 0: + stats = dict(epoch=epoch, step=step, + time=int(time.time() - start_time), + step_loss=loss.item(), + lr= lr_,) + lutl.LOG2DICTXT(stats, CFG.checkpoint_dir +'/train-stats.txt') + train_epoch_loss = t_running_loss_/len(trainloader) + + # save checkpoint + if (epoch+1) % CFG.ckpt_freq_epoch == 0: + wgt_suf = (wgt_suf+1) %2 + state = dict(epoch=epoch, model=model.state_dict(), + optimizer=optimizer.state_dict()) + torch.save(state, CFG.gWeightPath +f'/checkpoint-{wgt_suf}.pth') + + + ## ---- Validation Routine ---- + if (epoch+1) % CFG.valid_freq_epoch == 0: + model.eval() + v_running_loss_ = 0 + with torch.no_grad(): + for (y1, y2) in tqdm(validloader, total=len(validloader), + disable=CFG.disable_tqdm): + y1 = y1.to(device, non_blocking=True) + y2 = y2.to(device, non_blocking=True) + loss = model.forward(y1, y2) + v_running_loss_ += loss.item() + valid_epoch_loss = v_running_loss_/len(validloader) + best_flag = False + if valid_epoch_loss < best_loss: + best_flag = True + best_loss = valid_epoch_loss + + v_stats = dict(epoch=epoch, best=best_flag, wgt_suf=wgt_suf, + train_loss=train_epoch_loss, + valid_loss=valid_epoch_loss) + lutl.LOG2DICTXT(v_stats, CFG.gLogPath+'/valid-stats.txt') + + +if __name__ == '__main__': + simple_main() \ No newline at end of file