Switch to unified view

a b/landmark_extraction/pose-estimate.py
1
import cv2
2
import time
3
import torch
4
import argparse
5
import numpy as np
6
import matplotlib
7
import matplotlib.pyplot as plt
8
from torchvision import transforms
9
from utils.datasets import letterbox
10
from utils.torch_utils import select_device
11
from models.experimental import attempt_load
12
from utils.general import non_max_suppression_kpt,strip_optimizer,xyxy2xywh
13
from utils.plots import output_to_keypoint, plot_skeleton_kpts,colors,plot_one_box_kpt
14
15
matplotlib.use('TkAgg')
16
17
@torch.no_grad()
18
def run(poseweights="yolov7-w6-pose.pt",source="football1.mp4",device='cpu',view_img=False,
19
        save_conf=False,line_thickness = 3,hide_labels=False, hide_conf=True):
20
21
    frame_count = 0  # count no of frames
22
    total_fps = 0  # count total fps
23
    time_list = []   # list to store time
24
    fps_list = []    # list to store fps
25
    fps_graph = []
26
    
27
    device = select_device(opt.device) #select device
28
    half = device.type != 'cpu'
29
30
    model = attempt_load(poseweights, map_location=device)  #Load model
31
    _ = model.eval()
32
    names = model.module.names if hasattr(model, 'module') else model.names  # get class names
33
   
34
    if source.isnumeric() :    
35
        cap = cv2.VideoCapture(int(source))    #pass video to videocapture object
36
    else :
37
        cap = cv2.VideoCapture(source)    #pass video to videocapture object
38
   
39
    if (cap.isOpened() == False):   #check if videocapture not opened
40
        print('Error while trying to read video. Please check path again')
41
        raise SystemExit()
42
43
    else:
44
        frame_width = int(cap.get(3))  #get video frame width
45
        frame_height = int(cap.get(4)) #get video frame height
46
47
        
48
        vid_write_image = letterbox(cap.read()[1], (frame_width), stride=64, auto=True)[0] #init videowriter
49
        resize_height, resize_width = vid_write_image.shape[:2]
50
        out_video_name = f"{source.split('/')[-1].split('.')[0]}"
51
        out = cv2.VideoWriter(f"{source}_keypoint.mp4",
52
                            cv2.VideoWriter_fourcc(*'mp4v'), 30,
53
                            (resize_width, resize_height))
54
55
        while(cap.isOpened): #loop until cap opened or video not complete
56
        
57
            # print("Frame {} Processing".format(frame_count+1))
58
59
            ret, frame = cap.read()  #get frame and success from video capture
60
            
61
            if ret: #if success is true, means frame exist
62
                orig_image = frame #store frame
63
                image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB) #convert frame to RGB
64
                image = letterbox(image, (frame_width), stride=64, auto=True)[0]
65
                image_ = image.copy()
66
                image = transforms.ToTensor()(image)
67
                image = torch.tensor(np.array([image.numpy()]))
68
            
69
                image = image.to(device)  #convert image data to device
70
                image = image.float() #convert image to float precision (cpu)
71
                start_time = time.time() #start time for fps calculation
72
            
73
                with torch.no_grad():  #get predictions
74
                    output_data, _ = model(image)
75
                
76
                # keypoint_file = open("keypoints.md", "a")
77
                # keypoint_file.write(str(output_data))
78
                # keypoint_file.close()
79
80
81
82
                output_data = non_max_suppression_kpt(output_data,   #Apply non max suppression
83
                                            0.25,   # Conf. Threshold.
84
                                            0.65, # IoU Threshold.
85
                                            nc=model.yaml['nc'], # Number of classes.
86
                                            nkpt=model.yaml['nkpt'], # Number of keypoints.
87
                                            kpt_label=True)
88
            
89
                output = output_to_keypoint(output_data)
90
91
                im0 = image[0].permute(1, 2, 0) * 255 # Change format [b, c, h, w] to [h, w, c] for displaying the image.
92
                im0 = im0.cpu().numpy().astype(np.uint8)
93
                
