Diff of /internal/frame_process.py [000000] .. [a5e8ec]

Switch to unified view

a b/internal/frame_process.py
1
import sys
2
import os
3
from torchvision import transforms
4
import torch
5
import cv2
6
import numpy as np
7
import requests
8
9
yolov7_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../vendor/yolov7')
10
sys.path.append(yolov7_path)
11
12
from models.experimental import attempt_load
13
from utils.datasets import letterbox
14
from utils.general import non_max_suppression_kpt, strip_optimizer, xyxy2xywh
15
from utils.plots import colors, output_to_keypoint, plot_one_box_kpt, plot_skeleton_kpts
16
from utils.torch_utils import select_device
17
18
19
20
21
HOST = 'localhost'
22
PORT = '8420'
23
def predict_http_request(payload):
24
    """
25
    Args:
26
        - payload: {'array': (1, 10, 50) shape (10 frames)}
27
    Returns:
28
        - score: Value between 0 and 1
29
        - status: Good or bad posture (depending on threshold:0.7)
30
    """
31
    response = requests.post(f"http://{HOST}:{PORT}/predict", json=payload)
32
    if response.status_code == 200:
33
        return response.json()
34
    else:
35
        print("Error:", response.status_code)
36
        print(response.text)
37
38
39
model = None
40
device = None
41
def model_initialization(device_ref, w):
42
    global model, device
43
    device = select_device(device_ref)
44
    model = attempt_load(w, map_location=device)
45
    return model
46
47
48
separation =0
49
length = 10
50
frame_count = 0
51
# logic for multiple persons
52
people = {}
53
next_object_id = 0
54
# logic for single persons
55
current_sequence = []
56
current_score = 0
57
current_status = 'good'
58
previous_status = "None"
59
longevity = 0 # frames spent in the current status
60
separation = 1
61
multiple = False
62
frame_count = 0
63
empty = False
64
65
66
# for audio playing
67
iterations_in_bad_posture = 0 
68
max_iterations_in_bad_posture = 5
69
should_alert = False
70
71
72
@torch.no_grad()
73
def on_update(frame, recently_alerted, threshold = 0.7):
74
    global separation, frame_count, current_sequence, empty, current_score, current_status, should_alert,\
75
            iterations_in_bad_posture, max_iterations_in_bad_posture
76
    if recently_alerted:
77
        should_alert = False
78
79
    orig_image = frame 
80
    image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB) 
81
    image = letterbox(image, (frame.shape[1]), stride=64, auto=True)[0]
82
    image = transforms.ToTensor()(image)
83
    image = torch.tensor(np.array([image.numpy()]))
84
    image = image.to(device)
85
    image = image.float()
86
87
    with torch.no_grad():  #get predictions
88
        output_data, _ = model(image)
89
90
    output_data = non_max_suppression_kpt(output_data,   #Apply non max suppression
91
                                0.25,   # Conf. Threshold.
92
                                0.65, # IoU Threshold.
93
                                nc=model.yaml['nc'], # Number of classes.
94
                                nkpt=model.yaml['nkpt'], # Number of keypoints.
95
                                kpt_label=True)
96
97
    output = output_to_keypoint(output_data)
98
    if multiple:
99
        if len(output) == 0:
100
            if not empty:
101
                print("Wiping data, waiting for objects to appear in frame")
102
            people = {}
103
            next_object_id = 0
104
            empty = True
105
        else:
106
            empty = False
107
    else:
108
        if output.shape[0] > 0:
109
            if frame_count % separation == 0:
110
                landmarks = output[0, 7:].T
111
                current_sequence += [landmarks[:-1]]
112
113
            if len(current_sequence) == 10:
114
                current_sequence = np.array([current_sequence])
115
                payload = {'array': current_sequence.tolist() }
116
                response = predict_http_request(payload)
117
118
                current_score = response['score']
119
                response['status'] = 'good' if current_score > threshold else 'bad'
