--- a
+++ b/sybil/utils/loading.py
@@ -0,0 +1,192 @@
+from argparse import Namespace
+import hashlib
+import collections.abc as container_abcs
+import re
+from typing import Literal
+import torch
+from torch.utils import data
+
+from sybil.utils.sampler import DistributedWeightedSampler
+from sybil.augmentations import get_augmentations
+from sybil.loaders.image_loaders import OpenCVLoader, DicomLoader
+
+string_classes = (str, bytes)
+int_classes = int
+np_str_obj_array_pattern = re.compile(r"[SaUO]")
+
+default_collate_err_msg_format = (
+    "default_collate: batch must contain tensors, numpy arrays, numbers, "
+    "dicts, MoleculeDatapoint or lists; found {}"
+)
+
+
+def default_collate(batch):
+    r"""Puts each data field into a tensor with outer dimension batch size"""
+
+    elem = batch[0]
+    elem_type = type(elem)
+    if isinstance(elem, torch.Tensor):
+        out = None
+        if torch.utils.data.get_worker_info() is not None:
+            # If we're in a background process, concatenate directly into a
+            # shared memory tensor to avoid an extra copy
+            numel = sum([x.numel() for x in batch])
+            storage = elem.storage()._new_shared(numel)
+            out = elem.new(storage)
+        return torch.stack(batch, 0, out=out)
+    elif (
+        elem_type.__module__ == "numpy"
+        and elem_type.__name__ != "str_"
+        and elem_type.__name__ != "string_"
+    ):
+        if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
+            # array of string classes and object
+            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
+                raise TypeError(default_collate_err_msg_format.format(elem.dtype))
+
+            return default_collate([torch.as_tensor(b) for b in batch])
+        elif elem.shape == ():  # scalars
+            return torch.as_tensor(batch)
+    elif isinstance(elem, float):
+        return torch.tensor(batch, dtype=torch.float64)
+    elif isinstance(elem, int_classes):
+        return torch.tensor(batch)
+    elif isinstance(elem, string_classes):
+        return batch
+    elif isinstance(elem, container_abcs.Mapping):
+        return {key: default_collate([d[key] for d in batch]) for key in elem}
+    elif isinstance(elem, tuple) and hasattr(elem, "_fields"):  # namedtuple
+        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
+    elif isinstance(elem, container_abcs.Sequence):
+        # check to make sure that the elements in batch have consistent size
+        it = iter(batch)
+        elem_size = len(next(it))
+        if not all(len(elem) == elem_size for elem in it):
+            raise RuntimeError("each element in list of batch should be of equal size")
+        transposed = zip(*batch)
+        return [default_collate(samples) for samples in transposed]
+
+    raise TypeError(default_collate_err_msg_format.format(elem_type))
+
+
+def ignore_None_collate(batch):
+    """
+    default_collate wrapper that creates batches only of not None values.
+    Useful for cases when the dataset.__getitem__ can return None because of some
+    exception and then we will want to exclude that sample from the batch.
+    """
+    batch = [x for x in batch if x is not None]
+    if len(batch) == 0:
+        return None
+    return default_collate(batch)
+
+
+def get_train_dataset_loader(args, train_data):
+    """
+    Given arg configuration, return appropriate torch.DataLoader
+    for train_data and dev_data
+
+    returns:
+    train_data_loader: iterator that returns batches
+    dev_data_loader: iterator that returns batches
+    """
+    if args.accelerator == "ddp":
+        sampler = DistributedWeightedSampler(
+            train_data,
+            weights=train_data.weights,
+            replacement=True,
+            rank=args.global_rank,
+            num_replicas=args.world_size,
+        )
+    else:
+        sampler = data.sampler.WeightedRandomSampler(
+            weights=train_data.weights, num_samples=len(train_data), replacement=True
+        )
+
+    train_data_loader = data.DataLoader(
+        train_data,
+        num_workers=args.num_workers,
+        sampler=sampler,
+        pin_memory=True,
+        batch_size=args.batch_size,
+        collate_fn=ignore_None_collate,
+    )
+
+    return train_data_loader
+
+
+def get_eval_dataset_loader(args, eval_data, shuffle):
+
+    if args.accelerator == "ddp":
+        sampler = torch.utils.data.distributed.DistributedSampler(
+            eval_data,
+            shuffle=shuffle,
+            rank=args.global_rank,
+            num_replicas=args.world_size,
+        )
+    else:
+        sampler = (
+            torch.utils.data.sampler.RandomSampler(eval_data)
+            if shuffle
+            else torch.utils.data.sampler.SequentialSampler(eval_data)
+        )
+    data_loader = torch.utils.data.DataLoader(
+        eval_data,
+        batch_size=args.batch_size,
+        num_workers=args.num_workers,
+        collate_fn=ignore_None_collate,
+        pin_memory=True,
+        drop_last=False,
+        sampler=sampler,
+    )
+
+    return data_loader
+
+
+@torch.no_grad()
+def concat_all_gather(tensor):
+    """
+    Performs all_gather operation on the provided tensors.
+    *** Warning ***: torch.distributed.all_gather has no gradient.
+    """
+
+    tensors_gather = [
+        torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
+    ]
+    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
+    output = torch.cat(tensors_gather, dim=0)
+    return output
+
+
+def get_sample_loader(
+    split_group: Literal["train", "dev", "test"],
+    args: Namespace,
+    apply_augmentations=True,
+):
+    """[summary]
+
+    Parameters
+    ----------
+    ``split_group`` : str
+        dataset split according to which the augmentation is selected (choices are ['train', 'dev', 'test'])
+    ``args`` : Namespace
+        global args
+    ``apply_augmentations`` : bool, optional (default=True)
+
+    Returns
+    -------
+    abstract_loader
+        sample loader (DicomLoader for dicoms or OpenCVLoader pngs). see sybil.loaders.image_loaders
+
+    Raises
+    ------
+    NotImplementedError
+        img_file_type must be one of "dicom" or "png"
+    """
+    augmentations = get_augmentations(split_group, args)
+    if args.img_file_type == "dicom":
+        return DicomLoader(args.cache_path, augmentations, args, apply_augmentations)
+    elif args.img_file_type == "png":
+        return OpenCVLoader(args.cache_path, augmentations, args, apply_augmentations)
+    else:
+        raise NotImplementedError