|
a |
|
b/merlin/data/dataloaders.py |
|
|
1 |
import torch |
|
|
2 |
import monai |
|
|
3 |
from copy import deepcopy |
|
|
4 |
import shutil |
|
|
5 |
import tempfile |
|
|
6 |
from pathlib import Path |
|
|
7 |
from typing import List |
|
|
8 |
from monai.utils import look_up_option |
|
|
9 |
from monai.data.utils import SUPPORTED_PICKLE_MOD |
|
|
10 |
|
|
|
11 |
from merlin.data.monai_transforms import ImageTransforms |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
class CTPersistentDataset(monai.data.PersistentDataset): |
|
|
15 |
def __init__(self, data, transform, cache_dir=None): |
|
|
16 |
super().__init__(data=data, transform=transform, cache_dir=cache_dir) |
|
|
17 |
|
|
|
18 |
print(f"Size of dataset: {self.__len__()}\n") |
|
|
19 |
|
|
|
20 |
def _cachecheck(self, item_transformed): |
|
|
21 |
hashfile = None |
|
|
22 |
_item_transformed = deepcopy(item_transformed) |
|
|
23 |
image_data = { |
|
|
24 |
"image": item_transformed.get("image") |
|
|
25 |
} # Assuming the image data is under the 'image' key |
|
|
26 |
|
|
|
27 |
if self.cache_dir is not None and image_data is not None: |
|
|
28 |
data_item_md5 = self.hash_func(image_data).decode( |
|
|
29 |
"utf-8" |
|
|
30 |
) # Hash based on image data |
|
|
31 |
hashfile = self.cache_dir / f"{data_item_md5}.pt" |
|
|
32 |
|
|
|
33 |
if hashfile is not None and hashfile.is_file(): |
|
|
34 |
cached_image = torch.load(hashfile) |
|
|
35 |
_item_transformed["image"] = cached_image |
|
|
36 |
return _item_transformed |
|
|
37 |
|
|
|
38 |
_image_transformed = self._pre_transform(image_data)["image"] |
|
|
39 |
_item_transformed["image"] = _image_transformed |
|
|
40 |
if hashfile is None: |
|
|
41 |
return _item_transformed |
|
|
42 |
try: |
|
|
43 |
# NOTE: Writing to a temporary directory and then using a nearly atomic rename operation |
|
|
44 |
# to make the cache more robust to manual killing of parent process |
|
|
45 |
# which may leave partially written cache files in an incomplete state |
|
|
46 |
with tempfile.TemporaryDirectory() as tmpdirname: |
|
|
47 |
temp_hash_file = Path(tmpdirname) / hashfile.name |
|
|
48 |
torch.save( |
|
|
49 |
obj=_image_transformed, |
|
|
50 |
f=temp_hash_file, |
|
|
51 |
pickle_module=look_up_option( |
|
|
52 |
self.pickle_module, SUPPORTED_PICKLE_MOD |
|
|
53 |
), |
|
|
54 |
pickle_protocol=self.pickle_protocol, |
|
|
55 |
) |
|
|
56 |
if temp_hash_file.is_file() and not hashfile.is_file(): |
|
|
57 |
# On Unix, if target exists and is a file, it will be replaced silently if the user has permission. |
|
|
58 |
# for more details: https://docs.python.org/3/library/shutil.html#shutil.move. |
|
|
59 |
try: |
|
|
60 |
shutil.move(str(temp_hash_file), hashfile) |
|
|
61 |
except FileExistsError: |
|
|
62 |
pass |
|
|
63 |
except PermissionError: # project-monai/monai issue #3613 |
|
|
64 |
pass |
|
|
65 |
return _item_transformed |
|
|
66 |
|
|
|
67 |
def _transform(self, index: int): |
|
|
68 |
pre_random_item = self._cachecheck(self.data[index]) |
|
|
69 |
return self._post_transform(pre_random_item) |
|
|
70 |
|
|
|
71 |
|
|
|
72 |
class DataLoader(monai.data.DataLoader): |
|
|
73 |
def __init__( |
|
|
74 |
self, |
|
|
75 |
datalist: List[dict], |
|
|
76 |
cache_dir: str, |
|
|
77 |
batchsize: int, |
|
|
78 |
shuffle: bool = True, |
|
|
79 |
num_workers: int = 0, |
|
|
80 |
): |
|
|
81 |
self.datalist = datalist |
|
|
82 |
self.cache_dir = cache_dir |
|
|
83 |
self.batchsize = batchsize |
|
|
84 |
self.dataset = CTPersistentDataset( |
|
|
85 |
data=datalist, |
|
|
86 |
transform=ImageTransforms, |
|
|
87 |
cache_dir=cache_dir, |
|
|
88 |
) |
|
|
89 |
super().__init__( |
|
|
90 |
self.dataset, |
|
|
91 |
batch_size=batchsize, |
|
|
92 |
shuffle=shuffle, |
|
|
93 |
num_workers=num_workers, |
|
|
94 |
) |