Diff of /model/model.py [000000] .. [40f229]

Switch to unified view

a b/model/model.py
1
import math
2
import os
3
import os.path as osp
4
import random
5
import sys
6
from datetime import datetime
7
8
import numpy as np
9
import torch
10
import torch.nn as nn
11
import torch.autograd as autograd
12
import torch.optim as optim
13
import torch.utils.data as tordata
14
15
from .network import TripletLoss, SetNet
16
from .utils import TripletSampler
17
18
19
class Model:
20
    def __init__(self,
21
                 hidden_dim,
22
                 lr,
23
                 hard_or_full_trip,
24
                 margin,
25
                 num_workers,
26
                 batch_size,
27
                 restore_iter,
28
                 total_iter,
29
                 save_name,
30
                 train_pid_num,
31
                 frame_num,
32
                 model_name,
33
                 train_source,
34
                 test_source,
35
                 img_size=64):
36
37
        self.save_name = save_name
38
        self.train_pid_num = train_pid_num
39
        self.train_source = train_source
40
        self.test_source = test_source
41
42
        self.hidden_dim = hidden_dim
43
        self.lr = lr
44
        self.hard_or_full_trip = hard_or_full_trip
45
        self.margin = margin
46
        self.frame_num = frame_num
47
        self.num_workers = num_workers
48
        self.batch_size = batch_size
49
        self.model_name = model_name
50
        self.P, self.M = batch_size
51
52
        self.restore_iter = restore_iter
53
        self.total_iter = total_iter
54
55
        self.img_size = img_size
56
57
        self.encoder = SetNet(self.hidden_dim).float()
58
        self.encoder = nn.DataParallel(self.encoder)
59
        self.triplet_loss = TripletLoss(self.P * self.M, self.hard_or_full_trip, self.margin).float()
60
        self.triplet_loss = nn.DataParallel(self.triplet_loss)
61
        self.encoder.cuda()
62
        self.triplet_loss.cuda()
63
64
        self.optimizer = optim.Adam([
65
            {'params': self.encoder.parameters()},
66
        ], lr=self.lr)
67
68
        self.hard_loss_metric = []
69
        self.full_loss_metric = []
70
        self.full_loss_num = []
71
        self.dist_list = []
72
        self.mean_dist = 0.01
73
74
        self.sample_type = 'all'
75
76
    def collate_fn(self, batch):
77
        batch_size = len(batch)
78
        feature_num = len(batch[0][0])
79
        seqs = [batch[i][0] for i in range(batch_size)]
80
        frame_sets = [batch[i][1] for i in range(batch_size)]
81
        view = [batch[i][2] for i in range(batch_size)]
82
        seq_type = [batch[i][3] for i in range(batch_size)]
83
        label = [batch[i][4] for i in range(batch_size)]
84
        batch = [seqs, view, seq_type, label, None]
85
86
        def select_frame(index):
87
            sample = seqs[index]
88
            frame_set = frame_sets[index]
89
            if self.sample_type == 'random':
90
                frame_id_list = random.choices(frame_set, k=self.frame_num)
91
                _ = [feature.loc[frame_id_list].values for feature in sample]
92
            else:
93
                _ = [feature.values for feature in sample]
94
            return _
95
96
        seqs = list(map(select_frame, range(len(seqs))))
97
98
        if self.sample_type == 'random':
99
            seqs = [np.asarray([seqs[i][j] for i in range(batch_size)]) for j in range(feature_num)]
100
        else:
101
            gpu_num = min(torch.cuda.device_count(), batch_size)
102
            batch_per_gpu = math.ceil(batch_size / gpu_num)
103
            batch_frames = [[
104
                                len(frame_sets[i])
105
                                for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1))
106
                                if i < batch_size
107
                                ] for _ in range(gpu_num)]
108
            if len(batch_frames[-1]) != batch_per_gpu:
109
                for _ in range(batch_per_gpu - len(batch_frames[-1])):
110
                    batch_frames[-1].append(0)
111
            max_sum_frame = np.max([np.sum(batch_frames[_]) for _ in range(gpu_num)])
112
            seqs = [[
113
                        np.concatenate([
114
                                           seqs[i][j]
115
                                           for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1))
116
                                           if i < batch_size
117
                                           ], 0) for _ in range(gpu_num)]
118
                    for j in range(feature_num)]
119
            seqs = [np.asarray([
120
                                   np.pad(seqs[j][_],
121
                                          ((0, max_sum_frame - seqs[j][_].shape[0]), (0, 0), (0, 0)),
122
                                          'constant',
123
                                          constant_values=0)
124
                                   for _ in range(gpu_num)])
125
                    for j in range(feature_num)]
126
            batch[4] = np.asarray(batch_frames)
127
128
        batch[0] = seqs
129
        return batch
130
131
    def fit(self):
132
        if self.restore_iter != 0:
133
            self.load(self.restore_iter)
134
135
        self.encoder.train()
136
        self.sample_type = 'random'
137
        for param_group in self.optimizer.param_groups:
138
            param_group['lr'] = self.lr
139
        triplet_sampler = TripletSampler(self.train_source, self.batch_size)
140
        train_loader = tordata.DataLoader(
141
            dataset=self.train_source,
142
            batch_sampler=triplet_sampler,
143
            collate_fn=self.collate_fn,
144
            num_workers=self.num_workers)
145
146
        train_label_set = list(self.train_source.label_set)
147
        train_label_set.sort()
148
149
        _time1 = datetime.now()
150
        for seq, view, seq_type, label, batch_frame in train_loader:
151
            self.restore_iter += 1
152
            self.optimizer.zero_grad()
153
154
            for i in range(len(seq)):
