Diff of /autoposture_kivy.py [000000] .. [a5e8ec]

Switch to unified view

a b/autoposture_kivy.py
1
import argparse
2
import os
3
import sys
4
import time
5
6
import cv2
7
import matplotlib.pyplot as plt
8
import numpy as np
9
import requests
10
11
12
import kivy
13
from kivy.app import App
14
from kivy.uix.label import Label
15
from kivy.uix.button import Button
16
from kivy.uix.boxlayout import BoxLayout
17
from kivy.uix.checkbox import CheckBox
18
19
yolov7_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'vendor/yolov7')
20
sys.path.append(yolov7_path)
21
22
from models.experimental import attempt_load
23
import torch
24
from torchvision import transforms
25
from utils.datasets import letterbox
26
from utils.general import non_max_suppression_kpt, strip_optimizer, xyxy2xywh
27
from utils.plots import colors, output_to_keypoint, plot_one_box_kpt, plot_skeleton_kpts
28
from utils.torch_utils import select_device
29
# from tts.tttest import generate_audios, play_audio
30
import asyncio
31
import threading
32
import websockets
33
import json
34
35
36
HOST = 'localhost'
37
PORT = '8000'
38
POSEWEIGHTS = 'src_models/yolov7-w6-pose.pt'
39
40
async def predict_request(payload):
41
    """
42
    Args:
43
        - payload: {'array': (1, 10, 50) shape (10 frames)}
44
    Returns:
45
        - score: Value between 0 and 1
46
        - status: Good or bad posture (depending on threshold:0.7)
47
    """
48
    uri = f"ws://{HOST}:{PORT}"
49
    try:
50
        async with websockets.connect(uri) as ws:
51
            payload_json = json.dumps(payload)
52
            await ws.send(payload_json)
53
            raw_prediction = await ws.recv()
54
            prediction = json.loads(raw_prediction)
55
            score = prediction['score']
56
            status = prediction['status']
57
            return score, status
58
    except:
59
        return None, 'server-error'
60
61
62
def predict_http_request(payload):
63
    response = requests.post(f"http://{HOST}:{PORT}/predict", json=payload)
64
    if response.status_code == 200:
65
        return response.json()
66
    else:
67
        print("Error:", response.status_code)
68
        print(response.text)
69
70
71
72
@torch.no_grad()
73
def run(source, device, separation, length, multiple):
74
    current_score = 0
75
    current_status = 'good'
76
    # global ap_model
77
    separation = int(separation)
78
    length = int(length)
79
80
    frame_count = 0  #count no of frames
81
    total_fps = 0  #count total fps
82
    
83
    device = select_device(opt.device) #select device
84
    model = attempt_load(POSEWEIGHTS, map_location=device)  #Load model
85
    _ = model.eval()
86
    names = model.module.names if hasattr(model, 'module') else model.names  # get class names
87
   
88
    if source.isnumeric() :    
89
        cap = cv2.VideoCapture(int(source))    #pass video to videocapture object
90
    else:
91
        cap = cv2.VideoCapture(source)    #pass video to videocapture object
92
   
93
    if (cap.isOpened() == False):   #check if videocapture not opened
94
        print('Error while trying to read video. Please check path again')
95
        raise SystemExit()
96
97
    else:
98
        frame_width = int(cap.get(3))  #get video frame width
99
        # logic for multiple persons
100
        people = {}
101
        next_object_id = 0
102
        # logic for single persons
103
        current_sequence = []
104
105
        previous_status = "None"
106
        longevity = 0 # frames spent in the current status
107
108
        # generate_audios("good"); generate_audios("bad")
109
        # bad_audio_thread = threading.Thread(target=play_audio, args=["bad"])
110
111
        empty = False
112
        while(cap.isOpened):
113
            ret, frame = cap.read() 
114
            if ret: 
115
                orig_image = frame 
116
                image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB) 
117
                image = letterbox(image, (frame_width), stride=64, auto=True)[0]
118
                image = transforms.ToTensor()(image)
119
                image = torch.tensor(np.array([image.numpy()]))
120
                image = image.to(device)
121
                image = image.float()
122
            
123
                with torch.no_grad():  #get predictions
124
                    output_data, _ = model(image)
