a b/libs/datasets/acdc_dataset.py
1
import torch
2
from torch.utils.data import Dataset
3
from torchvision import transforms as T
4
5
import os
6
import json
7
import numpy as np
8
from PIL import Image
9
import h5py
10
11
import sys
12
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
13
sys.path.append(os.path.join(BASE_DIR, '../../'))
14
from utils.direct_field.df_cardia import direct_field
15
from libs.datasets import augment as standard_aug
16
from utils.direct_field.utils_df import class2dist
17
18
19
class AcdcDataset(Dataset):
20
    def __init__(self, data_list, df_used=False, joint_augment=None, augment=None, target_augment=None, df_norm=True, boundary=False):
21
        self.joint_augment = joint_augment
22
        self.augment = augment
23
        self.target_augment = target_augment
24
        self.data_list = data_list
25
        self.df_used = df_used
26
        self.df_norm = df_norm
27
        self.boundary = boundary
28
        
29
        with open(data_list, 'r') as f:
30
            self.data_infos = json.load(f)
31
32
    def __len__(self):
33
        return len(self.data_infos)
34
35
    def __getitem__(self,index):
36
        img = h5py.File(self.data_infos[index],'r')['image']
37
        gt = h5py.File(self.data_infos[index],'r')['label']
38
        # print(np.unique(gt))
39
        img = np.array(img)[:,:,None].astype(np.float32)
40
        gt = np.array(gt)[:,:,None].astype(np.float32)
41
        # print(np.unique(gt))
42
43
        if self.joint_augment is not None:
44
            img, gt = self.joint_augment(img, gt)
45
        if self.augment is not None:
46
            img = self.augment(img)
47
        if self.target_augment is not None:
48
            gt = self.target_augment(gt)
49
50
        if self.df_used:
51
            gt_df = direct_field(gt.numpy()[0], norm=self.df_norm)
52
            gt_df = torch.from_numpy(gt_df)
53
        else:
54
            gt_df = None
55
        
56
        if self.boundary:
57
            dist_map = torch.from_numpy(class2dist(gt.numpy()[0], C=4))
58
        else:
59
            dist_map = None
60
61
        return img, gt, gt_df, dist_map
62
63