155
                seq[i] = self.np2var(seq[i]).float()
156
            if batch_frame is not None:
157
                batch_frame = self.np2var(batch_frame).int()
158
159
            feature, label_prob = self.encoder(*seq, batch_frame)
160
161
            target_label = [train_label_set.index(l) for l in label]
162
            target_label = self.np2var(np.array(target_label)).long()
163
164
            triplet_feature = feature.permute(1, 0, 2).contiguous()
165
            triplet_label = target_label.unsqueeze(0).repeat(triplet_feature.size(0), 1)
166
            (full_loss_metric, hard_loss_metric, mean_dist, full_loss_num
167
             ) = self.triplet_loss(triplet_feature, triplet_label)
168
            if self.hard_or_full_trip == 'hard':
169
                loss = hard_loss_metric.mean()
170
            elif self.hard_or_full_trip == 'full':
171
                loss = full_loss_metric.mean()
172
173
            self.hard_loss_metric.append(hard_loss_metric.mean().data.cpu().numpy())
174
            self.full_loss_metric.append(full_loss_metric.mean().data.cpu().numpy())
175
            self.full_loss_num.append(full_loss_num.mean().data.cpu().numpy())
176
            self.dist_list.append(mean_dist.mean().data.cpu().numpy())
177
178
            if loss > 1e-9:
179
                loss.backward()
180
                self.optimizer.step()
181
182
            if self.restore_iter % 1000 == 0:
183
                print(datetime.now() - _time1)
184
                _time1 = datetime.now()
185
186
            if self.restore_iter % 100 == 0:
187
                self.save()
188
                print('iter {}:'.format(self.restore_iter), end='')
189
                print(', hard_loss_metric={0:.8f}'.format(np.mean(self.hard_loss_metric)), end='')
190
                print(', full_loss_metric={0:.8f}'.format(np.mean(self.full_loss_metric)), end='')
191
                print(', full_loss_num={0:.8f}'.format(np.mean(self.full_loss_num)), end='')
192
                self.mean_dist = np.mean(self.dist_list)
193
                print(', mean_dist={0:.8f}'.format(self.mean_dist), end='')
194
                print(', lr=%f' % self.optimizer.param_groups[0]['lr'], end='')
195
                print(', hard or full=%r' % self.hard_or_full_trip)
196
                sys.stdout.flush()
197
                self.hard_loss_metric = []
198
                self.full_loss_metric = []
199
                self.full_loss_num = []
200
                self.dist_list = []
201
202
            # Visualization using t-SNE
203
            # if self.restore_iter % 500 == 0:
204
            #     pca = TSNE(2)
205
            #     pca_feature = pca.fit_transform(feature.view(feature.size(0), -1).data.cpu().numpy())
206
            #     for i in range(self.P):
207
            #         plt.scatter(pca_feature[self.M * i:self.M * (i + 1), 0],
208
            #                     pca_feature[self.M * i:self.M * (i + 1), 1], label=label[self.M * i])
209
            #
210
            #     plt.show()
211
212
            if self.restore_iter == self.total_iter:
213
                break
214
215
    def ts2var(self, x):
216
        return autograd.Variable(x).cuda()
217
218
    def np2var(self, x):
219
        return self.ts2var(torch.from_numpy(x))
220
221
    def transform(self, flag, batch_size=1):
222
        self.encoder.eval()
223
        source = self.test_source if flag == 'test' else self.train_source
224
        self.sample_type = 'all'
225
        data_loader = tordata.DataLoader(
226
            dataset=source,
227
            batch_size=batch_size,
228
            sampler=tordata.sampler.SequentialSampler(source),
229
            collate_fn=self.collate_fn,
230
            num_workers=self.num_workers)
231
232
        feature_list = list()
233
        view_list = list()
234
        seq_type_list = list()
235
        label_list = list()
236
237
        for i, x in enumerate(data_loader):
238
            seq, view, seq_type, label, batch_frame = x
239
            for j in range(len(seq)):
240
                seq[j] = self.np2var(seq[j]).float()
241
            if batch_frame is not None:
242
                batch_frame = self.np2var(batch_frame).int()
243
            # print(batch_frame, np.sum(batch_frame))
244
245
            feature, _ = self.encoder(*seq, batch_frame)
246
            n, num_bin, _ = feature.size()
247
            feature_list.append(feature.view(n, -1).data.cpu().numpy())
248
            view_list += view
249
            seq_type_list += seq_type
250
            label_list += label
251
252
        return np.concatenate(feature_list, 0), view_list, seq_type_list, label_list
253
254
    def save(self):
255
        os.makedirs(osp.join('checkpoint', self.model_name), exist_ok=True)
256
        torch.save(self.encoder.state_dict(),
257
                   osp.join('checkpoint', self.model_name,
258
                            '{}-{:0>5}-encoder.ptm'.format(
259
                                self.save_name, self.restore_iter)))
260
        torch.save(self.optimizer.state_dict(),
261
                   osp.join('checkpoint', self.model_name,
262
                            '{}-{:0>5}-optimizer.ptm'.format(
263
                                self.save_name, self.restore_iter)))
264
265
    # restore_iter: iteration index of the checkpoint to load
266
    def load(self, restore_iter):
267
        self.encoder.load_state_dict(torch.load(osp.join(
268
            'checkpoint', self.model_name,
269
            '{}-{:0>5}-encoder.ptm'.format(self.save_name, restore_iter))))
270
        self.optimizer.load_state_dict(torch.load(osp.join(
271
            'checkpoint', self.model_name,
272
            '{}-{:0>5}-optimizer.ptm'.format(self.save_name, restore_iter))))