--- a +++ b/model/model.py @@ -0,0 +1,272 @@ +import math +import os +import os.path as osp +import random +import sys +from datetime import datetime + +import numpy as np +import torch +import torch.nn as nn +import torch.autograd as autograd +import torch.optim as optim +import torch.utils.data as tordata + +from .network import TripletLoss, SetNet +from .utils import TripletSampler + + +class Model: + def __init__(self, + hidden_dim, + lr, + hard_or_full_trip, + margin, + num_workers, + batch_size, + restore_iter, + total_iter, + save_name, + train_pid_num, + frame_num, + model_name, + train_source, + test_source, + img_size=64): + + self.save_name = save_name + self.train_pid_num = train_pid_num + self.train_source = train_source + self.test_source = test_source + + self.hidden_dim = hidden_dim + self.lr = lr + self.hard_or_full_trip = hard_or_full_trip + self.margin = margin + self.frame_num = frame_num + self.num_workers = num_workers + self.batch_size = batch_size + self.model_name = model_name + self.P, self.M = batch_size + + self.restore_iter = restore_iter + self.total_iter = total_iter + + self.img_size = img_size + + self.encoder = SetNet(self.hidden_dim).float() + self.encoder = nn.DataParallel(self.encoder) + self.triplet_loss = TripletLoss(self.P * self.M, self.hard_or_full_trip, self.margin).float() + self.triplet_loss = nn.DataParallel(self.triplet_loss) + self.encoder.cuda() + self.triplet_loss.cuda() + + self.optimizer = optim.Adam([ + {'params': self.encoder.parameters()}, + ], lr=self.lr) + + self.hard_loss_metric = [] + self.full_loss_metric = [] + self.full_loss_num = [] + self.dist_list = [] + self.mean_dist = 0.01 + + self.sample_type = 'all' + + def collate_fn(self, batch): + batch_size = len(batch) + feature_num = len(batch[0][0]) + seqs = [batch[i][0] for i in range(batch_size)] + frame_sets = [batch[i][1] for i in range(batch_size)] + view = [batch[i][2] for i in range(batch_size)] + seq_type = [batch[i][3] for i in range(batch_size)] + label = [batch[i][4] for i in range(batch_size)] + batch = [seqs, view, seq_type, label, None] + + def select_frame(index): + sample = seqs[index] + frame_set = frame_sets[index] + if self.sample_type == 'random': + frame_id_list = random.choices(frame_set, k=self.frame_num) + _ = [feature.loc[frame_id_list].values for feature in sample] + else: + _ = [feature.values for feature in sample] + return _ + + seqs = list(map(select_frame, range(len(seqs)))) + + if self.sample_type == 'random': + seqs = [np.asarray([seqs[i][j] for i in range(batch_size)]) for j in range(feature_num)] + else: + gpu_num = min(torch.cuda.device_count(), batch_size) + batch_per_gpu = math.ceil(batch_size / gpu_num) + batch_frames = [[ + len(frame_sets[i]) + for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1)) + if i < batch_size + ] for _ in range(gpu_num)] + if len(batch_frames[-1]) != batch_per_gpu: + for _ in range(batch_per_gpu - len(batch_frames[-1])): + batch_frames[-1].append(0) + max_sum_frame = np.max([np.sum(batch_frames[_]) for _ in range(gpu_num)]) + seqs = [[ + np.concatenate([ + seqs[i][j] + for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1)) + if i < batch_size + ], 0) for _ in range(gpu_num)] + for j in range(feature_num)] + seqs = [np.asarray([ + np.pad(seqs[j][_], + ((0, max_sum_frame - seqs[j][_].shape[0]), (0, 0), (0, 0)), + 'constant', + constant_values=0) + for _ in range(gpu_num)]) + for j in range(feature_num)] + batch[4] = np.asarray(batch_frames) + + batch[0] = seqs + return batch + + def fit(self): + if self.restore_iter != 0: + self.load(self.restore_iter) + + self.encoder.train() + self.sample_type = 'random' + for param_group in self.optimizer.param_groups: + param_group['lr'] = self.lr + triplet_sampler = TripletSampler(self.train_source, self.batch_size) + train_loader = tordata.DataLoader( + dataset=self.train_source, + batch_sampler=triplet_sampler, + collate_fn=self.collate_fn, + num_workers=self.num_workers) + + train_label_set = list(self.train_source.label_set) + train_label_set.sort() + + _time1 = datetime.now() + for seq, view, seq_type, label, batch_frame in train_loader: + self.restore_iter += 1 + self.optimizer.zero_grad() + + for i in range(len(seq)): + seq[i] = self.np2var(seq[i]).float() + if batch_frame is not None: + batch_frame = self.np2var(batch_frame).int() + + feature, label_prob = self.encoder(*seq, batch_frame) + + target_label = [train_label_set.index(l) for l in label] + target_label = self.np2var(np.array(target_label)).long() + + triplet_feature = feature.permute(1, 0, 2).contiguous() + triplet_label = target_label.unsqueeze(0).repeat(triplet_feature.size(0), 1) + (full_loss_metric, hard_loss_metric, mean_dist, full_loss_num + ) = self.triplet_loss(triplet_feature, triplet_label) + if self.hard_or_full_trip == 'hard': + loss = hard_loss_metric.mean() + elif self.hard_or_full_trip == 'full': + loss = full_loss_metric.mean() + + self.hard_loss_metric.append(hard_loss_metric.mean().data.cpu().numpy()) + self.full_loss_metric.append(full_loss_metric.mean().data.cpu().numpy()) + self.full_loss_num.append(full_loss_num.mean().data.cpu().numpy()) + self.dist_list.append(mean_dist.mean().data.cpu().numpy()) + + if loss > 1e-9: + loss.backward() + self.optimizer.step() + + if self.restore_iter % 1000 == 0: + print(datetime.now() - _time1) + _time1 = datetime.now() + + if self.restore_iter % 100 == 0: + self.save() + print('iter {}:'.format(self.restore_iter), end='') + print(', hard_loss_metric={0:.8f}'.format(np.mean(self.hard_loss_metric)), end='') + print(', full_loss_metric={0:.8f}'.format(np.mean(self.full_loss_metric)), end='') + print(', full_loss_num={0:.8f}'.format(np.mean(self.full_loss_num)), end='') + self.mean_dist = np.mean(self.dist_list) + print(', mean_dist={0:.8f}'.format(self.mean_dist), end='') + print(', lr=%f' % self.optimizer.param_groups[0]['lr'], end='') + print(', hard or full=%r' % self.hard_or_full_trip) + sys.stdout.flush() + self.hard_loss_metric = [] + self.full_loss_metric = [] + self.full_loss_num = [] + self.dist_list = [] + + # Visualization using t-SNE + # if self.restore_iter % 500 == 0: + # pca = TSNE(2) + # pca_feature = pca.fit_transform(feature.view(feature.size(0), -1).data.cpu().numpy()) + # for i in range(self.P): + # plt.scatter(pca_feature[self.M * i:self.M * (i + 1), 0], + # pca_feature[self.M * i:self.M * (i + 1), 1], label=label[self.M * i]) + # + # plt.show() + + if self.restore_iter == self.total_iter: + break + + def ts2var(self, x): + return autograd.Variable(x).cuda() + + def np2var(self, x): + return self.ts2var(torch.from_numpy(x)) + + def transform(self, flag, batch_size=1): + self.encoder.eval() + source = self.test_source if flag == 'test' else self.train_source + self.sample_type = 'all' + data_loader = tordata.DataLoader( + dataset=source, + batch_size=batch_size, + sampler=tordata.sampler.SequentialSampler(source), + collate_fn=self.collate_fn, + num_workers=self.num_workers) + + feature_list = list() + view_list = list() + seq_type_list = list() + label_list = list() + + for i, x in enumerate(data_loader): + seq, view, seq_type, label, batch_frame = x + for j in range(len(seq)): + seq[j] = self.np2var(seq[j]).float() + if batch_frame is not None: + batch_frame = self.np2var(batch_frame).int() + # print(batch_frame, np.sum(batch_frame)) + + feature, _ = self.encoder(*seq, batch_frame) + n, num_bin, _ = feature.size() + feature_list.append(feature.view(n, -1).data.cpu().numpy()) + view_list += view + seq_type_list += seq_type + label_list += label + + return np.concatenate(feature_list, 0), view_list, seq_type_list, label_list + + def save(self): + os.makedirs(osp.join('checkpoint', self.model_name), exist_ok=True) + torch.save(self.encoder.state_dict(), + osp.join('checkpoint', self.model_name, + '{}-{:0>5}-encoder.ptm'.format( + self.save_name, self.restore_iter))) + torch.save(self.optimizer.state_dict(), + osp.join('checkpoint', self.model_name, + '{}-{:0>5}-optimizer.ptm'.format( + self.save_name, self.restore_iter))) + + # restore_iter: iteration index of the checkpoint to load + def load(self, restore_iter): + self.encoder.load_state_dict(torch.load(osp.join( + 'checkpoint', self.model_name, + '{}-{:0>5}-encoder.ptm'.format(self.save_name, restore_iter)))) + self.optimizer.load_state_dict(torch.load(osp.join( + 'checkpoint', self.model_name, + '{}-{:0>5}-optimizer.ptm'.format(self.save_name, restore_iter))))