[40f229]: / model / utils / data_set.py

Download this file

91 lines (79 with data), 3.2 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch.utils.data as tordata
import numpy as np
import os.path as osp
import os
import pickle
import cv2
import xarray as xr
class DataSet(tordata.Dataset):
def __init__(self, seq_dir, label, seq_type, view, cache, resolution):
self.seq_dir = seq_dir
self.view = view
self.seq_type = seq_type
self.label = label
self.cache = cache
self.resolution = int(resolution)
self.cut_padding = int(float(resolution)/64*10)
self.data_size = len(self.label)
self.data = [None] * self.data_size
self.frame_set = [None] * self.data_size
self.label_set = set(self.label)
self.seq_type_set = set(self.seq_type)
self.view_set = set(self.view)
_ = np.zeros((len(self.label_set),
len(self.seq_type_set),
len(self.view_set))).astype('int')
_ -= 1
self.index_dict = xr.DataArray(
_,
coords={'label': sorted(list(self.label_set)),
'seq_type': sorted(list(self.seq_type_set)),
'view': sorted(list(self.view_set))},
dims=['label', 'seq_type', 'view'])
for i in range(self.data_size):
_label = self.label[i]
_seq_type = self.seq_type[i]
_view = self.view[i]
self.index_dict.loc[_label, _seq_type, _view] = i
def load_all_data(self):
for i in range(self.data_size):
self.load_data(i)
def load_data(self, index):
return self.__getitem__(index)
def __loader__(self, path):
return self.img2xarray(
path)[:, :, self.cut_padding:-self.cut_padding].astype(
'float32') / 255.0
def __getitem__(self, index):
# pose sequence sampling
if not self.cache:
data = [self.__loader__(_path) for _path in self.seq_dir[index]]
frame_set = [set(feature.coords['frame'].values.tolist()) for feature in data]
frame_set = list(set.intersection(*frame_set))
elif self.data[index] is None:
data = [self.__loader__(_path) for _path in self.seq_dir[index]]
frame_set = [set(feature.coords['frame'].values.tolist()) for feature in data]
frame_set = list(set.intersection(*frame_set))
self.data[index] = data
self.frame_set[index] = frame_set
else:
data = self.data[index]
frame_set = self.frame_set[index]
return data, frame_set, self.view[
index], self.seq_type[index], self.label[index],
def img2xarray(self, flie_path):
imgs = sorted(list(os.listdir(flie_path)))
frame_list = [np.reshape(
cv2.imread(osp.join(flie_path, _img_path)),
[self.resolution, self.resolution, -1])[:, :, 0]
for _img_path in imgs
if osp.isfile(osp.join(flie_path, _img_path))]
num_list = list(range(len(frame_list)))
data_dict = xr.DataArray(
frame_list,
coords={'frame': num_list},
dims=['frame', 'img_y', 'img_x'],
)
return data_dict
def __len__(self):
return len(self.label)