[4fa73e]: / pytorch / datasets / dataloader.py

Download this file

143 lines (117 with data), 6.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
CelebA Dataloader implementation, used in DCGAN
"""
import numpy as np
import imageio
import torch
import torchvision.transforms as v_transforms
import torchvision.utils as v_utils
import torchvision.datasets as v_datasets
from torch.utils.data import DataLoader, TensorDataset, Dataset
from utils.preprocess import *
# For UNET based supervised method
class Supervised_Dataloader(Dataset):
def __init__(self, config, phase):
assert phase in ['training', 'validating', 'testing']
self.phase = phase
if phase == 'training':
self.patches, self.label = preprocess_dynamic_lab(config.data_directory, config.seed, config.num_classes,\
config.extraction_step, config.patch_shape,\
config.number_images_training)
print("Label unique:",np.unique(self.label))
if phase == 'validating':
self.patches, self.label, self.whole_vol = preprocess_dynamic_lab(config.data_directory, config.seed, config.num_classes,\
config.extraction_step, config.patch_shape,\
config.number_images_training, validating=True)
print("Label unique:",np.unique(self.label))
if phase == 'testing':
self.patches, self.whole_vol = preprocess_dynamic_lab(config.data_directory, config.seed, config.num_classes,\
config.extraction_step, config.patch_shape,\
config.number_images_training, testing=True)
# self.data_lab, self.label = shuffle(self.data_lab, self.label, random_state=0)
print("Data_shape:",self.patches.shape)
print("Data lab max and min:",np.max(self.patches),np.min(self.patches))
def __len__(self):
return len(self.patches)
def __getitem__(self, index):
if self.phase == 'training':
return self.patches[index], self.label[index]
if self.phase == 'validating':
return self.patches[index], self.label[index], self.whole_vol
if self.phase == 'testing':
return self.patches[index], self.whole_vol
class Supervised_Dataset:
def __init__(self, config, phase):
self.config = config
self.dataset = Supervised_Dataloader(config, phase)
if phase == 'training':
shuffle = True
else:
shuffle = False
self.loader = DataLoader(self.dataset,
batch_size=config.batch_size,
shuffle=shuffle,
num_workers=config.data_loader_workers,
pin_memory=config.pin_memory)
self.num_iterations = len(self.loader)
def finalize(self):
pass
# For GAN based few-shot method
class FewShot_Dataloader(Dataset):
def __init__(self, config, phase):
assert phase in ['training', 'validating', 'testing']
self.phase = phase
if phase == 'training':
self.patches, self.label = preprocess_dynamic_lab(config.data_directory, config.seed, config.num_classes,\
config.extraction_step, config.patch_shape,\
config.number_images_training)
print("Label unique:",np.unique(self.label))
self.patches_unlab = preprocess_dynamic_unlab(config.data_directory, config.extraction_step,
config.patch_shape, config.number_unlab_images_training)
self.patches_unlab = shuffle(self.patches_unlab, random_state=0)
factor = len(self.patches_unlab) // len(self.patches)
print("Factor for labeled images:",factor)
rem = len(self.patches_unlab)%len(self.patches)
temp = self.patches[:rem]
self.patches = np.concatenate((np.repeat(self.patches, factor, axis=0), temp), axis=0)
temp = self.label[:rem]
self.label = np.concatenate((np.repeat(self.label, factor, axis=0), temp), axis=0)
assert(self.patches.shape == self.patches_unlab.shape)
print("Data_shape:",self.patches.shape,self.patches_unlab.shape)
print("Data lab max and min:",np.max(self.patches),np.min(self.patches))
print("Data unlab max and min:",np.max(self.patches_unlab),np.min(self.patches_unlab))
print("Label unique:",np.unique(self.label))
if phase == 'validating':
self.patches, self.label, self.whole_vol = preprocess_dynamic_lab(config.data_directory, config.seed, config.num_classes,\
config.extraction_step, config.patch_shape,\
config.number_images_training, validating=True)
print("Label unique:",np.unique(self.label))
if phase == 'testing':
self.patches, self.whole_vol = preprocess_dynamic_lab(config.data_directory, config.seed, config.num_classes,\
config.extraction_step, config.patch_shape,\
config.number_images_training, testing=True)
def __len__(self):
return len(self.patches)
def __getitem__(self, index):
if self.phase == 'training':
return self.patches[index], self.patches_unlab[index], self.label[index]
if self.phase == 'validating':
return self.patches[index], self.label[index], self.whole_vol
if self.phase == 'testing':
return self.patches[index], self.whole_vol
class FewShot_Dataset:
def __init__(self, config, phase):
self.config = config
self.dataset = FewShot_Dataloader(config, phase)
if phase == 'training':
shuffle = True
else:
shuffle = False
self.loader = DataLoader(self.dataset,
batch_size=config.batch_size,
shuffle=shuffle,
num_workers=config.data_loader_workers,
pin_memory=config.pin_memory)
self.num_iterations = len(self.loader)
def finalize(self):
pass