|
a |
|
b/libs/datasets/acdc_dataset.py |
|
|
1 |
import torch |
|
|
2 |
from torch.utils.data import Dataset |
|
|
3 |
from torchvision import transforms as T |
|
|
4 |
|
|
|
5 |
import os |
|
|
6 |
import json |
|
|
7 |
import numpy as np |
|
|
8 |
from PIL import Image |
|
|
9 |
import h5py |
|
|
10 |
|
|
|
11 |
import sys |
|
|
12 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
13 |
sys.path.append(os.path.join(BASE_DIR, '../../')) |
|
|
14 |
from utils.direct_field.df_cardia import direct_field |
|
|
15 |
from libs.datasets import augment as standard_aug |
|
|
16 |
from utils.direct_field.utils_df import class2dist |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
class AcdcDataset(Dataset): |
|
|
20 |
def __init__(self, data_list, df_used=False, joint_augment=None, augment=None, target_augment=None, df_norm=True, boundary=False): |
|
|
21 |
self.joint_augment = joint_augment |
|
|
22 |
self.augment = augment |
|
|
23 |
self.target_augment = target_augment |
|
|
24 |
self.data_list = data_list |
|
|
25 |
self.df_used = df_used |
|
|
26 |
self.df_norm = df_norm |
|
|
27 |
self.boundary = boundary |
|
|
28 |
|
|
|
29 |
with open(data_list, 'r') as f: |
|
|
30 |
self.data_infos = json.load(f) |
|
|
31 |
|
|
|
32 |
def __len__(self): |
|
|
33 |
return len(self.data_infos) |
|
|
34 |
|
|
|
35 |
def __getitem__(self,index): |
|
|
36 |
img = h5py.File(self.data_infos[index],'r')['image'] |
|
|
37 |
gt = h5py.File(self.data_infos[index],'r')['label'] |
|
|
38 |
# print(np.unique(gt)) |
|
|
39 |
img = np.array(img)[:,:,None].astype(np.float32) |
|
|
40 |
gt = np.array(gt)[:,:,None].astype(np.float32) |
|
|
41 |
# print(np.unique(gt)) |
|
|
42 |
|
|
|
43 |
if self.joint_augment is not None: |
|
|
44 |
img, gt = self.joint_augment(img, gt) |
|
|
45 |
if self.augment is not None: |
|
|
46 |
img = self.augment(img) |
|
|
47 |
if self.target_augment is not None: |
|
|
48 |
gt = self.target_augment(gt) |
|
|
49 |
|
|
|
50 |
if self.df_used: |
|
|
51 |
gt_df = direct_field(gt.numpy()[0], norm=self.df_norm) |
|
|
52 |
gt_df = torch.from_numpy(gt_df) |
|
|
53 |
else: |
|
|
54 |
gt_df = None |
|
|
55 |
|
|
|
56 |
if self.boundary: |
|
|
57 |
dist_map = torch.from_numpy(class2dist(gt.numpy()[0], C=4)) |
|
|
58 |
else: |
|
|
59 |
dist_map = None |
|
|
60 |
|
|
|
61 |
return img, gt, gt_df, dist_map |
|
|
62 |
|
|
|
63 |
|