--- a
+++ b/model/utils/data_set.py
@@ -0,0 +1,90 @@
+import torch.utils.data as tordata
+import numpy as np
+import os.path as osp
+import os
+import pickle
+import cv2
+import xarray as xr
+
+
+class DataSet(tordata.Dataset):
+    def __init__(self, seq_dir, label, seq_type, view, cache, resolution):
+        self.seq_dir = seq_dir
+        self.view = view
+        self.seq_type = seq_type
+        self.label = label
+        self.cache = cache
+        self.resolution = int(resolution)
+        self.cut_padding = int(float(resolution)/64*10)
+        self.data_size = len(self.label)
+        self.data = [None] * self.data_size
+        self.frame_set = [None] * self.data_size
+
+        self.label_set = set(self.label)
+        self.seq_type_set = set(self.seq_type)
+        self.view_set = set(self.view)
+        _ = np.zeros((len(self.label_set),
+                      len(self.seq_type_set),
+                      len(self.view_set))).astype('int')
+        _ -= 1
+        self.index_dict = xr.DataArray(
+            _,
+            coords={'label': sorted(list(self.label_set)),
+                    'seq_type': sorted(list(self.seq_type_set)),
+                    'view': sorted(list(self.view_set))},
+            dims=['label', 'seq_type', 'view'])
+
+        for i in range(self.data_size):
+            _label = self.label[i]
+            _seq_type = self.seq_type[i]
+            _view = self.view[i]
+            self.index_dict.loc[_label, _seq_type, _view] = i
+
+    def load_all_data(self):
+        for i in range(self.data_size):
+            self.load_data(i)
+
+    def load_data(self, index):
+        return self.__getitem__(index)
+
+    def __loader__(self, path):
+        return self.img2xarray(
+            path)[:, :, self.cut_padding:-self.cut_padding].astype(
+            'float32') / 255.0
+
+    def __getitem__(self, index):
+        # pose sequence sampling
+        if not self.cache:
+            data = [self.__loader__(_path) for _path in self.seq_dir[index]]
+            frame_set = [set(feature.coords['frame'].values.tolist()) for feature in data]
+            frame_set = list(set.intersection(*frame_set))
+        elif self.data[index] is None:
+            data = [self.__loader__(_path) for _path in self.seq_dir[index]]
+            frame_set = [set(feature.coords['frame'].values.tolist()) for feature in data]
+            frame_set = list(set.intersection(*frame_set))
+            self.data[index] = data
+            self.frame_set[index] = frame_set
+        else:
+            data = self.data[index]
+            frame_set = self.frame_set[index]
+
+        return data, frame_set, self.view[
+            index], self.seq_type[index], self.label[index],
+
+    def img2xarray(self, flie_path):
+        imgs = sorted(list(os.listdir(flie_path)))
+        frame_list = [np.reshape(
+            cv2.imread(osp.join(flie_path, _img_path)),
+            [self.resolution, self.resolution, -1])[:, :, 0]
+                      for _img_path in imgs
+                      if osp.isfile(osp.join(flie_path, _img_path))]
+        num_list = list(range(len(frame_list)))
+        data_dict = xr.DataArray(
+            frame_list,
+            coords={'frame': num_list},
+            dims=['frame', 'img_y', 'img_x'],
+        )
+        return data_dict
+
+    def __len__(self):
+        return len(self.label)