--- a
+++ b/src/experiment.py
@@ -0,0 +1,145 @@
+from collections import OrderedDict
+import torch
+import torch.nn as nn
+from torch.utils.data import ConcatDataset
+import random
+from catalyst.dl.experiment import ConfigExperiment
+from dataset import *
+from augmentation import train_aug, valid_aug
+
+
+class Experiment(ConfigExperiment):
+    def _postprocess_model_for_stage(self, stage: str, model: nn.Module):
+
+        import warnings
+        warnings.filterwarnings("ignore")
+
+        random.seed(2411)
+        np.random.seed(2411)
+        torch.manual_seed(2411)
+
+        model_ = model
+        if isinstance(model, torch.nn.DataParallel):
+            model_ = model_.module
+
+        if stage == "warmup":
+            if hasattr(model_, 'freeze'):
+                model_.freeze(model_)
+                print("Freeze backbone model using freeze method !!!")
+            else:
+                for param in model_.parameters():
+                    param.requires_grad = False
+
+                for param in model_.get_classifier().parameters():
+                    param.requires_grad = True
+                print("Freeze backbone model !!!")
+
+        else:
+            if hasattr(model_, 'unfreeze'):
+                model_.unfreeze(model_)
+                print("Unfreeze backbone model using unfreeze method !!!")
+            else:
+                for param in model_.parameters():
+                    param.requires_grad = True
+
+                print("Unfreeze backbone model !!!")
+        #
+        # import apex
+        # model_ = apex.parallel.convert_syncbn_model(model_)
+
+        return model_
+
+    def get_datasets(self, stage: str, **kwargs):
+        datasets = OrderedDict()
+
+        """
+        image_key: 'id'
+        label_key: 'attribute_ids'
+        """
+
+        image_size = kwargs.get("image_size", [224, 224])
+        train_csv = kwargs.get('train_csv', None)
+        valid_csv = kwargs.get('valid_csv', None)
+        with_any = kwargs.get('with_any', True)
+        dataset_type = kwargs.get('dataset_type', 'RSNADataset')
+        image_type = kwargs.get('image_type', 'jpg')
+        normalization = kwargs.get('normalization', True)
+        root = kwargs.get('root', None)
+
+        print(f"Image Size: {image_size}")
+
+        if train_csv:
+            transform = train_aug(image_size)
+            if dataset_type == 'RSNADataset':
+                train_set = RSNADataset(
+                    csv_file=train_csv,
+                    root=root,
+                    with_any=with_any,
+                    transform=transform,
+                    mode='train',
+                    image_type=image_type
+                )
+            elif dataset_type == 'RSNAMultiWindowsDataset':
+                train_set = RSNAMultiWindowsDataset(
+                    csv_file=train_csv,
+                    root=root,
+                    with_any=with_any,
+                    transform=transform
+                )
+            elif dataset_type == 'RSNADicomDataset':
+                train_set = RSNADicomDataset(
+                    csv_file=train_csv,
+                    root=root,
+                    with_any=with_any,
+                    transform=transform
+                )
+            elif dataset_type == "RSNARandomWindowDataset":
+                train_set = RSNARandomWindowDataset(
+                    csv_file=train_csv,
+                    root=root,
+                    with_any=with_any,
+                    transform=transform
+                )
+            else:
+                raise("No Dataset: {}".format(dataset_type))
+            datasets["train"] = train_set
+
+        if valid_csv:
+            transform = valid_aug(image_size)
+            if dataset_type == 'RSNADataset':
+                valid_set = RSNADataset(
+                    csv_file=valid_csv,
+                    root=root,
+                    with_any=with_any,
+                    transform=transform,
+                    mode='valid',
+                    image_type=image_type
+                )
+            elif dataset_type == 'RSNAMultiWindowsDataset':
+                valid_set = RSNAMultiWindowsDataset(
+                    csv_file=valid_csv,
+                    root=root,
+                    with_any=with_any,
+                    transform=transform
+                )
+            elif dataset_type == 'RSNADicomDataset':
+                valid_set = RSNADicomDataset(
+                    csv_file=valid_csv,
+                    root=root,
+                    with_any=with_any,
+                    transform=transform,
+                    mode='valid'
+                )
+            elif dataset_type == "RSNARandomWindowDataset":
+                valid_set = RSNARandomWindowDataset(
+                    csv_file=valid_csv,
+                    root=root,
+                    with_any=with_any,
+                    transform=transform,
+                    mode='valid'
+                )
+            else:
+                raise("No Dataset: {}".format(dataset_type))
+            datasets["valid"] = valid_set
+
+        return datasets