--- a
+++ b/train.py
@@ -0,0 +1,283 @@
+import argparse
+import os
+from glob import glob
+import pandas as pd
+import yaml
+from utils import str2bool, write_csv
+from collections import OrderedDict
+from sklearn.model_selection import train_test_split
+
+from trainer import trainer, validate
+from dataset import CustomDataset
+import torch
+from torch.utils.data import DataLoader
+import torch.optim as optim
+from torch.nn.modules.loss import CrossEntropyLoss
+from metrics import Dice, IOU, HD
+
+from networks.RotCAtt_TransUNet_plusplus.RotCAtt_TransUNet_plusplus import RotCAtt_TransUNet_plusplus
+from networks.RotCAtt_TransUNet_plusplus.config import get_config as rot_config
+
+
+
+def parse_args():     
+    
+    # Training pipeline
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--name', default=None, help='model name')
+    parser.add_argument('--pretrained', default=False,
+                        help='pretrained or not (default: False)')
+    parser.add_argument('--epochs', default=600, type=int, metavar='N',
+                        help='number of epochs for training')
+    parser.add_argument('--batch_size', default=6, type=int, metavar='N',
+                        help='mini-batch size')
+    parser.add_argument('--seed', type=int, default=1234, help='random seed')
+    parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
+    parser.add_argument('--num_workers', default=3, type=int)
+    parser.add_argument('--val_mode', default=True, type=str2bool)
+    
+    # Network
+    parser.add_argument('--network', default='RotCAtt_TransUNet_plusplus') 
+    parser.add_argument('--input_channels', default=1, type=int,
+                        help='input channels')
+    parser.add_argument('--patch_size', default=16, type=int,
+                        help='input patch size')
+    parser.add_argument('--num_classes', default=12, type=int,
+                        help='number of classes')
+    parser.add_argument('--img_size', default=512, type=int,
+                        help='input image img_size')
+    
+    # Dataset
+    parser.add_argument('--dataset', default='VHSCDD', help='dataset name')
+    parser.add_argument('--ext', default='.npy', help='file extension')
+    parser.add_argument('--range', default=None, type=int, help='dataset size')
+    
+     # Criterion
+    parser.add_argument('--loss', default='Dice Iou Cross entropy')
+    
+    # Optimizer
+    parser.add_argument('--optimizer', default='SGD', choices=['Adam', 'SGD'],
+                        help='optimizer: ' + ' | '.join(['Adam', 'SGD']) 
+                        + 'default (Adam)')
+    parser.add_argument('--base_lr', '--learning_rate', default=0.01, type=float,
+                        metavar='LR', help='initial learning rate')
+    parser.add_argument('--momentum', default=0.9, type=float,
+                        help='momentum')
+    parser.add_argument('--weight_decay', default=0.0001, type=float,
+                        help='weight decay')
+    parser.add_argument('--nesterov', default=False, type=str2bool,
+                        help='nesterov')
+    
+    # scheduler
+    parser.add_argument('--scheduler', default='CosineAnnealingLR',
+                        choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 
+                                 'MultiStepLR', 'ConstantLR'])
+    parser.add_argument('--min_lr', default=1e-5, type=float,
+                        help='minimum learning rate')
+    parser.add_argument('--factor', default=0.1, type=float)
+    parser.add_argument('--patience', default=2, type=int)
+    parser.add_argument('--milestones', default='1,2', type=str)
+    parser.add_argument('--gamma', default=2/3, type=float)
+    parser.add_argument('--early_stopping', default=-1, type=int,
+                        metavar='N', help='early stopping (default: -1)')
+    
+    return parser.parse_args()
+
+def output_config(config):
+    print('-' * 20)
+    for key in config:
+        print(f'{key}: {config[key]}')
+    print('-' * 20)   
+   
+
+def loading_2D_data(config):
+    image_paths = glob(f"data/{config.dataset}/images/*.npy")
+    label_paths = glob(f"data/{config.dataset}/labels/*.npy")
+    
+    if config.range != None: 
+        image_paths = image_paths[:config.range]
+        label_paths = label_paths[:config.range]
+    
+    train_image_paths, val_image_paths, train_label_paths, val_label_paths = train_test_split(image_paths, label_paths, test_size=0.2, random_state=41)
+    train_ds = CustomDataset(config.num_classes, train_image_paths, train_label_paths, img_size=config.img_size)
+    val_ds = CustomDataset(config.num_classes, val_image_paths, val_label_paths, img_size=config.img_size)
+
+    train_loader = DataLoader(
+        train_ds,
+        batch_size=config.batch_size,
+        shuffle=False,
+        num_workers=config.num_workers,
+        drop_last=False,
+    )
+    
+    val_loader = DataLoader(
+        val_ds,
+        batch_size=config.batch_size,
+        shuffle=False,
+        num_workers=config.num_workers,
+        drop_last=False,
+    )
+    return train_loader, val_loader
+
+    
+def load_pretrained_model(model_path):
+    if os.path.exists(model_path):
+        model = torch.load(model_path)
+        return model
+    else:
+        print("No pretrained exists")
+        exit()
+        
+
+def load_network(config):
+    if config.network == 'RotCAtt_TransUNet_plusplus':
+        model_config = rot_config()
+        model_config.img_size = config.img_size
+        model_config.num_classes = config.num_classes
+        model = RotCAtt_TransUNet_plusplus(config=model_config).cuda()
+        
+    else:
+        print("Add the custom network to the training pipeline please")
+        exit(1)
+        
+    return model
+
+ 
+def rlog(value):
+    return round(value, 3)
+        
+def train(config):
+    config_dict = vars(config)
+    print(config.network)
+    
+    # Config name
+    config.name = f"{config.dataset}_{config.network}_bs{config.batch_size}_ps{config.patch_size}_epo{config.epochs}_hw{config.img_size}"
+
+    # Model
+    print(f"=> Initialize model: {config.network}")
+    if config.pretrained == False: 
+        model = load_network(config)
+        output_config(config_dict)
+        print(f"=> Initialize output: {config.name}")
+        model_path = f"outputs/{config.name}"
+        if not os.path.exists(model_path):
+            os.makedirs(model_path)
+            with open(f"{model_path}/config.yml", "w") as f:
+                yaml.dump(config_dict, f)
+                
+    else: model = load_pretrained_model(f'outputs/{config.name}/model.pth')
+    
+    # Data loading
+    if config.dataset == 'VHSCDD': config.dataset += f'_{config.img_size}'
+    train_loader, val_loader = loading_2D_data(config)
+    
+    # logging
+    log = OrderedDict([
+        ('epoch', []),                                          # 0
+        ('lr', []),                                             # 1
+        
+        ('Train loss', []),                                     # 2
+        ('Train ce loss', []),                                  # 3
+        ('Train dice score', []),                               # 4
+        ('Train dice loss', []),                                # 5
+        ('Train iou score', []),                                # 6
+        ('Train iou loss', []),                                 # 7
+        ('Train hausdorff', []),                                # 8
+        
+        ('Val loss', []),                                       # 8
+        ('Val ce loss', []),                                    # 9
+        ('Val dice score', []),                                 # 10
+        ('Val dice loss', []),                                  # 11
+        ('Val iou score', []),                                  # 12
+        ('Val iou loss', []),                                   # 13
+        ('Val hausdorff', []),                                  # 14
+    ])
+    
+    if config.pretrained: 
+        pre_log = pd.read_csv(f'outputs/{config.name}/epo_log.csv')
+        print(pre_log)
+        log = OrderedDict((key, []) for key in pre_log.keys())
+        for column in pre_log.columns:
+           log[column] = pre_log[column].tolist()
+
+    # Optimizer
+    params = filter(lambda p: p.requires_grad, model.parameters())
+    if config.optimizer == 'Adam':
+        optimizer = optim.Adam(params, lr=config.base_lr, weight_decay=config.weight_decay)
+    elif config.optimizer == 'SGD':
+        optimizer = optim.SGD(params, lr=config.base_lr, momentum=config.momentum,
+        nesterov=config.nesterov, weight_decay=config.weight_decay)
+        
+    # Criterion
+    ce = CrossEntropyLoss()
+    dice = Dice(config.num_classes)
+    iou = IOU(config.num_classes)
+    hd = HD()
+    
+    # Training loop
+    best_train_iou = 0
+    best_train_dice_score = 0
+    best_val_iou = 0
+    best_val_dice_score = 0
+    
+    fieldnames = ['CE Loss', 'Dice Score', 'Dice Loss', 'IoU Score', 'IoU Loss', 'HausDorff Distance', 'Total Loss']
+    iter_log_file = f'outputs/{config.name}/iter_log.csv'
+    if not os.path.exists(iter_log_file): 
+        write_csv(iter_log_file, fieldnames)
+        
+    for epoch in range(config.epochs):
+        print(f"Epoch: {epoch+1}/{config.epochs}")
+        train_log = trainer(config, train_loader, optimizer, model, ce, dice, iou, hd)
+        if config.val_mode: val_log = validate(config, val_loader, model, ce, dice, iou, hd)
+        
+        print(f"Train loss: {rlog(train_log['loss'])} - Train ce loss: {rlog(train_log['ce_loss'])} - Train dice score: {rlog(train_log['dice_score'])} - Train dice loss: {rlog(train_log['dice_loss'])} - Train iou Score: {rlog(train_log['iou_score'])} - Train iou loss: {rlog(train_log['iou_loss'])} - Train hausdorff: {rlog(train_log['hausdorff'])}")
+        if config.val_mode: print(f"Val loss: {rlog(val_log['loss'])} - Val ce loss: {rlog(val_log['ce_loss'])} - Val dice score: {rlog(val_log['dice_score'])} - Val dice loss: {rlog(val_log['dice_loss'])} - Val iou Score: {rlog(val_log['iou_score'])} - Val iou loss: {rlog(val_log['iou_loss'])} - Val hausdorff: {rlog(val_log['hausdorff'])}")
+        
+        log['epoch'].append(epoch)
+        log['lr'].append(config.base_lr)
+        
+        log['Train loss'].append(train_log['loss'])
+        log['Train ce loss'].append(train_log['ce_loss'])
+        log['Train dice score'].append(train_log['dice_score'])
+        log['Train dice loss'].append(train_log['dice_loss'])
+        log['Train iou score'].append(train_log['iou_score'])
+        log['Train iou loss'].append(train_log['iou_loss']) 
+        log['Train hausdorff'].append(train_log['hausdorff']) 
+        
+        if config.val_mode:
+            log['Val loss'].append(val_log['loss'])
+            log['Val ce loss'].append(val_log['ce_loss'])
+            log['Val dice score'].append(val_log['dice_score'])
+            log['Val dice loss'].append(val_log['dice_loss'])
+            log['Val iou score'].append(val_log['iou_score'])
+            log['Val iou loss'].append(val_log['iou_loss'])
+            log['Val hausdorff'].append(val_log['hausdorff'])
+            
+        else:
+            log['Val loss'].append(None)
+            log['Val ce loss'].append(None)
+            log['Val dice score'].append(None)
+            log['Val dice loss'].append(None)
+            log['Val iou score'].append(None)
+            log['Val iou loss'].append(None)
+            log['Val hausdorff'].append(None)
+
+        
+        pd.DataFrame(log).to_csv(f'outputs/{config.name}/epo_log.csv', index=False)
+
+        # Save best model
+        if train_log['iou_score'] > best_train_iou and train_log['dice_score'] > best_train_dice_score and val_log['iou_score'] > best_val_iou and val_log['dice_score'] > best_val_dice_score:
+                
+            best_train_iou = train_log['iou_score']
+            best_train_dice_score = train_log['dice_score']
+            best_val_iou = val_log['iou_score']
+            best_val_dice_score = val_log['dice_score']
+            
+            torch.save(model, f"outputs/{config.name}/model.pth")
+            
+        if (epoch+1) % 1 == 0:
+            print(f'BEST TRAIN DICE: {best_train_dice_score} - BEST TRAIN IOU: {best_train_iou} - BEST VAL DICE SCORE: {best_val_dice_score} - BEST VAL IOU: {best_val_iou}')
+            
+if __name__ == '__main__':
+    config = parse_args()
+    train(config)
\ No newline at end of file