Diff of /autoposture.py [000000] .. [a5e8ec]

Switch to unified view

a b/autoposture.py
1
import argparse
2
import os
3
import sys
4
import time
5
6
import cv2
7
import matplotlib.pyplot as plt
8
import numpy as np
9
import requests
10
from internal.prediction_client import predict_http_request
11
12
13
yolov7_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'vendor/yolov7')
14
sys.path.append(yolov7_path)
15
16
from models.experimental import attempt_load
17
import torch
18
from torchvision import transforms
19
from utils.datasets import letterbox
20
from utils.general import non_max_suppression_kpt, strip_optimizer, xyxy2xywh
21
from utils.plots import colors, output_to_keypoint, plot_one_box_kpt, plot_skeleton_kpts
22
from utils.torch_utils import select_device
23
# from tts.tttest import generate_audios, play_audio
24
import asyncio
25
import threading
26
import websockets
27
import json
28
29
30
POSEWEIGHTS = 'src_models/yolov7-w6-pose.pt'
31
32
33
@torch.no_grad()
34
def run(source, device, separation, length, multiple):
35
    # global ap_model
36
    separation = int(separation)
37
    length = int(length)
38
39
    frame_count = 0  #count no of frames
40
    total_fps = 0  #count total fps
41
    
42
    device = select_device(opt.device) #select device
43
    model = attempt_load(POSEWEIGHTS, map_location=device)  #Load model
44
    _ = model.eval()
45
    names = model.module.names if hasattr(model, 'module') else model.names  # get class names
46
   
47
    if source.isnumeric() :    
48
        cap = cv2.VideoCapture(int(source))    #pass video to videocapture object
49
    else:
50
        cap = cv2.VideoCapture(source)    #pass video to videocapture object
51
   
52
    if (cap.isOpened() == False):   #check if videocapture not opened
53
        print('Error while trying to read video. Please check path again')
54
        raise SystemExit()
55
56
    else:
57
        frame_width = int(cap.get(3))  #get video frame width
58
        # logic for multiple persons
59
        people = {}
60
        next_object_id = 0
61
        # logic for single persons
62
        current_sequence = []
63
        current_score = 0
64
        current_status = 'good'
65
        previous_status = "None"
66
        longevity = 0 # frames spent in the current status
67
68
        # generate_audios("good"); generate_audios("bad")
69
        # bad_audio_thread = threading.Thread(target=play_audio, args=["bad"])
70
71
        empty = False
72
        while(cap.isOpened):
73
            ret, frame = cap.read() 
74
            if ret: 
75
                orig_image = frame 
76
                image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB) 
77
                image = letterbox(image, (frame_width), stride=64, auto=True)[0]
78
                image = transforms.ToTensor()(image)
79
                image = torch.tensor(np.array([image.numpy()]))
80
                image = image.to(device)
81
                image = image.float()
82
            
83
                with torch.no_grad():  #get predictions
84
                    output_data, _ = model(image)
85
86
                output_data = non_max_suppression_kpt(output_data,   #Apply non max suppression
87
                                            0.25,   # Conf. Threshold.
88
                                            0.65, # IoU Threshold.
89
                                            nc=model.yaml['nc'], # Number of classes.
90
                                            nkpt=model.yaml['nkpt'], # Number of keypoints.
91
                                            kpt_label=True)
92
            
93
                output = output_to_keypoint(output_data)
94
                if multiple:
95
                    if len(output) == 0:
96
                        if not empty:
97
                            print("Wiping data, waiting for objects to appear in frame")
98
                        people = {}
99
                        next_object_id = 0
100
                        empty = True
101
                    else:
102
                        empty = False
103
                else:
104
                    if output.shape[0] > 0:
105
                        if frame_count % separation == 0:
106
                            landmarks = output[0, 7:].T
107
                            current_sequence += [landmarks[:-1]]
108
109
                        if len(current_sequence) == 10:
110
                            current_sequence = np.array([current_sequence])
111
                            payload = {'array': current_sequence.tolist() }
112
                            response = predict_http_request(payload)
113
114
                            current_score = response['score']
115
116
                            previous_status = current_status
117
                            current_status = response['status']
118
                            # score, status = asyncio.run(predict_request(payload))
119
                            # if status == 'server-error':
120
                            #     print('Server error or server not launched')
121
                            # print(score, status)
122
                            current_sequence = []
123
124
                        # if current_status == previous_status:
125
                        #     if not bad_audio_thread.is_alive() and longevity < 30:
126
                        #         longevity += 1
127
                        #     else:
128
                        #         longevity = 0
129
                        # else:
130
                        #     longevity = 0
131
132
                        # if longevity == 30 and current_status == "bad":
133
                        #     try:
134
                        #         if not bad_audio_thread.is_alive():
135
                        #             bad_audio_thread = threading.Thread(target=play_audio("bad"))
136
                        #             bad_audio_thread.start()
137
                        #     except Exception as e:
138
                        #         pass