120
                if response['status'] == 'bad':
121
                    iterations_in_bad_posture += 1
122
                current_status = response['status']
123
                current_sequence = []
124
125
    im0 = image[0].permute(1, 2, 0) * 255 # Change format [b, c, h, w] to [h, w, c] for displaying the image.
126
    im0 = im0.cpu().numpy().astype(np.uint8)
127
    
128
    im0 = cv2.cvtColor(im0, cv2.COLOR_RGB2BGR) #reshape image format to (BGR)
129
    gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
130
131
    for i, pose in enumerate(output_data):  # detections per image
132
        if empty: break
133
    
134
        if len(output_data) == 0:
135
            continue
136
        for det_index, (*xyxy, conf, cls) in enumerate(reversed(pose[:,:6])): #loop over poses for drawing on frame
137
            c = int(cls)  # integer class
138
            kpts = pose[det_index, 6:]
139
140
141
            if multiple:
142
                # get the centroid (cx, cy) for the current rectangle
143
                rect = [tensor.cpu().numpy() for tensor in xyxy]
144
                cx, cy = (rect[0] + rect[2]) / 2, (rect[1] + rect[3]) / 2
145
                matched_object_id = None
146
147
                # iterating through known people
148
                for object_id, data in people.items():
149
                    distance = np.sqrt((cx - data['centroid'][0]) ** 2 + (cy - data['centroid'][1]) ** 2)
150
                    print(distance)
151
                    if distance < 300:  # Adjust the threshold as needed
152
                        matched_object_id = object_id
153
                        break
154
155
                if matched_object_id is None:
156
                    matched_object_id = next_object_id
157
                    next_object_id += 1
158
159
                if matched_object_id not in people:
160
                    people[matched_object_id] = {'centroid': (cx, cy), 'yoloid': det_index, 'status': 'good', 'score': 0, 'sequence' : []}
161
                else:
162
                    people[matched_object_id]['centroid'] = (cx, cy)
163
                    people[matched_object_id]['yoloid'] = det_index
164
165
                obj = people[matched_object_id]
166
                label = f"ID #{obj['yoloid']} Score: {obj['score']:.2f}"
167
                plot_one_box_kpt(xyxy, im0, label=label, color=colors(c, True), 
168
                            line_thickness=3, kpt_label=True, kpts=kpts, steps=3, 
169
                            cmap=people[matched_object_id]['status'])
170
            else:
171
                label = f"ID #{0} Score: {current_score:.2f}"
172
                plot_one_box_kpt(xyxy, im0, label=label, color=colors(c, True), 
173
                            line_thickness=3,kpt_label=True, kpts=kpts, steps=3, 
174
                            cmap=current_status)
175
176
177
    if frame_count % separation == 0 and multiple:
178
        for _, data in people.items():
179
            if data['yoloid'] < output.shape[0]:
180
                yoloid = data['yoloid']
181
                landmarks = output[yoloid, 7:].T
182
                data['sequence'] += [landmarks[:-1]]
183
            
184
                if len(data['sequence']) == length:
185
                    payload = {'array': np.array([data['sequence']]).tolist()}
186
                    response = predict_http_request(payload)
187
188
                    data['score'] = response['score']
189
                    data['status'] = 'good' if score > THRESHOLD else 'bad'
190
                    data['sequence'] = []
191
192
                # print(f"{data['yoloid']} -> {data['status']}", end=' ')
193
            else:
194
                data['sequence'] = []
195
196
        
197
        statuses = [(people[p]['yoloid'], people[p]['status']) for p in people]
198
        # for id, status in statuses:
199
        #     print(f'{id}: {status}', end='\t')
200
        # print()
201
    if iterations_in_bad_posture >= max_iterations_in_bad_posture:
202
        iterations_in_bad_posture = 0
203
        should_alert = True
204
205
    frame_count += 1
206
    return im0, current_status, current_score, should_alert