|
a |
|
b/landmark_extraction/pose-estimate.py |
|
|
1 |
import cv2 |
|
|
2 |
import time |
|
|
3 |
import torch |
|
|
4 |
import argparse |
|
|
5 |
import numpy as np |
|
|
6 |
import matplotlib |
|
|
7 |
import matplotlib.pyplot as plt |
|
|
8 |
from torchvision import transforms |
|
|
9 |
from utils.datasets import letterbox |
|
|
10 |
from utils.torch_utils import select_device |
|
|
11 |
from models.experimental import attempt_load |
|
|
12 |
from utils.general import non_max_suppression_kpt,strip_optimizer,xyxy2xywh |
|
|
13 |
from utils.plots import output_to_keypoint, plot_skeleton_kpts,colors,plot_one_box_kpt |
|
|
14 |
|
|
|
15 |
matplotlib.use('TkAgg') |
|
|
16 |
|
|
|
17 |
@torch.no_grad() |
|
|
18 |
def run(poseweights="yolov7-w6-pose.pt",source="football1.mp4",device='cpu',view_img=False, |
|
|
19 |
save_conf=False,line_thickness = 3,hide_labels=False, hide_conf=True): |
|
|
20 |
|
|
|
21 |
frame_count = 0 # count no of frames |
|
|
22 |
total_fps = 0 # count total fps |
|
|
23 |
time_list = [] # list to store time |
|
|
24 |
fps_list = [] # list to store fps |
|
|
25 |
fps_graph = [] |
|
|
26 |
|
|
|
27 |
device = select_device(opt.device) #select device |
|
|
28 |
half = device.type != 'cpu' |
|
|
29 |
|
|
|
30 |
model = attempt_load(poseweights, map_location=device) #Load model |
|
|
31 |
_ = model.eval() |
|
|
32 |
names = model.module.names if hasattr(model, 'module') else model.names # get class names |
|
|
33 |
|
|
|
34 |
if source.isnumeric() : |
|
|
35 |
cap = cv2.VideoCapture(int(source)) #pass video to videocapture object |
|
|
36 |
else : |
|
|
37 |
cap = cv2.VideoCapture(source) #pass video to videocapture object |
|
|
38 |
|
|
|
39 |
if (cap.isOpened() == False): #check if videocapture not opened |
|
|
40 |
print('Error while trying to read video. Please check path again') |
|
|
41 |
raise SystemExit() |
|
|
42 |
|
|
|
43 |
else: |
|
|
44 |
frame_width = int(cap.get(3)) #get video frame width |
|
|
45 |
frame_height = int(cap.get(4)) #get video frame height |
|
|
46 |
|
|
|
47 |
|
|
|
48 |
vid_write_image = letterbox(cap.read()[1], (frame_width), stride=64, auto=True)[0] #init videowriter |
|
|
49 |
resize_height, resize_width = vid_write_image.shape[:2] |
|
|
50 |
out_video_name = f"{source.split('/')[-1].split('.')[0]}" |
|
|
51 |
out = cv2.VideoWriter(f"{source}_keypoint.mp4", |
|
|
52 |
cv2.VideoWriter_fourcc(*'mp4v'), 30, |
|
|
53 |
(resize_width, resize_height)) |
|
|
54 |
|
|
|
55 |
while(cap.isOpened): #loop until cap opened or video not complete |
|
|
56 |
|
|
|
57 |
# print("Frame {} Processing".format(frame_count+1)) |
|
|
58 |
|
|
|
59 |
ret, frame = cap.read() #get frame and success from video capture |
|
|
60 |
|
|
|
61 |
if ret: #if success is true, means frame exist |
|
|
62 |
orig_image = frame #store frame |
|
|
63 |
image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB) #convert frame to RGB |
|
|
64 |
image = letterbox(image, (frame_width), stride=64, auto=True)[0] |
|
|
65 |
image_ = image.copy() |
|
|
66 |
image = transforms.ToTensor()(image) |
|
|
67 |
image = torch.tensor(np.array([image.numpy()])) |
|
|
68 |
|
|
|
69 |
image = image.to(device) #convert image data to device |
|
|
70 |
image = image.float() #convert image to float precision (cpu) |
|
|
71 |
start_time = time.time() #start time for fps calculation |
|
|
72 |
|
|
|
73 |
with torch.no_grad(): #get predictions |
|
|
74 |
output_data, _ = model(image) |
|
|
75 |
|
|
|
76 |
# keypoint_file = open("keypoints.md", "a") |
|
|
77 |
# keypoint_file.write(str(output_data)) |
|
|
78 |
# keypoint_file.close() |
|
|
79 |
|
|
|
80 |
|
|
|
81 |
|
|
|
82 |
output_data = non_max_suppression_kpt(output_data, #Apply non max suppression |
|
|
83 |
0.25, # Conf. Threshold. |
|
|
84 |
0.65, # IoU Threshold. |
|
|
85 |
nc=model.yaml['nc'], # Number of classes. |
|
|
86 |
nkpt=model.yaml['nkpt'], # Number of keypoints. |
|
|
87 |
kpt_label=True) |
|
|
88 |
|
|
|
89 |
output = output_to_keypoint(output_data) |
|
|
90 |
|
|
|
91 |
im0 = image[0].permute(1, 2, 0) * 255 # Change format [b, c, h, w] to [h, w, c] for displaying the image. |
|
|
92 |
im0 = im0.cpu().numpy().astype(np.uint8) |
|
|
93 |
|
|
|
94 |
im0 = cv2.cvtColor(im0, cv2.COLOR_RGB2BGR) #reshape image format to (BGR) |
|
|
95 |
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh |
|
|
96 |
|
|
|
97 |
for i, pose in enumerate(output_data): # detections per image |
|
|
98 |
|
|
|
99 |
if len(output_data): #check if no pose |
|
|
100 |
for c in pose[:, 5].unique(): # Print results |
|
|
101 |
n = (pose[:, 5] == c).sum() # detections per class |
|
|
102 |
print("No of Objects in Current Frame : {}".format(n)) |
|
|
103 |
|
|
|
104 |
for det_index, (*xyxy, conf, cls) in enumerate(reversed(pose[:,:6])): #loop over poses for drawing on frame |
|
|
105 |
c = int(cls) # integer class |
|
|
106 |
kpts = pose[det_index, 6:] |
|
|
107 |
label = None if opt.hide_labels else (names[c] if opt.hide_conf else f'{names[c]} {conf:.2f}') |
|
|
108 |
plot_one_box_kpt(xyxy, im0, label=label, color=colors(c, True), |
|
|
109 |
line_thickness=opt.line_thickness,kpt_label=True, kpts=kpts, steps=3, |
|
|
110 |
orig_shape=im0.shape[:2]) |
|
|
111 |
|
|
|
112 |
|
|
|
113 |
end_time = time.time() #Calculatio for FPS |
|
|
114 |
fps = 1 / (end_time - start_time) |
|
|
115 |
total_fps += fps |
|
|
116 |
frame_count += 1 |
|
|
117 |
|
|
|
118 |
fps_list.append(total_fps) #append FPS in list |
|
|
119 |
time_list.append(end_time - start_time) #append time in list |
|
|
120 |
fps_graph.append(fps) |
|
|
121 |
|
|
|
122 |
cv2.imshow("YOLOv7 Pose Estimation Demo", im0) |
|
|
123 |
# Press Q on keyboard to exit |
|
|
124 |
if cv2.waitKey(25) & 0xFF == ord('q'): |
|
|
125 |
break |
|
|
126 |
|
|
|
127 |
else: |
|
|
128 |
break |
|
|
129 |
|
|
|
130 |
cap.release() |
|
|
131 |
cv2.destroyAllWindows() |
|
|
132 |
avg_fps = total_fps / frame_count |
|
|
133 |
print(f"Average FPS: {avg_fps:.3f}") |
|
|
134 |
|
|
|
135 |
#plot the comparision graph |
|
|
136 |
frame_indices = range(len(fps_graph)) |
|
|
137 |
fig, ax = plt.subplots() |
|
|
138 |
ax.plot(frame_indices, fps_graph, color='blue', linestyle='-') |
|
|
139 |
ax.set_xlabel('Frame Index') |
|
|
140 |
ax.set_ylabel('FPS') |
|
|
141 |
ax.set_title('Frames per Second (FPS)') |
|
|
142 |
ax.grid(True, linestyle='--', alpha=0.5) |
|
|
143 |
ax.set_facecolor('#f0f0f0') |
|
|
144 |
plt.savefig("FPS_graph.png") |
|
|
145 |
plt.show() |
|
|
146 |
# plot_fps_time_comparision(time_list=time_list,fps_list=fps_graph) |
|
|
147 |
|
|
|
148 |
|
|
|
149 |
def parse_opt(): |
|
|
150 |
parser = argparse.ArgumentParser() |
|
|
151 |
parser.add_argument('--poseweights', nargs='+', type=str, default='yolov7-w6-pose.pt', help='model path(s)') |
|
|
152 |
parser.add_argument('--source', type=str, default='football1.mp4', help='video/0 for webcam') #video source |
|
|
153 |
parser.add_argument('--device', type=str, default='cpu', help='cpu/0,1,2,3(gpu)') #device arugments |
|
|
154 |
parser.add_argument('--view-img', action='store_true', help='display results') #display results |
|
|
155 |
parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') #save confidence in txt writing |
|
|
156 |
parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)') #box linethickness |
|
|
157 |
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels') #box hidelabel |
|
|
158 |
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences') #boxhideconf |
|
|
159 |
opt = parser.parse_args() |
|
|
160 |
return opt |
|
|
161 |
|
|
|
162 |
#function for plot fps and time comparision graph |
|
|
163 |
def plot_fps_time_comparision(time_list,fps_list): |
|
|
164 |
plt.figure() |
|
|
165 |
plt.xlabel('Time (s)') |
|
|
166 |
plt.ylabel('FPS') |
|
|
167 |
plt.title('FPS and Time Comparision Graph') |
|
|
168 |
plt.plot(time_list, fps_list,'b',label="FPS & Time") |
|
|
169 |
plt.savefig("FPS_and_Time_Comparision_pose_estimate.png") |
|
|
170 |
|
|
|
171 |
|
|
|
172 |
#main function |
|
|
173 |
def main(opt): |
|
|
174 |
run(**vars(opt)) |
|
|
175 |
|
|
|
176 |
if __name__ == "__main__": |
|
|
177 |
opt = parse_opt() |
|
|
178 |
dev = int(opt.device) if opt.device.isnumeric() else opt.device |
|
|
179 |
strip_optimizer(dev,opt.poseweights) |
|
|
180 |
main(opt) |