94
                im0 = cv2.cvtColor(im0, cv2.COLOR_RGB2BGR) #reshape image format to (BGR)
95
                gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
96
97
                for i, pose in enumerate(output_data):  # detections per image
98
                
99
                    if len(output_data):  #check if no pose
100
                        for c in pose[:, 5].unique(): # Print results
101
                            n = (pose[:, 5] == c).sum()  # detections per class
102
                            print("No of Objects in Current Frame : {}".format(n))
103
                        
104
                        for det_index, (*xyxy, conf, cls) in enumerate(reversed(pose[:,:6])): #loop over poses for drawing on frame
105
                            c = int(cls)  # integer class
106
                            kpts = pose[det_index, 6:]
107
                            label = None if opt.hide_labels else (names[c] if opt.hide_conf else f'{names[c]} {conf:.2f}')
108
                            plot_one_box_kpt(xyxy, im0, label=label, color=colors(c, True), 
109
                                        line_thickness=opt.line_thickness,kpt_label=True, kpts=kpts, steps=3, 
110
                                        orig_shape=im0.shape[:2])
111
112
                
113
                end_time = time.time()  #Calculatio for FPS
114
                fps = 1 / (end_time - start_time)
115
                total_fps += fps
116
                frame_count += 1
117
                
118
                fps_list.append(total_fps) #append FPS in list
119
                time_list.append(end_time - start_time) #append time in list
120
                fps_graph.append(fps)
121
                
122
                cv2.imshow("YOLOv7 Pose Estimation Demo", im0)
123
                # Press Q on keyboard to  exit
124
                if cv2.waitKey(25) & 0xFF == ord('q'):
125
                  break
126
127
            else:
128
                break
129
130
        cap.release()
131
        cv2.destroyAllWindows()
132
        avg_fps = total_fps / frame_count
133
        print(f"Average FPS: {avg_fps:.3f}")
134
        
135
        #plot the comparision graph
136
        frame_indices = range(len(fps_graph))
137
        fig, ax = plt.subplots()
138
        ax.plot(frame_indices, fps_graph, color='blue', linestyle='-')
139
        ax.set_xlabel('Frame Index')
140
        ax.set_ylabel('FPS')
141
        ax.set_title('Frames per Second (FPS)')
142
        ax.grid(True, linestyle='--', alpha=0.5)
143
        ax.set_facecolor('#f0f0f0')
144
        plt.savefig("FPS_graph.png")
145
        plt.show()
146
        # plot_fps_time_comparision(time_list=time_list,fps_list=fps_graph)
147
148
149
def parse_opt():
150
    parser = argparse.ArgumentParser()
151
    parser.add_argument('--poseweights', nargs='+', type=str, default='yolov7-w6-pose.pt', help='model path(s)')
152
    parser.add_argument('--source', type=str, default='football1.mp4', help='video/0 for webcam') #video source
153
    parser.add_argument('--device', type=str, default='cpu', help='cpu/0,1,2,3(gpu)')   #device arugments
154
    parser.add_argument('--view-img', action='store_true', help='display results')  #display results
155
    parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') #save confidence in txt writing
156
    parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)') #box linethickness
157
    parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels') #box hidelabel
158
    parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences') #boxhideconf
159
    opt = parser.parse_args()
160
    return opt
161
162
#function for plot fps and time comparision graph
163
def plot_fps_time_comparision(time_list,fps_list):
164
    plt.figure()
165
    plt.xlabel('Time (s)')
166
    plt.ylabel('FPS')
167
    plt.title('FPS and Time Comparision Graph')
168
    plt.plot(time_list, fps_list,'b',label="FPS & Time")
169
    plt.savefig("FPS_and_Time_Comparision_pose_estimate.png")
170
    
171
172
#main function
173
def main(opt):
174
    run(**vars(opt))
175
176
if __name__ == "__main__":
177
    opt = parse_opt()
178
    dev = int(opt.device) if opt.device.isnumeric() else opt.device
179
    strip_optimizer(dev,opt.poseweights)
180
    main(opt)