|
a |
|
b/internal/frame_process.py |
|
|
1 |
import sys |
|
|
2 |
import os |
|
|
3 |
from torchvision import transforms |
|
|
4 |
import torch |
|
|
5 |
import cv2 |
|
|
6 |
import numpy as np |
|
|
7 |
import requests |
|
|
8 |
|
|
|
9 |
yolov7_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../vendor/yolov7') |
|
|
10 |
sys.path.append(yolov7_path) |
|
|
11 |
|
|
|
12 |
from models.experimental import attempt_load |
|
|
13 |
from utils.datasets import letterbox |
|
|
14 |
from utils.general import non_max_suppression_kpt, strip_optimizer, xyxy2xywh |
|
|
15 |
from utils.plots import colors, output_to_keypoint, plot_one_box_kpt, plot_skeleton_kpts |
|
|
16 |
from utils.torch_utils import select_device |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
|
|
|
20 |
|
|
|
21 |
HOST = 'localhost' |
|
|
22 |
PORT = '8420' |
|
|
23 |
def predict_http_request(payload): |
|
|
24 |
""" |
|
|
25 |
Args: |
|
|
26 |
- payload: {'array': (1, 10, 50) shape (10 frames)} |
|
|
27 |
Returns: |
|
|
28 |
- score: Value between 0 and 1 |
|
|
29 |
- status: Good or bad posture (depending on threshold:0.7) |
|
|
30 |
""" |
|
|
31 |
response = requests.post(f"http://{HOST}:{PORT}/predict", json=payload) |
|
|
32 |
if response.status_code == 200: |
|
|
33 |
return response.json() |
|
|
34 |
else: |
|
|
35 |
print("Error:", response.status_code) |
|
|
36 |
print(response.text) |
|
|
37 |
|
|
|
38 |
|
|
|
39 |
model = None |
|
|
40 |
device = None |
|
|
41 |
def model_initialization(device_ref, w): |
|
|
42 |
global model, device |
|
|
43 |
device = select_device(device_ref) |
|
|
44 |
model = attempt_load(w, map_location=device) |
|
|
45 |
return model |
|
|
46 |
|
|
|
47 |
|
|
|
48 |
separation =0 |
|
|
49 |
length = 10 |
|
|
50 |
frame_count = 0 |
|
|
51 |
# logic for multiple persons |
|
|
52 |
people = {} |
|
|
53 |
next_object_id = 0 |
|
|
54 |
# logic for single persons |
|
|
55 |
current_sequence = [] |
|
|
56 |
current_score = 0 |
|
|
57 |
current_status = 'good' |
|
|
58 |
previous_status = "None" |
|
|
59 |
longevity = 0 # frames spent in the current status |
|
|
60 |
separation = 1 |
|
|
61 |
multiple = False |
|
|
62 |
frame_count = 0 |
|
|
63 |
empty = False |
|
|
64 |
|
|
|
65 |
|
|
|
66 |
# for audio playing |
|
|
67 |
iterations_in_bad_posture = 0 |
|
|
68 |
max_iterations_in_bad_posture = 5 |
|
|
69 |
should_alert = False |
|
|
70 |
|
|
|
71 |
|
|
|
72 |
@torch.no_grad() |
|
|
73 |
def on_update(frame, recently_alerted, threshold = 0.7): |
|
|
74 |
global separation, frame_count, current_sequence, empty, current_score, current_status, should_alert,\ |
|
|
75 |
iterations_in_bad_posture, max_iterations_in_bad_posture |
|
|
76 |
if recently_alerted: |
|
|
77 |
should_alert = False |
|
|
78 |
|
|
|
79 |
orig_image = frame |
|
|
80 |
image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB) |
|
|
81 |
image = letterbox(image, (frame.shape[1]), stride=64, auto=True)[0] |
|
|
82 |
image = transforms.ToTensor()(image) |
|
|
83 |
image = torch.tensor(np.array([image.numpy()])) |
|
|
84 |
image = image.to(device) |
|
|
85 |
image = image.float() |
|
|
86 |
|
|
|
87 |
with torch.no_grad(): #get predictions |
|
|
88 |
output_data, _ = model(image) |
|
|
89 |
|
|
|
90 |
output_data = non_max_suppression_kpt(output_data, #Apply non max suppression |
|
|
91 |
0.25, # Conf. Threshold. |
|
|
92 |
0.65, # IoU Threshold. |
|
|
93 |
nc=model.yaml['nc'], # Number of classes. |
|
|
94 |
nkpt=model.yaml['nkpt'], # Number of keypoints. |
|
|
95 |
kpt_label=True) |
|
|
96 |
|
|
|
97 |
output = output_to_keypoint(output_data) |
|
|
98 |
if multiple: |
|
|
99 |
if len(output) == 0: |
|
|
100 |
if not empty: |
|
|
101 |
print("Wiping data, waiting for objects to appear in frame") |
|
|
102 |
people = {} |
|
|
103 |
next_object_id = 0 |
|
|
104 |
empty = True |
|
|
105 |
else: |
|
|
106 |
empty = False |
|
|
107 |
else: |
|
|
108 |
if output.shape[0] > 0: |
|
|
109 |
if frame_count % separation == 0: |
|
|
110 |
landmarks = output[0, 7:].T |
|
|
111 |
current_sequence += [landmarks[:-1]] |
|
|
112 |
|
|
|
113 |
if len(current_sequence) == 10: |
|
|
114 |
current_sequence = np.array([current_sequence]) |
|
|
115 |
payload = {'array': current_sequence.tolist() } |
|
|
116 |
response = predict_http_request(payload) |
|
|
117 |
|
|
|
118 |
current_score = response['score'] |
|
|
119 |
response['status'] = 'good' if current_score > threshold else 'bad' |
|
|
120 |
if response['status'] == 'bad': |
|
|
121 |
iterations_in_bad_posture += 1 |
|
|
122 |
current_status = response['status'] |
|
|
123 |
current_sequence = [] |
|
|
124 |
|
|
|
125 |
im0 = image[0].permute(1, 2, 0) * 255 # Change format [b, c, h, w] to [h, w, c] for displaying the image. |
|
|
126 |
im0 = im0.cpu().numpy().astype(np.uint8) |
|
|
127 |
|
|
|
128 |
im0 = cv2.cvtColor(im0, cv2.COLOR_RGB2BGR) #reshape image format to (BGR) |
|
|
129 |
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh |
|
|
130 |
|
|
|
131 |
for i, pose in enumerate(output_data): # detections per image |
|
|
132 |
if empty: break |
|
|
133 |
|
|
|
134 |
if len(output_data) == 0: |
|
|
135 |
continue |
|
|
136 |
for det_index, (*xyxy, conf, cls) in enumerate(reversed(pose[:,:6])): #loop over poses for drawing on frame |
|
|
137 |
c = int(cls) # integer class |
|
|
138 |
kpts = pose[det_index, 6:] |
|
|
139 |
|
|
|
140 |
|
|
|
141 |
if multiple: |
|
|
142 |
# get the centroid (cx, cy) for the current rectangle |
|
|
143 |
rect = [tensor.cpu().numpy() for tensor in xyxy] |
|
|
144 |
cx, cy = (rect[0] + rect[2]) / 2, (rect[1] + rect[3]) / 2 |
|
|
145 |
matched_object_id = None |
|
|
146 |
|
|
|
147 |
# iterating through known people |
|
|
148 |
for object_id, data in people.items(): |
|
|
149 |
distance = np.sqrt((cx - data['centroid'][0]) ** 2 + (cy - data['centroid'][1]) ** 2) |
|
|
150 |
print(distance) |
|
|
151 |
if distance < 300: # Adjust the threshold as needed |
|
|
152 |
matched_object_id = object_id |
|
|
153 |
break |
|
|
154 |
|
|
|
155 |
if matched_object_id is None: |
|
|
156 |
matched_object_id = next_object_id |
|
|
157 |
next_object_id += 1 |
|
|
158 |
|
|
|
159 |
if matched_object_id not in people: |
|
|
160 |
people[matched_object_id] = {'centroid': (cx, cy), 'yoloid': det_index, 'status': 'good', 'score': 0, 'sequence' : []} |
|
|
161 |
else: |
|
|
162 |
people[matched_object_id]['centroid'] = (cx, cy) |
|
|
163 |
people[matched_object_id]['yoloid'] = det_index |
|
|
164 |
|
|
|
165 |
obj = people[matched_object_id] |
|
|
166 |
label = f"ID #{obj['yoloid']} Score: {obj['score']:.2f}" |
|
|
167 |
plot_one_box_kpt(xyxy, im0, label=label, color=colors(c, True), |
|
|
168 |
line_thickness=3, kpt_label=True, kpts=kpts, steps=3, |
|
|
169 |
cmap=people[matched_object_id]['status']) |
|
|
170 |
else: |
|
|
171 |
label = f"ID #{0} Score: {current_score:.2f}" |
|
|
172 |
plot_one_box_kpt(xyxy, im0, label=label, color=colors(c, True), |
|
|
173 |
line_thickness=3,kpt_label=True, kpts=kpts, steps=3, |
|
|
174 |
cmap=current_status) |
|
|
175 |
|
|
|
176 |
|
|
|
177 |
if frame_count % separation == 0 and multiple: |
|
|
178 |
for _, data in people.items(): |
|
|
179 |
if data['yoloid'] < output.shape[0]: |
|
|
180 |
yoloid = data['yoloid'] |
|
|
181 |
landmarks = output[yoloid, 7:].T |
|
|
182 |
data['sequence'] += [landmarks[:-1]] |
|
|
183 |
|
|
|
184 |
if len(data['sequence']) == length: |
|
|
185 |
payload = {'array': np.array([data['sequence']]).tolist()} |
|
|
186 |
response = predict_http_request(payload) |
|
|
187 |
|
|
|
188 |
data['score'] = response['score'] |
|
|
189 |
data['status'] = 'good' if score > THRESHOLD else 'bad' |
|
|
190 |
data['sequence'] = [] |
|
|
191 |
|
|
|
192 |
# print(f"{data['yoloid']} -> {data['status']}", end=' ') |
|
|
193 |
else: |
|
|
194 |
data['sequence'] = [] |
|
|
195 |
|
|
|
196 |
|
|
|
197 |
statuses = [(people[p]['yoloid'], people[p]['status']) for p in people] |
|
|
198 |
# for id, status in statuses: |
|
|
199 |
# print(f'{id}: {status}', end='\t') |
|
|
200 |
# print() |
|
|
201 |
if iterations_in_bad_posture >= max_iterations_in_bad_posture: |
|
|
202 |
iterations_in_bad_posture = 0 |
|
|
203 |
should_alert = True |
|
|
204 |
|
|
|
205 |
frame_count += 1 |
|
|
206 |
return im0, current_status, current_score, should_alert |