125
126
                output_data = non_max_suppression_kpt(output_data,   #Apply non max suppression
127
                                            0.25,   # Conf. Threshold.
128
                                            0.65, # IoU Threshold.
129
                                            nc=model.yaml['nc'], # Number of classes.
130
                                            nkpt=model.yaml['nkpt'], # Number of keypoints.
131
                                            kpt_label=True)
132
            
133
                output = output_to_keypoint(output_data)
134
                if multiple:
135
                    if len(output) == 0:
136
                        if not empty:
137
                            print("Wiping data, waiting for objects to appear in frame")
138
                        people = {}
139
                        next_object_id = 0
140
                        empty = True
141
                    else:
142
                        empty = False
143
                else:
144
                    if output.shape[0] > 0:
145
                        if frame_count % separation == 0:
146
                            landmarks = output[0, 7:].T
147
                            current_sequence += [landmarks[:-1]]
148
149
                        if len(current_sequence) == 10:
150
                            current_sequence = np.array([current_sequence])
151
                            payload = {'array': current_sequence.tolist() }
152
                            response = predict_http_request(payload)
153
154
                            current_score = response['score']
155
156
                            previous_status = current_status
157
                            current_status = response['status']
158
                            # score, status = asyncio.run(predict_request(payload))
159
                            # if status == 'server-error':
160
                            #     print('Server error or server not launched')
161
                            # print(score, status)
162
                            current_sequence = []
163
164
                        # if current_status == previous_status:
165
                        #     if not bad_audio_thread.is_alive() and longevity < 30:
166
                        #         longevity += 1
167
                        #     else:
168
                        #         longevity = 0
169
                        # else:
170
                        #     longevity = 0
171
172
                        # if longevity == 30 and current_status == "bad":
173
                        #     try:
174
                        #         if not bad_audio_thread.is_alive():
175
                        #             bad_audio_thread = threading.Thread(target=play_audio("bad"))
176
                        #             bad_audio_thread.start()
177
                        #     except Exception as e:
178
                        #         pass
179
180
181
182
183
                im0 = image[0].permute(1, 2, 0) * 255 # Change format [b, c, h, w] to [h, w, c] for displaying the image.
184
                im0 = im0.cpu().numpy().astype(np.uint8)
185
                
186
                im0 = cv2.cvtColor(im0, cv2.COLOR_RGB2BGR) #reshape image format to (BGR)
187
                gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
188
189
                for i, pose in enumerate(output_data):  # detections per image
190
                    if empty: break
191
                
192
                    if len(output_data) == 0:
193
                        continue
194
                    for det_index, (*xyxy, conf, cls) in enumerate(reversed(pose[:,:6])): #loop over poses for drawing on frame
195
                        c = int(cls)  # integer class
196
                        kpts = pose[det_index, 6:]
197
198
199
                        if multiple:
200
                            # get the centroid (cx, cy) for the current rectangle
201
                            rect = [tensor.cpu().numpy() for tensor in xyxy]
202
                            cx, cy = (rect[0] + rect[2]) / 2, (rect[1] + rect[3]) / 2
203
                            matched_object_id = None
204
205
                            # iterating through known people
206
                            for object_id, data in people.items():
207
                                distance = np.sqrt((cx - data['centroid'][0]) ** 2 + (cy - data['centroid'][1]) ** 2)
208
                                print(distance)
209
                                if distance < 300:  # Adjust the threshold as needed
210
                                    matched_object_id = object_id
211
                                    break
212
213
                            if matched_object_id is None:
214
                                matched_object_id = next_object_id
215
                                next_object_id += 1
216
217
                            if matched_object_id not in people:
218
                                people[matched_object_id] = {'centroid': (cx, cy), 'yoloid': det_index, 'status': 'good', 'score': 0, 'sequence' : []}
219
                            else:
220
                                people[matched_object_id]['centroid'] = (cx, cy)
221
                                people[matched_object_id]['yoloid'] = det_index
222
223
                            obj = people[matched_object_id]
224
                            label = f"ID #{obj['yoloid']} Score: {obj['score']:.2f}"
225
                            plot_one_box_kpt(xyxy, im0, label=label, color=colors(c, True), 
226
                                        line_thickness=3, kpt_label=True, kpts=kpts, steps=3, 
227
                                        cmap=people[matched_object_id]['status'])
228
                        else:
