a b/data_generator_3D.py
1
import torch
2
import numpy as np
3
import time
4
import math
5
import os
6
import random
7
import nibabel as nib
8
9
base_path = '.../data/train_valid_test/'  # 改成你的路径
10
size_x = 240
11
size_y = 160
12
size_z = 48
13
class Covid19TrainSet():
14
    def __iter__(self):
15
        file = "/home/ubuntu/zhaoqianfei/data/train_valid_test/config/image_train_names.txt"
16
        train_list = []
17
        with open(file) as f:
18
            for line in f:
19
                for i in line.split():
20
                    train_list.append(int(i))
21
22
        for i in train_list:
23
            image = nib.load(base_path + 'image/' + str(i) + '.nii.gz')
24
            image = np.asarray(image.dataobj)[np.newaxis, np.newaxis, :,  :, :]
25
            label = nib.load(base_path + 'label/' + str(i) + '.nii.gz')
26
            label = np.asarray(label.dataobj)[np.newaxis, np.newaxis, :,  :, :]
27
            x = image.shape[2]
28
            y = image.shape[3]
29
            z = image.shape[4]
30
            x_random = random.randrange(0, x-size_x)
31
            y_random = random.randrange(0, y-size_y)
32
            z_random = random.randrange(0, z-size_z) if z > 64 else 0
33
            image_random = image[:,:, x_random:x_random+size_x, y_random:y_random+size_y, z_random:z_random+size_z]
34
            label_random = label[:,:, x_random:x_random+size_x, y_random:y_random+size_y, z_random:z_random+size_z]
35
36
            yield str(i) + '.nii.gz', image_random, label_random
37
38
        return
39
40
    def __len__(self):
41
        return 80
42
43
44
class Covid19EvalSet():
45
    def __iter__(self):
46
        file = ".../data/train_valid_test/config/image_valid_names.txt"
47
        train_list = []
48
        with open(file) as f:
49
            for line in f:
50
                for i in line.split():
51
                    train_list.append(int(i))
52
53
        for i in train_list:
54
            image = nib.load(base_path + 'image/' + str(i) + '.nii.gz')
55
            image = np.asarray(image.dataobj)[np.newaxis, np.newaxis, :,  :, :]
56
            label = nib.load(base_path + 'label/' + str(i) + '.nii.gz')
57
            label = np.asarray(label.dataobj)[np.newaxis, np.newaxis, :,  :, :]
58
            z = image.shape[4]
59
            z_random = random.randrange(0, z-size_z) if z > 64 else 0
60
            image_random = image[:,:, :, :, z_random:z_random+size_z]
61
            label_random = label[:,:, :, :, z_random:z_random+size_z]
62
            yield str(i) + '.nii.gz', image_random, label_random
63
        return
64
65
    def __len__(self):
66
        return 13
67
68
69
class Convid19TestSet:
70
    def __iter__(self):
71
        file = ".../data/train_valid_test/config/image_test_names.txt"
72
        train_list = []
73
        with open(file) as f:
74
            for line in f:
75
                for i in line.split():
76
                    train_list.append(int(i))
77
        #train_list = [31]
78
        for i in train_list:
79
            image = nib.load(base_path + 'image/' + str(i) + '.nii.gz')
80
            image = np.asarray(image.dataobj)[np.newaxis, np.newaxis, :,  :, :]
81
            label = nib.load(base_path + 'label/' + str(i) + '.nii.gz')
82
            label = np.asarray(label.dataobj)[np.newaxis, np.newaxis, :,  :, :]
83
84
            yield str(i) + '.nii.gz', image, label
85
86
        return
87
88
'''train_loader = Covid19TrainSet()
89
for step, (name, X, y) in enumerate(train_loader):
90
    print("???")'''