|
a |
|
b/autoposture.py |
|
|
1 |
import argparse |
|
|
2 |
import os |
|
|
3 |
import sys |
|
|
4 |
import time |
|
|
5 |
|
|
|
6 |
import cv2 |
|
|
7 |
import matplotlib.pyplot as plt |
|
|
8 |
import numpy as np |
|
|
9 |
import requests |
|
|
10 |
from internal.prediction_client import predict_http_request |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
yolov7_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'vendor/yolov7') |
|
|
14 |
sys.path.append(yolov7_path) |
|
|
15 |
|
|
|
16 |
from models.experimental import attempt_load |
|
|
17 |
import torch |
|
|
18 |
from torchvision import transforms |
|
|
19 |
from utils.datasets import letterbox |
|
|
20 |
from utils.general import non_max_suppression_kpt, strip_optimizer, xyxy2xywh |
|
|
21 |
from utils.plots import colors, output_to_keypoint, plot_one_box_kpt, plot_skeleton_kpts |
|
|
22 |
from utils.torch_utils import select_device |
|
|
23 |
# from tts.tttest import generate_audios, play_audio |
|
|
24 |
import asyncio |
|
|
25 |
import threading |
|
|
26 |
import websockets |
|
|
27 |
import json |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
POSEWEIGHTS = 'src_models/yolov7-w6-pose.pt' |
|
|
31 |
|
|
|
32 |
|
|
|
33 |
@torch.no_grad() |
|
|
34 |
def run(source, device, separation, length, multiple): |
|
|
35 |
# global ap_model |
|
|
36 |
separation = int(separation) |
|
|
37 |
length = int(length) |
|
|
38 |
|
|
|
39 |
frame_count = 0 #count no of frames |
|
|
40 |
total_fps = 0 #count total fps |
|
|
41 |
|
|
|
42 |
device = select_device(opt.device) #select device |
|
|
43 |
model = attempt_load(POSEWEIGHTS, map_location=device) #Load model |
|
|
44 |
_ = model.eval() |
|
|
45 |
names = model.module.names if hasattr(model, 'module') else model.names # get class names |
|
|
46 |
|
|
|
47 |
if source.isnumeric() : |
|
|
48 |
cap = cv2.VideoCapture(int(source)) #pass video to videocapture object |
|
|
49 |
else: |
|
|
50 |
cap = cv2.VideoCapture(source) #pass video to videocapture object |
|
|
51 |
|
|
|
52 |
if (cap.isOpened() == False): #check if videocapture not opened |
|
|
53 |
print('Error while trying to read video. Please check path again') |
|
|
54 |
raise SystemExit() |
|
|
55 |
|
|
|
56 |
else: |
|
|
57 |
frame_width = int(cap.get(3)) #get video frame width |
|
|
58 |
# logic for multiple persons |
|
|
59 |
people = {} |
|
|
60 |
next_object_id = 0 |
|
|
61 |
# logic for single persons |
|
|
62 |
current_sequence = [] |
|
|
63 |
current_score = 0 |
|
|
64 |
current_status = 'good' |
|
|
65 |
previous_status = "None" |
|
|
66 |
longevity = 0 # frames spent in the current status |
|
|
67 |
|
|
|
68 |
# generate_audios("good"); generate_audios("bad") |
|
|
69 |
# bad_audio_thread = threading.Thread(target=play_audio, args=["bad"]) |
|
|
70 |
|
|
|
71 |
empty = False |
|
|
72 |
while(cap.isOpened): |
|
|
73 |
ret, frame = cap.read() |
|
|
74 |
if ret: |
|
|
75 |
orig_image = frame |
|
|
76 |
image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB) |
|
|
77 |
image = letterbox(image, (frame_width), stride=64, auto=True)[0] |
|
|
78 |
image = transforms.ToTensor()(image) |
|
|
79 |
image = torch.tensor(np.array([image.numpy()])) |
|
|
80 |
image = image.to(device) |
|
|
81 |
image = image.float() |
|
|
82 |
|
|
|
83 |
with torch.no_grad(): #get predictions |
|
|
84 |
output_data, _ = model(image) |
|
|
85 |
|
|
|
86 |
output_data = non_max_suppression_kpt(output_data, #Apply non max suppression |
|
|
87 |
0.25, # Conf. Threshold. |
|
|
88 |
0.65, # IoU Threshold. |
|
|
89 |
nc=model.yaml['nc'], # Number of classes. |
|
|
90 |
nkpt=model.yaml['nkpt'], # Number of keypoints. |
|
|
91 |
kpt_label=True) |
|
|
92 |
|
|
|
93 |
output = output_to_keypoint(output_data) |
|
|
94 |
if multiple: |
|
|
95 |
if len(output) == 0: |
|
|
96 |
if not empty: |
|
|
97 |
print("Wiping data, waiting for objects to appear in frame") |
|
|
98 |
people = {} |
|
|
99 |
next_object_id = 0 |
|
|
100 |
empty = True |
|
|
101 |
else: |
|
|
102 |
empty = False |
|
|
103 |
else: |
|
|
104 |
if output.shape[0] > 0: |
|
|
105 |
if frame_count % separation == 0: |
|
|
106 |
landmarks = output[0, 7:].T |
|
|
107 |
current_sequence += [landmarks[:-1]] |
|
|
108 |
|
|
|
109 |
if len(current_sequence) == 10: |
|
|
110 |
current_sequence = np.array([current_sequence]) |
|
|
111 |
payload = {'array': current_sequence.tolist() } |
|
|
112 |
response = predict_http_request(payload) |
|
|
113 |
|
|
|
114 |
current_score = response['score'] |
|
|
115 |
|
|
|
116 |
previous_status = current_status |
|
|
117 |
current_status = response['status'] |
|
|
118 |
# score, status = asyncio.run(predict_request(payload)) |
|
|
119 |
# if status == 'server-error': |
|
|
120 |
# print('Server error or server not launched') |
|
|
121 |
# print(score, status) |
|
|
122 |
current_sequence = [] |
|
|
123 |
|
|
|
124 |
# if current_status == previous_status: |
|
|
125 |
# if not bad_audio_thread.is_alive() and longevity < 30: |
|
|
126 |
# longevity += 1 |
|
|
127 |
# else: |
|
|
128 |
# longevity = 0 |
|
|
129 |
# else: |
|
|
130 |
# longevity = 0 |
|
|
131 |
|
|
|
132 |
# if longevity == 30 and current_status == "bad": |
|
|
133 |
# try: |
|
|
134 |
# if not bad_audio_thread.is_alive(): |
|
|
135 |
# bad_audio_thread = threading.Thread(target=play_audio("bad")) |
|
|
136 |
# bad_audio_thread.start() |
|
|
137 |
# except Exception as e: |
|
|
138 |
# pass |
|
|
139 |
|
|
|
140 |
|
|
|
141 |
|
|
|
142 |
|
|
|
143 |
im0 = image[0].permute(1, 2, 0) * 255 # Change format [b, c, h, w] to [h, w, c] for displaying the image. |
|
|
144 |
im0 = im0.cpu().numpy().astype(np.uint8) |
|
|
145 |
|
|
|
146 |
im0 = cv2.cvtColor(im0, cv2.COLOR_RGB2BGR) #reshape image format to (BGR) |
|
|
147 |
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh |
|
|
148 |
|
|
|
149 |
for i, pose in enumerate(output_data): # detections per image |
|
|
150 |
if empty: break |
|
|
151 |
|
|
|
152 |
if len(output_data) == 0: |
|
|
153 |
continue |
|
|
154 |
for det_index, (*xyxy, conf, cls) in enumerate(reversed(pose[:,:6])): #loop over poses for drawing on frame |
|
|
155 |
c = int(cls) # integer class |
|
|
156 |
kpts = pose[det_index, 6:] |
|
|
157 |
|
|
|
158 |
|
|
|
159 |
if multiple: |
|
|
160 |
# get the centroid (cx, cy) for the current rectangle |
|
|
161 |
rect = [tensor.cpu().numpy() for tensor in xyxy] |
|
|
162 |
cx, cy = (rect[0] + rect[2]) / 2, (rect[1] + rect[3]) / 2 |
|
|
163 |
matched_object_id = None |
|
|
164 |
|
|
|
165 |
# iterating through known people |
|
|
166 |
for object_id, data in people.items(): |
|
|
167 |
distance = np.sqrt((cx - data['centroid'][0]) ** 2 + (cy - data['centroid'][1]) ** 2) |
|
|
168 |
print(distance) |
|
|
169 |
if distance < 300: # Adjust the threshold as needed |
|
|
170 |
matched_object_id = object_id |
|
|
171 |
break |
|
|
172 |
|
|
|
173 |
if matched_object_id is None: |
|
|
174 |
matched_object_id = next_object_id |
|
|
175 |
next_object_id += 1 |
|
|
176 |
|
|
|
177 |
if matched_object_id not in people: |
|
|
178 |
people[matched_object_id] = {'centroid': (cx, cy), 'yoloid': det_index, 'status': 'good', 'score': 0, 'sequence' : []} |
|
|
179 |
else: |
|
|
180 |
people[matched_object_id]['centroid'] = (cx, cy) |
|
|
181 |
people[matched_object_id]['yoloid'] = det_index |
|
|
182 |
|
|
|
183 |
obj = people[matched_object_id] |
|
|
184 |
label = f"ID #{obj['yoloid']} Score: {obj['score']:.2f}" |
|
|
185 |
plot_one_box_kpt(xyxy, im0, label=label, color=colors(c, True), |
|
|
186 |
line_thickness=3, kpt_label=True, kpts=kpts, steps=3, |
|
|
187 |
cmap=people[matched_object_id]['status']) |
|
|
188 |
else: |
|
|
189 |
label = f"ID #{0} Score: {current_score:.2f}" |
|
|
190 |
plot_one_box_kpt(xyxy, im0, label=label, color=colors(c, True), |
|
|
191 |
line_thickness=3,kpt_label=True, kpts=kpts, steps=3, |
|
|
192 |
cmap=current_status) |
|
|
193 |
|
|
|
194 |
|
|
|
195 |
if frame_count % separation == 0 and multiple: |
|
|
196 |
for _, data in people.items(): |
|
|
197 |
if data['yoloid'] < output.shape[0]: |
|
|
198 |
yoloid = data['yoloid'] |
|
|
199 |
landmarks = output[yoloid, 7:].T |
|
|
200 |
data['sequence'] += [landmarks[:-1]] |
|
|
201 |
|
|
|
202 |
if len(data['sequence']) == length: |
|
|
203 |
payload = {'array': np.array([data['sequence']]).tolist()} |
|
|
204 |
response = predict_http_request(payload) |
|
|
205 |
|
|
|
206 |
data['score'] = response['score'] |
|
|
207 |
data['status'] = response['status'] |
|
|
208 |
data['sequence'] = [] |
|
|
209 |
|
|
|
210 |
# print(f"{data['yoloid']} -> {data['status']}", end=' ') |
|
|
211 |
else: |
|
|
212 |
data['sequence'] = [] |
|
|
213 |
|
|
|
214 |
statuses = [(people[p]['yoloid'], people[p]['status']) for p in people] |
|
|
215 |
# for id, status in statuses: |
|
|
216 |
# print(f'{id}: {status}', end='\t') |
|
|
217 |
# print() |
|
|
218 |
|
|
|
219 |
|
|
|
220 |
frame_count += 1 |
|
|
221 |
|
|
|
222 |
|
|
|
223 |
cv2.imshow("YOLOv7 Pose Estimation Demo", im0) |
|
|
224 |
key = cv2.waitKey(1) & 0xFF # Wait for 1 millisecond and get the pressed key |
|
|
225 |
if key == ord('q'): |
|
|
226 |
cv2.destroyAllWindows() # Close the window if 'q' is pressed |
|
|
227 |
break |
|
|
228 |
else: |
|
|
229 |
break |
|
|
230 |
|
|
|
231 |
cap.release() |
|
|
232 |
|
|
|
233 |
def parse_opt(): |
|
|
234 |
parser = argparse.ArgumentParser() |
|
|
235 |
parser.add_argument('-s', '--source', type=str, default='video/0', help='video/0 for webcam') #video source |
|
|
236 |
parser.add_argument('-d', '--device', type=str, default='cpu', help='cpu/0,1,2,3(gpu)') #device arugments |
|
|
237 |
parser.add_argument('-sep', '--separation', type=str, default='1', help='Each how many frames the prediction will be executed. Defaults to 1, increase for performance') #separation arugments |
|
|
238 |
parser.add_argument('-l', '--length', type=str, default='10', help='Defines the length of the sequence. Defaults to 10, decrease for performance') #separation arugments |
|
|
239 |
parser.add_argument('-m', '--multiple', default=False, action='store_true', help='Enable multiple-person detection') # Boolean for multiple person detection |
|
|
240 |
opt = parser.parse_args() |
|
|
241 |
return opt |
|
|
242 |
|
|
|
243 |
|
|
|
244 |
def main(opt): |
|
|
245 |
run(**vars(opt)) |
|
|
246 |
|
|
|
247 |
if __name__ == "__main__": |
|
|
248 |
opt = parse_opt() |
|
|
249 |
strip_optimizer(opt.device, POSEWEIGHTS) |
|
|
250 |
main(opt) |