Diff of /tasks/cls-train.py [000000] .. [a18f15]

Switch to side-by-side view

--- a
+++ b/tasks/cls-train.py
@@ -0,0 +1,373 @@
+""" Classifier Network trainig
+"""
+
+import argparse
+import json
+import os
+import sys
+import time
+from tqdm.autonotebook import tqdm
+
+import torch
+from torch import nn, optim
+import torchinfo
+
+import numpy as np
+from sklearn.model_selection import train_test_split as sk_train_test_split
+
+sys.path.append(os.getcwd())
+import utilities.runUtils as rutl
+import utilities.logUtils as lutl
+from utilities.metricUtils import MultiClassMetrics
+from algorithms.classifiers import ClassifierNet
+from datacode.ultrasound_data import ClassifyDataFromCSV, get_class_weights
+from datacode.augmentations import ClassifierTransform
+
+print(f"Pytorch version: {torch.__version__}")
+print(f"cuda version: {torch.version.cuda}")
+device = 'cuda' if torch.cuda.is_available() else 'cpu'
+print("Device Used:", device)
+
+###============================= Configure and Setup ===========================
+
+CFG = rutl.ObjDict(
+data_folder  = "/home/joseph.benjamin/WERK/fetal-ultrasound/data/Fetal-UltraSound/US-Planes-Heart-Views-V3",
+balance_data = False, #while loading in dataloader; removed
+seed = 1792,  #previously 73
+
+epochs        = 100,
+image_size    = 256,
+batch_size    = 128,
+workers       = 16,
+learning_rate = 1e-3,
+weight_decay  = 1e-6,
+
+featx_arch     = "resnet50",
+featx_pretrain =  "IMAGENET-1K" , # "IMAGENET-1K" or None
+featx_freeze   = False,
+featx_bnorm    = False,
+featx_dropout  = 0.5,
+clsfy_layers   = [5], #First mlp inwill be set w.r.t FeatureExtractor
+clsfy_dropout  = 0.0,
+
+checkpoint_dir   = "hypotheses/#dummy/Classify/trail-002",
+disable_tqdm     = False, #True--> to disable
+restart_training = True
+)
+
+### ----------------------------------------------------------------------------
+# CLI TAKES PRECENCE OVER JSON CONFIG
+# e.g CLI overwrites the value set for featx-pretain in JSON while running
+# without CLI default values form dict will be used
+
+parser = argparse.ArgumentParser(description='Classification task')
+parser.add_argument('--load-json', type=str, metavar='JSON',
+    help='Load settings from file in json format. Command line options override values in file.')
+
+parser.add_argument('--seed', type=int, metavar='INT',
+    help='add batchnorm between feature extractor and classifier')
+
+parser.add_argument('--featx-freeze', type=bool, metavar='BOOL',
+    help='freeze pretrain or not')
+
+parser.add_argument('--featx-bnorm', type=bool, metavar='BOOL',
+    help='add batchnorm between feature extractor and classifier')
+
+parser.add_argument('--featx-pretrain', type=str, metavar='PATH',
+    help='Set from where to load the prestrained weight from')
+
+parser.add_argument('--checkpoint-dir', type=str, metavar='PATH',
+    help='Load settings from file in json format. Command line options override values in file.')
+
+
+args = parser.parse_args()
+
+if args.load_json:
+    with open(args.load_json, 'rt') as f:
+        CFG.__dict__.update(json.load(f))
+
+for arg in vars(args):
+    att = getattr(args, arg)
+    if att: CFG.__dict__[arg] = att
+
+### ----------------------------------------------------------------------------
+CFG.gLogPath = CFG.checkpoint_dir
+CFG.gWeightPath = CFG.checkpoint_dir + '/weights/'
+
+### ============================================================================
+
+def getDataLoaders(data_percent=None):
+    ## Augumentations
+    train_transforms =ClassifierTransform(image_size=CFG.image_size, mode="train")
+    valid_transforms =ClassifierTransform(image_size=CFG.image_size, mode="infer")
+
+    ## Dataset Class
+    traindataset = ClassifyDataFromCSV(CFG.data_folder,
+                                       CFG.data_folder+"/trainV3.csv",
+                                       transform = train_transforms,)
+    validdataset = ClassifyDataFromCSV(CFG.data_folder,
+                                       CFG.data_folder+"/validV3.csv",
+                                       transform = valid_transforms,)
+    class_weights, _ = get_class_weights(traindataset.targets, nclasses=5)
+
+    ### Choose P% of data from train data
+    if data_percent and (data_percent < 100):
+        _idx, used_idx = sk_train_test_split( np.arange(len(traindataset)),
+                                test_size=data_percent/100, random_state=CFG.seed,
+                                stratify=traindataset.targets)
+        traindataset = torch.utils.data.Subset(traindataset, sorted(used_idx))
+        lutl.LOG2CSV(sorted(used_idx), CFG.gLogPath +'/train_indices_used.csv')
+
+    torch.manual_seed(CFG.seed)
+    ## Loaders Class
+    trainloader  = torch.utils.data.DataLoader( traindataset, shuffle=True,
+                        batch_size=CFG.batch_size, num_workers=CFG.workers,
+                        pin_memory=True)
+
+    validloader  = torch.utils.data.DataLoader( validdataset, shuffle=False,
+                        batch_size=CFG.batch_size, num_workers=CFG.workers,
+                        pin_memory=True)
+
+    lutl.LOG2DICTXT({"Train->":len(traindataset),
+                    "class-weights":str(class_weights),
+                    "TransformsClass": str(train_transforms.get_composition()),
+                    },CFG.gLogPath +'/misc.txt')
+    lutl.LOG2DICTXT({"Valid->":len(validdataset),
+                    "TransformsClass": str(valid_transforms.get_composition()),
+                    },CFG.gLogPath +'/misc.txt')
+
+    return trainloader, validloader, class_weights
+
+
+def getModelnOptimizer():
+
+    ## pretrain setting
+    m_state = 0; torch_pretrain_flag = None
+    if os.path.isfile(CFG.featx_pretrain):
+        m_state = torch.load(CFG.featx_pretrain, map_location='cpu')
+    else: torch_pretrain_flag = CFG.featx_pretrain
+
+    model = ClassifierNet(arch=CFG.featx_arch,
+                    fc_layer_sizes=CFG.clsfy_layers,
+                    feature_freeze=CFG.featx_freeze,
+                    feature_dropout=CFG.featx_dropout,
+                    feature_bnorm=CFG.featx_bnorm,
+                    classifier_dropout=CFG.clsfy_dropout,
+                    torch_pretrain=torch_pretrain_flag )
+
+    ## load from checkpoints
+    if m_state:
+        m_state = m_state["model"]
+        ret_msg = model.load_state_dict(m_state, strict=False)
+        lutl.LOG2TXT(f"Manual Pretrain Loaded...{CFG.featx_pretrain},{str(ret_msg)}",
+                     CFG.gLogPath +'/misc.txt')
+
+    model_info = torchinfo.summary(model, (1, 3, CFG.image_size, CFG.image_size),
+                                verbose=0)
+    lutl.LOG2TXT(model_info, CFG.gLogPath +'/misc.txt', console= False)
+
+    ##--------------
+
+    optimizer = optim.AdamW(model.parameters(), lr=CFG.learning_rate,
+                        weight_decay=CFG.weight_decay)
+    scheduler = False
+
+    return model.to(device), optimizer, scheduler
+
+
+def getLossFunc(class_weights):
+    lossfn = nn.CrossEntropyLoss(weight=torch.tensor(class_weights,
+                                        dtype=torch.float32).to(device) )
+    return lossfn
+
+
+def simple_main(data_percent=None):
+
+   ### SETUP
+    rutl.START_SEED(CFG.seed)
+    gpu = 0
+    torch.cuda.set_device(gpu)
+    torch.backends.cudnn.benchmark = True
+
+    ## paths and logs setup
+    if data_percent: CFG.gLogPath = CFG.checkpoint_dir+f"/{data_percent}_percent/"
+    CFG.gWeightPath = CFG.gLogPath+"/weights/"
+
+    if os.path.exists(CFG.gLogPath) and (not CFG.restart_training):
+        raise Exception("CheckPoint folder already exists and restart_training not enabled; Somethings Wrong!",
+                        CFG.checkpoint_dir)
+    if not os.path.exists(CFG.gWeightPath): os.makedirs(CFG.gWeightPath)
+
+    with open(CFG.gLogPath+"/exp_cfg.json", 'a') as f:
+        json.dump(vars(CFG), f, indent=4)
+
+
+    ### DATA ACCESS
+    trainloader, validloader, class_weights  = getDataLoaders(data_percent)
+
+    ### MODEL, OPTIM
+    model, optimizer, scheduler = getModelnOptimizer()
+    lossfn = getLossFunc(class_weights)
+
+
+    ## Automatically resume from checkpoint if it exists and enabled
+    if os.path.exists(CFG.gWeightPath +'/checkpoint.pth') and CFG.restart_training:
+        ckpt = torch.load(CFG.gWeightPath  +'/checkpoint.pth',
+                            map_location='cpu')
+        start_epoch = ckpt['epoch']
+        model.load_state_dict(ckpt['model'])
+        optimizer.load_state_dict(ckpt['optimizer'])
+        lutl.LOG2TXT(f"Restarting Training from EPOCH:{start_epoch} of {CFG.gLogPath}",  CFG.gLogPath +'/misc.txt')
+    else:
+        start_epoch = 0
+
+    ### MODEL TRAINING
+    start_time = time.time()
+    best_acc = 0 ; best_loss = float('inf')
+    trainMetric = MultiClassMetrics(CFG.gLogPath)
+    validMetric = MultiClassMetrics(CFG.gLogPath)
+
+    for epoch in range(start_epoch, CFG.epochs):
+
+        ## ---- Training Routine ----
+        model.train()
+        for img, tgt in tqdm(trainloader, disable=CFG.disable_tqdm):
+            img = img.to(device, non_blocking=True)
+            tgt = tgt.to(device, non_blocking=True)
+            optimizer.zero_grad()
+            pred = model.forward(img)
+            loss = lossfn(pred, tgt)
+            loss.backward()
+            # nn.utils.clip_grad_norm_(model.parameters(),
+            #                          max_norm=2.0, norm_type=2)
+            optimizer.step()
+            trainMetric.add_entry(torch.argmax(pred, dim=1), tgt, loss)
+        if scheduler: scheduler.step()
+
+        ## save checkpoint states
+        state = dict(epoch=epoch + 1, model=model.state_dict(),
+                        optimizer=optimizer.state_dict())
+        torch.save(state, CFG.gWeightPath +'/checkpoint.pth')
+
+
+        ## ---- Validation Routine ----
+        model.eval()
+        with torch.no_grad():
+            for img, tgt in tqdm(validloader, disable=CFG.disable_tqdm):
+                img = img.to(device, non_blocking=True)
+                tgt = tgt.to(device, non_blocking=True)
+                pred = model.forward(img)
+                loss = lossfn(pred, tgt)
+                validMetric.add_entry(torch.argmax(pred, dim=1), tgt, loss)
+
+        ## Log Metrics TODO Add balanced and F1
+        stats = dict(
+                epoch=epoch, time=int(time.time() - start_time),
+                trainloss = trainMetric.get_loss(),
+                trainacc  = trainMetric.get_balanced_accuracy(),
+                trainF1   = trainMetric.get_f1score(),
+                validloss = validMetric.get_loss(),
+                validacc  = validMetric.get_balanced_accuracy(),
+                validF1   = validMetric.get_f1score(),
+                )
+        lutl.LOG2DICTXT(stats, CFG.gLogPath+'/train-stats.txt')
+
+
+        ## save best model
+        best_flag = False
+        if stats['validacc'] > best_acc:
+            torch.save(model.state_dict(), CFG.gWeightPath +'/bestmodel.pth')
+            best_acc = stats['validacc']
+            best_loss = stats['validloss']
+            best_flag = True
+
+        ## Log detailed validation
+        detail_stat = dict(
+                epoch=epoch, time=int(time.time() - start_time),
+                best = best_flag,
+                validf1scr  = validMetric.get_f1score(),
+                validbalacc = validMetric.get_balanced_accuracy(),
+                validacc    = validMetric.get_accuracy(),
+                validreport = validMetric.get_class_report(),
+                validconfus = validMetric.get_confusion_matrix().tolist(),
+            )
+        lutl.LOG2DICTXT(detail_stat, CFG.gLogPath+'/validation-details.txt', console=False)
+
+        trainMetric.reset()
+        validMetric.reset(best_flag)
+
+    return CFG.gLogPath
+
+
+
+def simple_test(saved_logpath):
+
+    ### SETUP
+    rutl.START_SEED()
+    gpu = 0
+    torch.cuda.set_device(gpu)
+    torch.backends.cudnn.benchmark = True
+
+    ### DATA ACCESS
+    test_transforms =ClassifierTransform(image_size=CFG.image_size,
+                                        mode="infer")
+    testdataset = ClassifyDataFromCSV(  CFG.data_folder,
+                                        CFG.data_folder+"/testV3.csv",
+                                        transform = test_transforms,)
+    testloader  = torch.utils.data.DataLoader( testdataset,
+                                        shuffle=False,
+                                        batch_size=CFG.batch_size,
+                                        num_workers=CFG.workers,
+                                        pin_memory=True)
+    lutl.LOG2DICTXT({"TEST->":len(testdataset),
+                     "TransformsClass": str(test_transforms.get_composition()),
+                    },saved_logpath +'/test-results.txt')
+
+    ### MODEL
+    model = ClassifierNet(arch=CFG.featx_arch,
+                    fc_layer_sizes=CFG.clsfy_layers,
+                    feature_freeze=CFG.featx_freeze,
+                    feature_dropout=CFG.featx_dropout,
+                    feature_bnorm=CFG.featx_bnorm,
+                    classifier_dropout=CFG.clsfy_dropout)
+    model = model.to(device)
+    model.load_state_dict(torch.load(saved_logpath+"/weights/bestmodel.pth"))
+
+
+    ### MODEL TESTING
+    testMetric = MultiClassMetrics(saved_logpath)
+    model.eval()
+
+    start_time = time.time()
+    with torch.no_grad():
+        for img, tgt in tqdm(testloader, disable=CFG.disable_tqdm):
+            img = img.to(device, non_blocking=True)
+            tgt = tgt.to(device, non_blocking=True)
+            pred = model.forward(img)
+            testMetric.add_entry(torch.argmax(pred, dim=1), tgt)
+
+        ## Log detailed validation
+        detail_stat = dict(
+                timetaken   = int(time.time() - start_time),
+                testf1scr  = testMetric.get_f1score(),
+                testbalacc = testMetric.get_balanced_accuracy(),
+                testacc    = testMetric.get_accuracy(),
+                testreport = testMetric.get_class_report(),
+                testconfus = testMetric.get_confusion_matrix(
+                                        save_png= True, title="test").tolist(),
+            )
+        lutl.LOG2DICTXT(detail_stat, saved_logpath+'/test-results.txt',
+                        console=True)
+
+        testMetric._write_predictions(title="test")
+
+
+
+if __name__ == '__main__':
+
+    # logpth = simple_main()
+    # simple_test(logpth)
+
+    for p in [100, 50, 25, 10, 5, 1]:
+        logpth = simple_main(data_percent=p)
+        simple_test(logpth)
\ No newline at end of file