|
a |
|
b/autoposture_kivy.py |
|
|
1 |
import argparse |
|
|
2 |
import os |
|
|
3 |
import sys |
|
|
4 |
import time |
|
|
5 |
|
|
|
6 |
import cv2 |
|
|
7 |
import matplotlib.pyplot as plt |
|
|
8 |
import numpy as np |
|
|
9 |
import requests |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
import kivy |
|
|
13 |
from kivy.app import App |
|
|
14 |
from kivy.uix.label import Label |
|
|
15 |
from kivy.uix.button import Button |
|
|
16 |
from kivy.uix.boxlayout import BoxLayout |
|
|
17 |
from kivy.uix.checkbox import CheckBox |
|
|
18 |
|
|
|
19 |
yolov7_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'vendor/yolov7') |
|
|
20 |
sys.path.append(yolov7_path) |
|
|
21 |
|
|
|
22 |
from models.experimental import attempt_load |
|
|
23 |
import torch |
|
|
24 |
from torchvision import transforms |
|
|
25 |
from utils.datasets import letterbox |
|
|
26 |
from utils.general import non_max_suppression_kpt, strip_optimizer, xyxy2xywh |
|
|
27 |
from utils.plots import colors, output_to_keypoint, plot_one_box_kpt, plot_skeleton_kpts |
|
|
28 |
from utils.torch_utils import select_device |
|
|
29 |
# from tts.tttest import generate_audios, play_audio |
|
|
30 |
import asyncio |
|
|
31 |
import threading |
|
|
32 |
import websockets |
|
|
33 |
import json |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
HOST = 'localhost' |
|
|
37 |
PORT = '8000' |
|
|
38 |
POSEWEIGHTS = 'src_models/yolov7-w6-pose.pt' |
|
|
39 |
|
|
|
40 |
async def predict_request(payload): |
|
|
41 |
""" |
|
|
42 |
Args: |
|
|
43 |
- payload: {'array': (1, 10, 50) shape (10 frames)} |
|
|
44 |
Returns: |
|
|
45 |
- score: Value between 0 and 1 |
|
|
46 |
- status: Good or bad posture (depending on threshold:0.7) |
|
|
47 |
""" |
|
|
48 |
uri = f"ws://{HOST}:{PORT}" |
|
|
49 |
try: |
|
|
50 |
async with websockets.connect(uri) as ws: |
|
|
51 |
payload_json = json.dumps(payload) |
|
|
52 |
await ws.send(payload_json) |
|
|
53 |
raw_prediction = await ws.recv() |
|
|
54 |
prediction = json.loads(raw_prediction) |
|
|
55 |
score = prediction['score'] |
|
|
56 |
status = prediction['status'] |
|
|
57 |
return score, status |
|
|
58 |
except: |
|
|
59 |
return None, 'server-error' |
|
|
60 |
|
|
|
61 |
|
|
|
62 |
def predict_http_request(payload): |
|
|
63 |
response = requests.post(f"http://{HOST}:{PORT}/predict", json=payload) |
|
|
64 |
if response.status_code == 200: |
|
|
65 |
return response.json() |
|
|
66 |
else: |
|
|
67 |
print("Error:", response.status_code) |
|
|
68 |
print(response.text) |
|
|
69 |
|
|
|
70 |
|
|
|
71 |
|
|
|
72 |
@torch.no_grad() |
|
|
73 |
def run(source, device, separation, length, multiple): |
|
|
74 |
current_score = 0 |
|
|
75 |
current_status = 'good' |
|
|
76 |
# global ap_model |
|
|
77 |
separation = int(separation) |
|
|
78 |
length = int(length) |
|
|
79 |
|
|
|
80 |
frame_count = 0 #count no of frames |
|
|
81 |
total_fps = 0 #count total fps |
|
|
82 |
|
|
|
83 |
device = select_device(opt.device) #select device |
|
|
84 |
model = attempt_load(POSEWEIGHTS, map_location=device) #Load model |
|
|
85 |
_ = model.eval() |
|
|
86 |
names = model.module.names if hasattr(model, 'module') else model.names # get class names |
|
|
87 |
|
|
|
88 |
if source.isnumeric() : |
|
|
89 |
cap = cv2.VideoCapture(int(source)) #pass video to videocapture object |
|
|
90 |
else: |
|
|
91 |
cap = cv2.VideoCapture(source) #pass video to videocapture object |
|
|
92 |
|
|
|
93 |
if (cap.isOpened() == False): #check if videocapture not opened |
|
|
94 |
print('Error while trying to read video. Please check path again') |
|
|
95 |
raise SystemExit() |
|
|
96 |
|
|
|
97 |
else: |
|
|
98 |
frame_width = int(cap.get(3)) #get video frame width |
|
|
99 |
# logic for multiple persons |
|
|
100 |
people = {} |
|
|
101 |
next_object_id = 0 |
|
|
102 |
# logic for single persons |
|
|
103 |
current_sequence = [] |
|
|
104 |
|
|
|
105 |
previous_status = "None" |
|
|
106 |
longevity = 0 # frames spent in the current status |
|
|
107 |
|
|
|
108 |
# generate_audios("good"); generate_audios("bad") |
|
|
109 |
# bad_audio_thread = threading.Thread(target=play_audio, args=["bad"]) |
|
|
110 |
|
|
|
111 |
empty = False |
|
|
112 |
while(cap.isOpened): |
|
|
113 |
ret, frame = cap.read() |
|
|
114 |
if ret: |
|
|
115 |
orig_image = frame |
|
|
116 |
image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB) |
|
|
117 |
image = letterbox(image, (frame_width), stride=64, auto=True)[0] |
|
|
118 |
image = transforms.ToTensor()(image) |
|
|
119 |
image = torch.tensor(np.array([image.numpy()])) |
|
|
120 |
image = image.to(device) |
|
|
121 |
image = image.float() |
|
|
122 |
|
|
|
123 |
with torch.no_grad(): #get predictions |
|
|
124 |
output_data, _ = model(image) |
|
|
125 |
|
|
|
126 |
output_data = non_max_suppression_kpt(output_data, #Apply non max suppression |
|
|
127 |
0.25, # Conf. Threshold. |
|
|
128 |
0.65, # IoU Threshold. |
|
|
129 |
nc=model.yaml['nc'], # Number of classes. |
|
|
130 |
nkpt=model.yaml['nkpt'], # Number of keypoints. |
|
|
131 |
kpt_label=True) |
|
|
132 |
|
|
|
133 |
output = output_to_keypoint(output_data) |
|
|
134 |
if multiple: |
|
|
135 |
if len(output) == 0: |
|
|
136 |
if not empty: |
|
|
137 |
print("Wiping data, waiting for objects to appear in frame") |
|
|
138 |
people = {} |
|
|
139 |
next_object_id = 0 |
|
|
140 |
empty = True |
|
|
141 |
else: |
|
|
142 |
empty = False |
|
|
143 |
else: |
|
|
144 |
if output.shape[0] > 0: |
|
|
145 |
if frame_count % separation == 0: |
|
|
146 |
landmarks = output[0, 7:].T |
|
|
147 |
current_sequence += [landmarks[:-1]] |
|
|
148 |
|
|
|
149 |
if len(current_sequence) == 10: |
|
|
150 |
current_sequence = np.array([current_sequence]) |
|
|
151 |
payload = {'array': current_sequence.tolist() } |
|
|
152 |
response = predict_http_request(payload) |
|
|
153 |
|
|
|
154 |
current_score = response['score'] |
|
|
155 |
|
|
|
156 |
previous_status = current_status |
|
|
157 |
current_status = response['status'] |
|
|
158 |
# score, status = asyncio.run(predict_request(payload)) |
|
|
159 |
# if status == 'server-error': |
|
|
160 |
# print('Server error or server not launched') |
|
|
161 |
# print(score, status) |
|
|
162 |
current_sequence = [] |
|
|
163 |
|
|
|
164 |
# if current_status == previous_status: |
|
|
165 |
# if not bad_audio_thread.is_alive() and longevity < 30: |
|
|
166 |
# longevity += 1 |
|
|
167 |
# else: |
|
|
168 |
# longevity = 0 |
|
|
169 |
# else: |
|
|
170 |
# longevity = 0 |
|
|
171 |
|
|
|
172 |
# if longevity == 30 and current_status == "bad": |
|
|
173 |
# try: |
|
|
174 |
# if not bad_audio_thread.is_alive(): |
|
|
175 |
# bad_audio_thread = threading.Thread(target=play_audio("bad")) |
|
|
176 |
# bad_audio_thread.start() |
|
|
177 |
# except Exception as e: |
|
|
178 |
# pass |
|
|
179 |
|
|
|
180 |
|
|
|
181 |
|
|
|
182 |
|
|
|
183 |
im0 = image[0].permute(1, 2, 0) * 255 # Change format [b, c, h, w] to [h, w, c] for displaying the image. |
|
|
184 |
im0 = im0.cpu().numpy().astype(np.uint8) |
|
|
185 |
|
|
|
186 |
im0 = cv2.cvtColor(im0, cv2.COLOR_RGB2BGR) #reshape image format to (BGR) |
|
|
187 |
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh |
|
|
188 |
|
|
|
189 |
for i, pose in enumerate(output_data): # detections per image |
|
|
190 |
if empty: break |
|
|
191 |
|
|
|
192 |
if len(output_data) == 0: |
|
|
193 |
continue |
|
|
194 |
for det_index, (*xyxy, conf, cls) in enumerate(reversed(pose[:,:6])): #loop over poses for drawing on frame |
|
|
195 |
c = int(cls) # integer class |
|
|
196 |
kpts = pose[det_index, 6:] |
|
|
197 |
|
|
|
198 |
|
|
|
199 |
if multiple: |
|
|
200 |
# get the centroid (cx, cy) for the current rectangle |
|
|
201 |
rect = [tensor.cpu().numpy() for tensor in xyxy] |
|
|
202 |
cx, cy = (rect[0] + rect[2]) / 2, (rect[1] + rect[3]) / 2 |
|
|
203 |
matched_object_id = None |
|
|
204 |
|
|
|
205 |
# iterating through known people |
|
|
206 |
for object_id, data in people.items(): |
|
|
207 |
distance = np.sqrt((cx - data['centroid'][0]) ** 2 + (cy - data['centroid'][1]) ** 2) |
|
|
208 |
print(distance) |
|
|
209 |
if distance < 300: # Adjust the threshold as needed |
|
|
210 |
matched_object_id = object_id |
|
|
211 |
break |
|
|
212 |
|
|
|
213 |
if matched_object_id is None: |
|
|
214 |
matched_object_id = next_object_id |
|
|
215 |
next_object_id += 1 |
|
|
216 |
|
|
|
217 |
if matched_object_id not in people: |
|
|
218 |
people[matched_object_id] = {'centroid': (cx, cy), 'yoloid': det_index, 'status': 'good', 'score': 0, 'sequence' : []} |
|
|
219 |
else: |
|
|
220 |
people[matched_object_id]['centroid'] = (cx, cy) |
|
|
221 |
people[matched_object_id]['yoloid'] = det_index |
|
|
222 |
|
|
|
223 |
obj = people[matched_object_id] |
|
|
224 |
label = f"ID #{obj['yoloid']} Score: {obj['score']:.2f}" |
|
|
225 |
plot_one_box_kpt(xyxy, im0, label=label, color=colors(c, True), |
|
|
226 |
line_thickness=3, kpt_label=True, kpts=kpts, steps=3, |
|
|
227 |
cmap=people[matched_object_id]['status']) |
|
|
228 |
else: |
|
|
229 |
label = f"ID #{0} Score: {current_score:.2f}" |
|
|
230 |
plot_one_box_kpt(xyxy, im0, label=label, color=colors(c, True), |
|
|
231 |
line_thickness=3,kpt_label=True, kpts=kpts, steps=3, |
|
|
232 |
cmap=current_status) |
|
|
233 |
|
|
|
234 |
|
|
|
235 |
if frame_count % separation == 0 and multiple: |
|
|
236 |
for _, data in people.items(): |
|
|
237 |
if data['yoloid'] < output.shape[0]: |
|
|
238 |
yoloid = data['yoloid'] |
|
|
239 |
landmarks = output[yoloid, 7:].T |
|
|
240 |
data['sequence'] += [landmarks[:-1]] |
|
|
241 |
|
|
|
242 |
if len(data['sequence']) == length: |
|
|
243 |
payload = {'array': np.array([data['sequence']]).tolist()} |
|
|
244 |
response = predict_http_request(payload) |
|
|
245 |
|
|
|
246 |
data['score'] = response['score'] |
|
|
247 |
data['status'] = response['status'] |
|
|
248 |
data['sequence'] = [] |
|
|
249 |
|
|
|
250 |
# print(f"{data['yoloid']} -> {data['status']}", end=' ') |
|
|
251 |
else: |
|
|
252 |
data['sequence'] = [] |
|
|
253 |
|
|
|
254 |
statuses = [(people[p]['yoloid'], people[p]['status']) for p in people] |
|
|
255 |
# for id, status in statuses: |
|
|
256 |
# print(f'{id}: {status}', end='\t') |
|
|
257 |
# print() |
|
|
258 |
|
|
|
259 |
|
|
|
260 |
frame_count += 1 |
|
|
261 |
|
|
|
262 |
|
|
|
263 |
cv2.imshow("YOLOv7 Pose Estimation Demo", im0) |
|
|
264 |
key = cv2.waitKey(1) & 0xFF # Wait for 1 millisecond and get the pressed key |
|
|
265 |
if key == ord('q'): |
|
|
266 |
cv2.destroyAllWindows() # Close the window if 'q' is pressed |
|
|
267 |
break |
|
|
268 |
else: |
|
|
269 |
break |
|
|
270 |
|
|
|
271 |
cap.release() |
|
|
272 |
|
|
|
273 |
|
|
|
274 |
class AutoPostureApp(App): |
|
|
275 |
|
|
|
276 |
def __init__(self, opt): |
|
|
277 |
super().__init__() |
|
|
278 |
self.opt = opt |
|
|
279 |
self.source = opt.source |
|
|
280 |
self.device = opt.device |
|
|
281 |
self.separation = opt.separation |
|
|
282 |
self.length = opt.length |
|
|
283 |
self.multiple = opt.multiple |
|
|
284 |
|
|
|
285 |
|
|
|
286 |
def build(self): |
|
|
287 |
# Create a layout for the GUI |
|
|
288 |
layout = BoxLayout(orientation='vertical') |
|
|
289 |
|
|
|
290 |
exit_button = Button(text='Exit Pose Estimation') |
|
|
291 |
layout.add_widget(exit_button) |
|
|
292 |
exit_button.bind(on_press=self.exit_app) |
|
|
293 |
|
|
|
294 |
# Create a title label |
|
|
295 |
title_label = Label(text='AutoPosture', font_size=20) |
|
|
296 |
layout.add_widget(title_label) |
|
|
297 |
|
|
|
298 |
# Create a checkbox to toggle the webcam display |
|
|
299 |
self.webcam_checkbox = CheckBox() |
|
|
300 |
self.webcam_checkbox.text = "View webcam" |
|
|
301 |
layout.add_widget(self.webcam_checkbox) |
|
|
302 |
|
|
|
303 |
# Create a button to start the pose estimation process |
|
|
304 |
start_button = Button(text='Start') |
|
|
305 |
layout.add_widget(start_button) |
|
|
306 |
|
|
|
307 |
# self.start_button.bind(on_press=lambda instance: self.runPoseEstimation(self.opt)) |
|
|
308 |
|
|
|
309 |
start_button.bind(on_press=self.runPoseEstimation) # Bind the button to the action |
|
|
310 |
|
|
|
311 |
return layout |
|
|
312 |
|
|
|
313 |
def runPoseEstimation(self, instance): |
|
|
314 |
run(self.source, self.opt.device, self.opt.separation, self.opt.length, self.opt.multiple) |
|
|
315 |
|
|
|
316 |
def exit_app(self, instance): |
|
|
317 |
App.get_running_app().stop() |
|
|
318 |
|
|
|
319 |
|
|
|
320 |
def parse_opt(): |
|
|
321 |
parser = argparse.ArgumentParser() |
|
|
322 |
parser.add_argument('-s', '--source', type=str, default='video/0', help='video/0 for webcam') #video source |
|
|
323 |
parser.add_argument('-d', '--device', type=str, default='cpu', help='cpu/0,1,2,3(gpu)') #device arugments |
|
|
324 |
parser.add_argument('-sep', '--separation', type=str, default='1', help='Each how many frames the prediction will be executed. Defaults to 1, increase for performance') #separation arugments |
|
|
325 |
parser.add_argument('-l', '--length', type=str, default='10', help='Defines the length of the sequence. Defaults to 10, decrease for performance') #separation arugments |
|
|
326 |
parser.add_argument('-m', '--multiple', default=False, action='store_true', help='Enable multiple-person detection') # Boolean for multiple person detection |
|
|
327 |
opt = parser.parse_args() |
|
|
328 |
return opt |
|
|
329 |
|
|
|
330 |
|
|
|
331 |
def main(opt): |
|
|
332 |
app = AutoPostureApp(opt) |
|
|
333 |
app.run() |
|
|
334 |
|
|
|
335 |
if __name__ == '__main__': |
|
|
336 |
opt = parse_opt() |
|
|
337 |
main(opt) |