229
                            label = f"ID #{0} Score: {current_score:.2f}"
230
                            plot_one_box_kpt(xyxy, im0, label=label, color=colors(c, True), 
231
                                        line_thickness=3,kpt_label=True, kpts=kpts, steps=3, 
232
                                        cmap=current_status)
233
234
235
                if frame_count % separation == 0 and multiple:
236
                    for _, data in people.items():
237
                        if data['yoloid'] < output.shape[0]:
238
                            yoloid = data['yoloid']
239
                            landmarks = output[yoloid, 7:].T
240
                            data['sequence'] += [landmarks[:-1]]
241
                        
242
                            if len(data['sequence']) == length:
243
                                payload = {'array': np.array([data['sequence']]).tolist()}
244
                                response = predict_http_request(payload)
245
246
                                data['score'] = response['score']
247
                                data['status'] = response['status']
248
                                data['sequence'] = []
249
250
                            # print(f"{data['yoloid']} -> {data['status']}", end=' ')
251
                        else:
252
                            data['sequence'] = []
253
254
                    statuses = [(people[p]['yoloid'], people[p]['status']) for p in people]
255
                    # for id, status in statuses:
256
                    #     print(f'{id}: {status}', end='\t')
257
                    # print()
258
259
260
                frame_count += 1
261
262
                
263
                cv2.imshow("YOLOv7 Pose Estimation Demo", im0)
264
                key = cv2.waitKey(1) & 0xFF  # Wait for 1 millisecond and get the pressed key
265
                if key == ord('q'):
266
                    cv2.destroyAllWindows()  # Close the window if 'q' is pressed
267
                    break
268
            else:
269
                break
270
271
        cap.release()
272
273
274
class AutoPostureApp(App):
275
276
    def __init__(self, opt):
277
        super().__init__()
278
        self.opt = opt
279
        self.source = opt.source
280
        self.device = opt.device
281
        self.separation = opt.separation
282
        self.length = opt.length
283
        self.multiple = opt.multiple
284
285
    
286
    def build(self):
287
        # Create a layout for the GUI
288
        layout = BoxLayout(orientation='vertical')
289
290
        exit_button = Button(text='Exit Pose Estimation')
291
        layout.add_widget(exit_button)
292
        exit_button.bind(on_press=self.exit_app)
293
294
        # Create a title label
295
        title_label = Label(text='AutoPosture', font_size=20)
296
        layout.add_widget(title_label)
297
298
        # Create a checkbox to toggle the webcam display
299
        self.webcam_checkbox = CheckBox()
300
        self.webcam_checkbox.text = "View webcam"
301
        layout.add_widget(self.webcam_checkbox)
302
303
        # Create a button to start the pose estimation process
304
        start_button = Button(text='Start')
305
        layout.add_widget(start_button)
306
307
        # self.start_button.bind(on_press=lambda instance: self.runPoseEstimation(self.opt))
308
309
        start_button.bind(on_press=self.runPoseEstimation)  # Bind the button to the action
310
311
        return layout
312
    
313
    def runPoseEstimation(self, instance):
314
        run(self.source, self.opt.device, self.opt.separation, self.opt.length, self.opt.multiple)
315
316
    def exit_app(self, instance):
317
        App.get_running_app().stop()
318
319
320
def parse_opt():
321
    parser = argparse.ArgumentParser()
322
    parser.add_argument('-s', '--source', type=str, default='video/0', help='video/0 for webcam') #video source
323
    parser.add_argument('-d', '--device', type=str, default='cpu', help='cpu/0,1,2,3(gpu)')   #device arugments
324
    parser.add_argument('-sep', '--separation', type=str, default='1', help='Each how many frames the prediction will be executed. Defaults to 1, increase for performance') #separation arugments
325
    parser.add_argument('-l', '--length', type=str, default='10', help='Defines the length of the sequence. Defaults to 10, decrease for performance') #separation arugments
326
    parser.add_argument('-m', '--multiple', default=False, action='store_true', help='Enable multiple-person detection')  # Boolean for multiple person detection
327
    opt = parser.parse_args()
328
    return opt
329
    
330
331
def main(opt):
332
    app = AutoPostureApp(opt)
333
    app.run()
334
335
if __name__ == '__main__':
336
    opt = parse_opt()
337
    main(opt)