a b/src/optimize.py
1
# Script for optimizing the ratio parameters
2
3
import torch
4
import torch.nn as nn
5
import os
6
from torch.utils.data import Dataset
7
from torch.utils.data import DataLoader
8
import pickle
9
import numpy as np
10
import time
11
12
POSE_MODEL = 'OpenPose'  # ViTPose_large, ViTPose_base, OpenPose
13
14
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
15
16
17
class Planar_Euclidean_Loss(nn.Module):
18
    '''
19
    Euclidean distance of x,y coordinates between 3D points, ignore z coordinate
20
    '''
21
22
    def __init__(self):
23
        super(Planar_Euclidean_Loss, self).__init__()
24
25
    def forward(self, pred, target):
26
        loss = torch.nn.functional.mse_loss(pred[0:2,:], target[0:2,:])  # only consider xy-axis
27
        return loss
28
29
30
class Target4_Model(nn.Module):
31
    '''
32
    target computation model
33
    '''
34
35
    def __init__(self, r1, r2):
36
        super(Target4_Model, self).__init__()
37
        self.r1 = nn.Parameter(torch.tensor(r1), requires_grad=True)
38
        self.r2 = nn.Parameter(torch.tensor(r2), requires_grad=True)
39
40
    def forward(self, X1, X2, t2):
41
        X3 = X1 + self.r1 * (X2 - X1)
42
        pred = X3 + self.r2 * torch.norm(X3 - X1) * t2
43
        return pred
44
45
46
class PositionDataset(Dataset):
47
    def __init__(self, X1, X2, t2, target):
48
        self.X1 = X1
49
        self.X2 = X2
50
        self.t2 = t2
51
        self.target = target
52
53
    def __len__(self):
54
        return len(self.X1)
55
56
    def __getitem__(self, i):
57
        return self.X1[i], self.X2[i], self.t2[i], self.target[i]
58
59
60
def optimize_side(data, target, params=[0.35, 0.1], epoch=1000, lr=0.01, use_gpu=False):
61
    if use_gpu and torch.cuda.is_available():
62
        device = torch.device('cuda')
63
    else:
64
        device = torch.device('cpu')
65
66
    X1, X2, t2 = data
67
    dataset = PositionDataset(X1, X2, t2, target)
68
    data_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)
69
70
    model = Target4_Model(*params)
71
    model.to(device)
72
    print('model.state_dict():', model.state_dict())
73
    opt = torch.optim.SGD(model.parameters(), lr=lr)
74
    loss_fn = Planar_Euclidean_Loss()
75
76
    for ep in range(epoch):
77
        total_loss = 0
78
        for i, data in enumerate(data_loader):
79
            X1, X2, t2, target = data
80
81
            X1 = X1.to(device)
82
            X2 = X2.to(device)
83
            t2 = t2.to(device)
84
            target = target.to(device)
85
86
            opt.zero_grad()
87
            pred = model.forward(X1, X2, t2)
88
            loss = loss_fn(pred, target)
89
            total_loss += loss.item()
90
            loss.backward()
91
            opt.step()
92
93
        if ep % 100 == 0:
94
            print('epoch: {}, loss: {}, r1: {}, r2: {}'.format(ep, total_loss, model.state_dict()['r1'], model.state_dict()['r2']))
95
96
97
def optimize_front_linear(data, target):
98
    '''
99
    Optimize the two ratios using least square.
100
    Only use x&y values.
101
     :X1, X2 --3D coordinates of the right shoulder and the right hip
102
     :t2 -- the direction vector of the line connecting X3 and target
103
    '''
104
    X1, X2, t2 = data[0], data[1], data[2]
105
106
    # Onlu use x&y values
107
    X1_xy = np.hstack(X1)[0:2,:].T.reshape(-1, 1)   # (2nx1)
108
    X2_xy = np.hstack(X2)[0:2,:].T.reshape(-1, 1)
109
110
    r1_coeff = X2_xy - X1_xy    # 2nx1
111
    r2_coeff = (np.hstack(t2)[0:2,:].T * np.sqrt(((np.hstack(X1)-np.hstack(X2))**2).sum(axis=0)).reshape(-1,1)).reshape(-1,1)  # 2nx1
112
    A = np.hstack([r1_coeff, r2_coeff])
113
    b = np.hstack(target)[0:2,:].T.reshape(-1,1) - X1_xy   # reshape (nx2) to (2nx1)
114
    r = np.linalg.pinv(A) @ b
115
116
    return r
117
118
119
if __name__ == '__main__':
120
121
    print("HPE model: ", POSE_MODEL)
122
    # collect data
123
    target12_data = [[], [], []]   # list of list of np.array: [X1_list, X2_list, t2_list]
124
    target1_GT = []               # list of np.array
125
    target2_GT = []
126
127
    target4_data = [[], [], []]
128
    target4_GT = []
129
130
    for SUBJECT_NAME in os.listdir('data'):
131
        subject_folder_path = os.path.join('data', SUBJECT_NAME)
132
        if os.path.isfile(subject_folder_path):
133
            continue
134
135
        scan_pose = 'front'
136
        with open(subject_folder_path + '/' + scan_pose + '/' + POSE_MODEL + '/position_data.pickle', 'rb') as f:
137
            position_data = pickle.load(f)
138
139
        target12_data[0].append(position_data[scan_pose][0])  # X1
140
        target12_data[1].append(position_data[scan_pose][1])  # X2
141
        target12_data[2].append(position_data[scan_pose][2])  # t2
142
143
        with open(subject_folder_path + '/' + scan_pose + '/two_cam_gt.pickle', 'rb') as f:
144
            ground_truth = pickle.load(f)
145
        target1_GT.append(ground_truth['target_1'])
146
        target2_GT.append(ground_truth['target_2'])
147
148
        # skip outlier for openpose target 4:
149
        if POSE_MODEL == 'OpenPose' and (SUBJECT_NAME == 'charles_xu' or SUBJECT_NAME == 'jingyu_wu'):
150
            continue
151
152
        scan_pose = 'side'
153
        with open(subject_folder_path + '/' + scan_pose + '/' + POSE_MODEL + '/position_data.pickle', 'rb') as f:
154
            position_data = pickle.load(f)
155
156
        target4_data[0].append(position_data[scan_pose][0])  # X1
157
        target4_data[1].append(position_data[scan_pose][1])  # X2
158
        target4_data[2].append(position_data[scan_pose][2])  # t2
159
160
        with open(subject_folder_path + '/' + scan_pose + '/two_cam_gt.pickle', 'rb') as f:
161
            ground_truth = pickle.load(f)
162
        target4_GT.append(ground_truth['target_4'])
163
164
    start_time = time.time()
165
    optimize_side(target4_data, target4_GT)
166
    print('training used {:.3f} s'.format(time.time() - start_time))
167
168
    target1_ratio = optimize_front_linear(target12_data, target1_GT)
169
    target2_ratio = optimize_front_linear(target12_data, target2_GT)
170
    print("target1_ratio: \n", target1_ratio)
171
    print("target2_ratio: \n", target2_ratio)
172
173
174
175
176
177
178
179
180