[d5c425]: / datasets.py

Download this file

118 lines (97 with data), 4.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import glob
import os
from math import ceil, floor
from medpy.io import load, header
from models import Model
import utils
import pandas as pd
import matplotlib.pyplot as plt
class RadDataset(Dataset):
def __init__(self, df, root_data,train_flag=True, dim=[48, 48, 3], ring=15):
self.df = df
self.train_flag=train_flag
self.transforms = transforms.Compose([
transforms.ToTensor(),
transforms.RandomAffine(3, scale=(0.95, 1.05)),
transforms.RandomHorizontalFlip(0.5),
transforms.RandomVerticalFlip(0.5)
])
self.test_transforms = transforms.Compose([
transforms.ToTensor(),
])
self.y = np.array(df["DFS_3years"]).astype(np.float32)
self.time = np.array(df["DFS"]).astype(np.float32)
self.event = np.array(df["DFS_censor"]).astype(np.float32)
self.ID = np.array(df["radiology_folder_name"])
self.dim = dim
self.ring = ring
self.root_data = root_data
def __len__(self):
return len(self.y)
def get_radiology(self, ct_image, index,train_flag):
concat_vols = []
torch.cuda.manual_seed_all(42)
torch.manual_seed(42)
np.random.seed(42)
for location in ['tumor', 'lymph']:
X_min, X_max, Y_min, Y_max, Z_min, Z_max = np.array(
self.df["X_min_" + location][index]), np.array(
self.df["X_max_" + location][index]), np.array(
self.df["Y_min_" + location][index]), np.array(
self.df["Y_max_" + location][index]), np.array(
self.df["Z_min_" + location][index]), np.array(
self.df["Z_max_" + location][index])
X_min -= self.ring
Y_min -= self.ring
Z_min = max(3, Z_min - self.ring)
X_max += self.ring
Y_max += self.ring
Z_max = min(ct_image.shape[-1]-1, Z_max+ self.ring)
center_Y = int(ceil(int(Y_min+Y_max)/2))
center_X = int(ceil(int(X_min+X_max)/2))
Z_1, Z_2, Z_3 = Z_min+int((Z_max - Z_min)/4), Z_min + \
int((Z_max - Z_min)/2), Z_min + \
int(3*(Z_max - Z_min)/4)
center_Z1 = int((Z_min+Z_1)/2)
center_Z2 = int((Z_1+Z_2)/2)
center_Z3 = Z_1
center_Z4 = Z_3
center1 = [center_Y, center_X, center_Z1]
center2 = [center_Y, center_X, center_Z2]
center3 = [center_Y, center_X, center_Z3]
center4 = [center_Y, center_X, center_Z4]
if train_flag:
sub_vol1 = self.transforms(
utils.random_crop(ct_image, self.dim, center1))
sub_vol2 = self.transforms(
utils.random_crop(ct_image, self.dim, center2))
sub_vol3 = self.transforms(
utils.random_crop(ct_image, self.dim, center3))
sub_vol4 = self.transforms(
utils.random_crop(ct_image, self.dim, center4))
vol = torch.stack(
(sub_vol1, sub_vol2, sub_vol3, sub_vol4))
concat_vols.append(vol)
else:
sub_vol1 = self.test_transforms(
utils.random_crop(ct_image, self.dim, center1))
sub_vol2 = self.test_transforms(
utils.random_crop(ct_image, self.dim, center2))
sub_vol3 = self.test_transforms(
utils.random_crop(ct_image, self.dim, center3))
sub_vol4 = self.test_transforms(
utils.random_crop(ct_image, self.dim, center4))
vol = torch.stack(
(sub_vol1, sub_vol2, sub_vol3, sub_vol4))
concat_vols.append(vol)
return concat_vols
def __getitem__(self, index):
ct_image, _ = load(os.path.join(self.root_data, self.df["radiology_folder_name"].iloc[index], "CT_img.nii.gz"))
ct_image = utils.soft_tissue_window(ct_image)
ct_vol = self.get_radiology(ct_image, index,self.train_flag)
ct_tumor, ct_lymphnodes = ct_vol[0], ct_vol[1]
return ct_tumor, ct_lymphnodes, self.y[index], self.time[index], self.event[index], self.ID[index]