--- a +++ b/opengait/data/collate_fn.py @@ -0,0 +1,115 @@ +import math +import random +import numpy as np +from utils import get_msg_mgr + + +class CollateFn(object): + def __init__(self, label_set, sample_config): + self.label_set = label_set + sample_type = sample_config['sample_type'] + sample_type = sample_type.split('_') + self.sampler = sample_type[0] + self.ordered = sample_type[1] + if self.sampler not in ['fixed', 'unfixed', 'all']: + raise ValueError + if self.ordered not in ['ordered', 'unordered']: + raise ValueError + self.ordered = sample_type[1] == 'ordered' + + # fixed cases + if self.sampler == 'fixed': + self.frames_num_fixed = sample_config['frames_num_fixed'] + + # unfixed cases + if self.sampler == 'unfixed': + self.frames_num_max = sample_config['frames_num_max'] + self.frames_num_min = sample_config['frames_num_min'] + + if self.sampler != 'all' and self.ordered: + self.frames_skip_num = sample_config['frames_skip_num'] + + self.frames_all_limit = -1 + if self.sampler == 'all' and 'frames_all_limit' in sample_config: + self.frames_all_limit = sample_config['frames_all_limit'] + + def __call__(self, batch): + batch_size = len(batch) + # currently, the functionality of feature_num is not fully supported yet, it refers to 1 now. We are supposed to make our framework support multiple source of input data, such as silhouette, or skeleton. + feature_num = len(batch[0][0]) + seqs_batch, labs_batch, typs_batch, vies_batch = [], [], [], [] + + for bt in batch: + seqs_batch.append(bt[0]) + labs_batch.append(self.label_set.index(bt[1][0])) + typs_batch.append(bt[1][1]) + vies_batch.append(bt[1][2]) + + global count + count = 0 + + def sample_frames(seqs): + global count + sampled_fras = [[] for i in range(feature_num)] + seq_len = len(seqs[0]) + indices = list(range(seq_len)) + + if self.sampler in ['fixed', 'unfixed']: + if self.sampler == 'fixed': + frames_num = self.frames_num_fixed + else: + frames_num = random.choice( + list(range(self.frames_num_min, self.frames_num_max+1))) + + if self.ordered: + fs_n = frames_num + self.frames_skip_num + if seq_len < fs_n: + it = math.ceil(fs_n / seq_len) + seq_len = seq_len * it + indices = indices * it + + start = random.choice(list(range(0, seq_len - fs_n + 1))) + end = start + fs_n + idx_lst = list(range(seq_len)) + idx_lst = idx_lst[start:end] + idx_lst = sorted(np.random.choice( + idx_lst, frames_num, replace=False)) + indices = [indices[i] for i in idx_lst] + else: + replace = seq_len < frames_num + + if seq_len == 0: + get_msg_mgr().log_debug('Find no frames in the sequence %s-%s-%s.' + % (str(labs_batch[count]), str(typs_batch[count]), str(vies_batch[count]))) + + count += 1 + indices = np.random.choice( + indices, frames_num, replace=replace) + + for i in range(feature_num): + for j in indices[:self.frames_all_limit] if self.frames_all_limit > -1 and len(indices) > self.frames_all_limit else indices: + sampled_fras[i].append(seqs[i][j]) + return sampled_fras + + # f: feature_num + # b: batch_size + # p: batch_size_per_gpu + # g: gpus_num + fras_batch = [sample_frames(seqs) for seqs in seqs_batch] # [b, f] + batch = [fras_batch, labs_batch, typs_batch, vies_batch, None] + + if self.sampler == "fixed": + fras_batch = [[np.asarray(fras_batch[i][j]) for i in range(batch_size)] + for j in range(feature_num)] # [f, b] + else: + seqL_batch = [[len(fras_batch[i][0]) + for i in range(batch_size)]] # [1, p] + + def my_cat(k): return np.concatenate( + [fras_batch[i][k] for i in range(batch_size)], 0) + fras_batch = [[my_cat(k)] for k in range(feature_num)] # [f, g] + + batch[-1] = np.asarray(seqL_batch) + + batch[0] = fras_batch + return batch