139
140
141
142
143
                im0 = image[0].permute(1, 2, 0) * 255 # Change format [b, c, h, w] to [h, w, c] for displaying the image.
144
                im0 = im0.cpu().numpy().astype(np.uint8)
145
                
146
                im0 = cv2.cvtColor(im0, cv2.COLOR_RGB2BGR) #reshape image format to (BGR)
147
                gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
148
149
                for i, pose in enumerate(output_data):  # detections per image
150
                    if empty: break
151
                
152
                    if len(output_data) == 0:
153
                        continue
154
                    for det_index, (*xyxy, conf, cls) in enumerate(reversed(pose[:,:6])): #loop over poses for drawing on frame
155
                        c = int(cls)  # integer class
156
                        kpts = pose[det_index, 6:]
157
158
159
                        if multiple:
160
                            # get the centroid (cx, cy) for the current rectangle
161
                            rect = [tensor.cpu().numpy() for tensor in xyxy]
162
                            cx, cy = (rect[0] + rect[2]) / 2, (rect[1] + rect[3]) / 2
163
                            matched_object_id = None
164
165
                            # iterating through known people
166
                            for object_id, data in people.items():
167
                                distance = np.sqrt((cx - data['centroid'][0]) ** 2 + (cy - data['centroid'][1]) ** 2)
168
                                print(distance)
169
                                if distance < 300:  # Adjust the threshold as needed
170
                                    matched_object_id = object_id
171
                                    break
172
173
                            if matched_object_id is None:
174
                                matched_object_id = next_object_id
175
                                next_object_id += 1
176
177
                            if matched_object_id not in people:
178
                                people[matched_object_id] = {'centroid': (cx, cy), 'yoloid': det_index, 'status': 'good', 'score': 0, 'sequence' : []}
179
                            else:
180
                                people[matched_object_id]['centroid'] = (cx, cy)
181
                                people[matched_object_id]['yoloid'] = det_index
182
183
                            obj = people[matched_object_id]
184
                            label = f"ID #{obj['yoloid']} Score: {obj['score']:.2f}"
185
                            plot_one_box_kpt(xyxy, im0, label=label, color=colors(c, True), 
186
                                        line_thickness=3, kpt_label=True, kpts=kpts, steps=3, 
187
                                        cmap=people[matched_object_id]['status'])
188
                        else:
189
                            label = f"ID #{0} Score: {current_score:.2f}"
190
                            plot_one_box_kpt(xyxy, im0, label=label, color=colors(c, True), 
191
                                        line_thickness=3,kpt_label=True, kpts=kpts, steps=3, 
192
                                        cmap=current_status)
193
194
195
                if frame_count % separation == 0 and multiple:
196
                    for _, data in people.items():
197
                        if data['yoloid'] < output.shape[0]:
198
                            yoloid = data['yoloid']
199
                            landmarks = output[yoloid, 7:].T
200
                            data['sequence'] += [landmarks[:-1]]
201
                        
202
                            if len(data['sequence']) == length:
203
                                payload = {'array': np.array([data['sequence']]).tolist()}
204
                                response = predict_http_request(payload)
205
206
                                data['score'] = response['score']
207
                                data['status'] = response['status']
208
                                data['sequence'] = []
209
210
                            # print(f"{data['yoloid']} -> {data['status']}", end=' ')
211
                        else:
212
                            data['sequence'] = []
213
214
                    statuses = [(people[p]['yoloid'], people[p]['status']) for p in people]
215
                    # for id, status in statuses:
216
                    #     print(f'{id}: {status}', end='\t')
217
                    # print()
218
219
220
                frame_count += 1
221
222
                
223
                cv2.imshow("YOLOv7 Pose Estimation Demo", im0)
224
                key = cv2.waitKey(1) & 0xFF  # Wait for 1 millisecond and get the pressed key
225
                if key == ord('q'):
226
                    cv2.destroyAllWindows()  # Close the window if 'q' is pressed
227
                    break
228
            else:
229
                break
230
231
        cap.release()
232
233
def parse_opt():
234
    parser = argparse.ArgumentParser()
235
    parser.add_argument('-s', '--source', type=str, default='video/0', help='video/0 for webcam') #video source
236
    parser.add_argument('-d', '--device', type=str, default='cpu', help='cpu/0,1,2,3(gpu)')   #device arugments
237
    parser.add_argument('-sep', '--separation', type=str, default='1', help='Each how many frames the prediction will be executed. Defaults to 1, increase for performance') #separation arugments
238
    parser.add_argument('-l', '--length', type=str, default='10', help='Defines the length of the sequence. Defaults to 10, decrease for performance') #separation arugments
239
    parser.add_argument('-m', '--multiple', default=False, action='store_true', help='Enable multiple-person detection')  # Boolean for multiple person detection
240
    opt = parser.parse_args()
241
    return opt
242
    
243
244
def main(opt):
245
    run(**vars(opt))
246
247
if __name__ == "__main__":
248
    opt = parse_opt()
249
    strip_optimizer(opt.device, POSEWEIGHTS)
250
    main(opt)