Switch to unified view

a b/merlin/data/dataloaders.py
1
import torch
2
import monai
3
from copy import deepcopy
4
import shutil
5
import tempfile
6
from pathlib import Path
7
from typing import List
8
from monai.utils import look_up_option
9
from monai.data.utils import SUPPORTED_PICKLE_MOD
10
11
from merlin.data.monai_transforms import ImageTransforms
12
13
14
class CTPersistentDataset(monai.data.PersistentDataset):
15
    def __init__(self, data, transform, cache_dir=None):
16
        super().__init__(data=data, transform=transform, cache_dir=cache_dir)
17
18
        print(f"Size of dataset: {self.__len__()}\n")
19
20
    def _cachecheck(self, item_transformed):
21
        hashfile = None
22
        _item_transformed = deepcopy(item_transformed)
23
        image_data = {
24
            "image": item_transformed.get("image")
25
        }  # Assuming the image data is under the 'image' key
26
27
        if self.cache_dir is not None and image_data is not None:
28
            data_item_md5 = self.hash_func(image_data).decode(
29
                "utf-8"
30
            )  # Hash based on image data
31
            hashfile = self.cache_dir / f"{data_item_md5}.pt"
32
33
        if hashfile is not None and hashfile.is_file():
34
            cached_image = torch.load(hashfile)
35
            _item_transformed["image"] = cached_image
36
            return _item_transformed
37
38
        _image_transformed = self._pre_transform(image_data)["image"]
39
        _item_transformed["image"] = _image_transformed
40
        if hashfile is None:
41
            return _item_transformed
42
        try:
43
            # NOTE: Writing to a temporary directory and then using a nearly atomic rename operation
44
            #       to make the cache more robust to manual killing of parent process
45
            #       which may leave partially written cache files in an incomplete state
46
            with tempfile.TemporaryDirectory() as tmpdirname:
47
                temp_hash_file = Path(tmpdirname) / hashfile.name
48
                torch.save(
49
                    obj=_image_transformed,
50
                    f=temp_hash_file,
51
                    pickle_module=look_up_option(
52
                        self.pickle_module, SUPPORTED_PICKLE_MOD
53
                    ),
54
                    pickle_protocol=self.pickle_protocol,
55
                )
56
                if temp_hash_file.is_file() and not hashfile.is_file():
57
                    # On Unix, if target exists and is a file, it will be replaced silently if the user has permission.
58
                    # for more details: https://docs.python.org/3/library/shutil.html#shutil.move.
59
                    try:
60
                        shutil.move(str(temp_hash_file), hashfile)
61
                    except FileExistsError:
62
                        pass
63
        except PermissionError:  # project-monai/monai issue #3613
64
            pass
65
        return _item_transformed
66
67
    def _transform(self, index: int):
68
        pre_random_item = self._cachecheck(self.data[index])
69
        return self._post_transform(pre_random_item)
70
71
72
class DataLoader(monai.data.DataLoader):
73
    def __init__(
74
        self,
75
        datalist: List[dict],
76
        cache_dir: str,
77
        batchsize: int,
78
        shuffle: bool = True,
79
        num_workers: int = 0,
80
    ):
81
        self.datalist = datalist
82
        self.cache_dir = cache_dir
83
        self.batchsize = batchsize
84
        self.dataset = CTPersistentDataset(
85
            data=datalist,
86
            transform=ImageTransforms,
87
            cache_dir=cache_dir,
88
        )
89
        super().__init__(
90
            self.dataset,
91
            batch_size=batchsize,
92
            shuffle=shuffle,
93
            num_workers=num_workers,
94
        )