Diff of /detect.py [000000] .. [190ca4]

Switch to unified view

a b/detect.py
1
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
2
"""
3
Run YOLOv5 detection inference on images, videos, directories, globs, YouTube, webcam, streams, etc.
4
5
Usage - sources:
6
    $ python detect.py --weights yolov5s.pt --source 0                               # webcam
7
                                                     img.jpg                         # image
8
                                                     vid.mp4                         # video
9
                                                     screen                          # screenshot
10
                                                     path/                           # directory
11
                                                     list.txt                        # list of images
12
                                                     list.streams                    # list of streams
13
                                                     'path/*.jpg'                    # glob
14
                                                     'https://youtu.be/LNwODJXcvt4'  # YouTube
15
                                                     'rtsp://example.com/media.mp4'  # RTSP, RTMP, HTTP stream
16
17
Usage - formats:
18
    $ python detect.py --weights yolov5s.pt                 # PyTorch
19
                                 yolov5s.torchscript        # TorchScript
20
                                 yolov5s.onnx               # ONNX Runtime or OpenCV DNN with --dnn
21
                                 yolov5s_openvino_model     # OpenVINO
22
                                 yolov5s.engine             # TensorRT
23
                                 yolov5s.mlmodel            # CoreML (macOS-only)
24
                                 yolov5s_saved_model        # TensorFlow SavedModel
25
                                 yolov5s.pb                 # TensorFlow GraphDef
26
                                 yolov5s.tflite             # TensorFlow Lite
27
                                 yolov5s_edgetpu.tflite     # TensorFlow Edge TPU
28
                                 yolov5s_paddle_model       # PaddlePaddle
29
"""
30
31
import argparse
32
import csv
33
import os
34
import platform
35
import sys
36
from pathlib import Path
37
38
import torch
39
import copy
40
import torch.nn.functional as F
41
42
43
FILE = Path(__file__).resolve()
44
ROOT = FILE.parents[0]  # YOLOv5 root directory
45
if str(ROOT) not in sys.path:
46
    sys.path.append(str(ROOT))  # add ROOT to PATH
47
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative
48
49
from ultralytics.utils.plotting import Annotator, colors, save_one_box
50
51
from models.common import DetectMultiBackend
52
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
53
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
54
                           increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh,get_fixed_xyxy)
55
from utils.torch_utils import select_device, smart_inference_mode
56
from utils.my_model import MyCNN
57
from torchvision.ops import roi_align
58
59
@smart_inference_mode()
60
def run(
61
    weights=ROOT / "yolov5s.pt",  # model path or triton URL
62
    source=ROOT / "data/images",  # file/dir/URL/glob/screen/0(webcam)
63
    data=ROOT / "data/coco128.yaml",  # dataset.yaml path
64
    imgsz=(640, 640),  # inference size (height, width)
65
    conf_thres=0.25,  # confidence threshold
66
    iou_thres=0.45,  # NMS IOU threshold
67
    max_det=1000,  # maximum detections per image
68
    device="",  # cuda device, i.e. 0 or 0,1,2,3 or cpu
69
    view_img=False,  # show results
70
    save_txt=False,  # save results to *.txt
71
    save_csv=False,  # save results in CSV format
72
    save_conf=False,  # save confidences in --save-txt labels
73
    save_crop=False,  # save cropped prediction boxes
74
    nosave=False,  # do not save images/videos
75
    classes=None,  # filter by class: --class 0, or --class 0 2 3
76
    agnostic_nms=False,  # class-agnostic NMS
77
    augment=False,  # augmented inference
78
    visualize=False,  # visualize features
79
    update=False,  # update all models
80
    project=ROOT / "runs/detect",  # save results to project/name
81
    name="exp",  # save results to project/name
82
    exist_ok=False,  # existing project/name ok, do not increment
83
    line_thickness=3,  # bounding box thickness (pixels)
84
    hide_labels=False,  # hide labels
85
    hide_conf=False,  # hide confidences
86
    half=False,  # use FP16 half-precision inference
87
    dnn=False,  # use OpenCV DNN for ONNX inference
88
    vid_stride=1,  # video frame-rate stride
89
):
90
    source = str(source)
91
    save_img = not nosave and not source.endswith('.txt')  # save inference images
92
    is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
93
    is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
94
    webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
95
    screenshot = source.lower().startswith('screen')
96
    if is_url and is_file:
97
        source = check_file(source)  # download
98
99
    # Directories
100
    save_dir = increment_path(Path(project) / name, exist_ok=exist_ok)  # increment run
101
    (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True)  # make dir
102
103
    # Load model
104
    device = select_device(device)
105
    model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
106
    stride, names, pt = model.stride, model.names, model.pt
107
    # stride = 16
108
    imgsz = check_img_size(imgsz, s=stride)  # check image size
109
110
    # Dataloader
111
    bs = 1  # batch_size
112
    if webcam:
113
        view_img = check_imshow(warn=True)
114
        dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
115
        bs = len(dataset)
116
    elif screenshot:
117
        dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
118
    else:
119
        dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=False, vid_stride=vid_stride)
120
    vid_path, vid_writer = [None] * bs, [None] * bs
121
122
    # Run inference
123
    model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))  # warmup
124
    seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device))
125
    for path, im, im0s, vid_cap, s, orig_img in dataset:
126
        with dt[0]:
127
            im = torch.from_numpy(im).to(model.device)
128
            im = im.half() if model.fp16 else im.float()  # uint8 to fp16/32
129
            im /= 255  # 0 - 255 to 0.0 - 1.0
130
            if len(im.shape) == 3:
131
                im = im[None]  # expand for batch dim
132
            if model.xml and im.shape[0] > 1:
133
                ims = torch.chunk(im, im.shape[0], 0)
134
135
        # Inference
136
        with dt[1]:
137
            visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
138
            if model.xml and im.shape[0] > 1:
139
                pred = None
140
                for image in ims:
141
                    if pred is None:
142
                        pred,int_feats = model(image, augment=augment, visualize=visualize).unsqueeze(0)
143
                    else:
144
                        pred, int_feats = torch.cat((pred, model(image, augment=augment, visualize=visualize).unsqueeze(0)), dim=0)
145
                pred = [pred, None]
146
            else:
147
                pred,int_feats = model(im, augment=augment, visualize=visualize)
148
        # NMS
149
        with dt[2]:
150
            pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
151
            
152
            # int_feats_p3= int_feats[0][0,:,:,:].to(torch.float32)
153
            # int_feats_p3 = int_feats_p3.unsqueeze(0)#.unsqueeze(0)
154
155
            int_feats_p2 = int_feats[0][0].to(torch.float32).unsqueeze(0)
156
            int_feats_p3 = int_feats[1][0].to(torch.float32).unsqueeze(0)
157
158
            # concat_feat = torch.cat([int_feats_p2,int_feats_p3],dim=1)
159
            in_channels = int_feats_p2.shape[1]+int_feats_p3.shape[1]
160
            cell_attribute_model= MyCNN(num_classes=12, dropout_prob=0.5, in_channels=in_channels).to(device)
161
            folder_name = 'data/WBC_dataset_sample/Attribute_model'
162
            custom_weights_path = f"{folder_name}/last_weights.pth"
163
            custom_weights = torch.load(custom_weights_path)
164
            cell_attribute_model.load_state_dict(custom_weights)
165
            cell_attribute_model.eval().to(device)
166
167
            # int_feats_p5= int_feats[1][0,:,:,:].to(torch.float32)
168
            # int_feats_p5 = int_feats_p5.unsqueeze(0)#.unsqueeze(0)
169
            torch.cuda.empty_cache()
170
171
                # del int_feats
172
            # resized_int_feats_p5 = F.interpolate(int_feats_p5, size=(int_feats[0].size(2), int_feats[0].size(3)), mode='bilinear', align_corners=False)
173
            # concatenated_features = torch.cat([resized_int_feats_p5,int_feats_p3],dim=1)
174
            
175
            if (len(pred)>0):
176
                all_top_indices_cell_pred = []
177
                top_indices_cell_pred = []
178
                pred_Nuclear_Chromatin_array = []
179
                pred_Nuclear_Shape_array = []
180
                pred_Nucleus_array = []
181
                pred_Cytoplasm_array = []
182
                pred_Cytoplasmic_Basophilia_array = []
183
                pred_Cytoplasmic_Vacuoles_array = []
184
185
                for i in range(len(pred[0])):
186
                    if pred[0][i].numel() > 0:  # Check if the tensor is not empty
187
188
                        pred_tensor = pred[0][i][0:4]
189
                        
190
                        if pred[0][i][5] != 0:
191
                            
192
                            img_shape_tensor = torch.tensor([im.shape[2], im.shape[3],im.shape[2],im.shape[3]]).to(device)
193
194
                            normalized_xyxy=pred_tensor / img_shape_tensor
195
                            p2_feature_shape_tensor = torch.tensor([int_feats[0].shape[1], int_feats[0].shape[2],int_feats[0].shape[1],int_feats[0].shape[2]]).to(device)                        # reduce_channels_layer = torch.nn.Conv2d(1280, 250, kernel_size=1).to(device)
196
                            p3_feature_shape_tensor = torch.tensor([int_feats[1].shape[1], int_feats[1].shape[2],int_feats[1].shape[1],int_feats[1].shape[2]]).to(device)             # reduce_channels_layer = torch.nn.Conv2d(1280, 250, kernel_size=1).to(device)
197
                        
198
                        
199
                            p2_normalized_xyxy = normalized_xyxy*p2_feature_shape_tensor
200
                            p3_normalized_xyxy = normalized_xyxy*p3_feature_shape_tensor
201
                            p2_x_min, p2_y_min, p2_x_max, p2_y_max = get_fixed_xyxy(p2_normalized_xyxy,int_feats_p2)
202
                            p3_x_min, p3_y_min, p3_x_max, p3_y_max = get_fixed_xyxy(p3_normalized_xyxy,int_feats_p3)
203
204
                            p2_roi = torch.tensor([p2_x_min, p2_y_min, p2_x_max, p2_y_max], device=device).float() 
205
                            p3_roi = torch.tensor([p3_x_min, p3_y_min, p3_x_max, p3_y_max], device=device).float() 
206
207
                            batch_index = torch.tensor([0], dtype=torch.float32, device = device)
208
209
                            # Concatenate the batch index to the bounding box coordinates
210
                            p2_roi_with_batch_index = torch.cat([batch_index, p2_roi])
211
                            p3_roi_with_batch_index = torch.cat([batch_index, p3_roi])
212
                            p2_resized_object = roi_align(int_feats_p2, p2_roi_with_batch_index.unsqueeze(0).to(device), output_size=(24, 30))
213
                            p3_resized_object = roi_align(int_feats_p3, p3_roi_with_batch_index.unsqueeze(0).to(device), output_size=(24, 30))
214
                            concat_box = torch.cat([p2_resized_object,p3_resized_object],dim=1)
215
216
                            output_cell_prediction= cell_attribute_model(concat_box)
217
                            output_cell_prediction_prob = F.softmax(output_cell_prediction.view(6,2), dim=1)
218
                            top_indices_cell_pred = torch.argmax(output_cell_prediction_prob, dim=1)
219
                            pred_Nuclear_Chromatin_array.append(top_indices_cell_pred[0].item())
220
                            pred_Nuclear_Shape_array.append(top_indices_cell_pred[1].item())
221
                            pred_Nucleus_array.append(top_indices_cell_pred[2].item())
222
                            pred_Cytoplasm_array.append(top_indices_cell_pred[3].item())
223
                            pred_Cytoplasmic_Basophilia_array.append(top_indices_cell_pred[4].item())
224
                            pred_Cytoplasmic_Vacuoles_array.append(top_indices_cell_pred[5].item())
225
                        # all_top_indices_cell_pred.append(top_indices_cell_pred.item())
226
                        else:
227
                            # top_indices_cell_pred = torch.tensor([0,0,0,0,0,0]).to(device)
228
                            pred_Nuclear_Chromatin_array.append(0)
229
                            pred_Nuclear_Shape_array.append(0)
230
                            pred_Nucleus_array.append(0)
231
                            pred_Cytoplasm_array.append(0)
232
                            pred_Cytoplasmic_Basophilia_array.append(0)
233
                            pred_Cytoplasmic_Vacuoles_array.append(0)
234
235
236
237
238
        # Second-stage classifier (optional)
239
        # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
240
241
        # Define the path for the CSV file
242
        csv_path = save_dir / 'predictions.csv'
243
244
        # # Create or append to the CSV file
245
        # def write_to_csv(name, predicts, confid,pred_NC,pred_NS, 
246
        #                  pred_N,pred_C,pred_CB,
247
        #                  pred_CV,x_min,y_min,x_max,y_max):
248
        #     data = {'Image Name': name, 'Prediction': predicts, 'Confidence': confid, 'Nuclear Chromatin':pred_NC,
249
        #             'Nuclear Shape':pred_NS,'Nucleus':pred_N,'Cytoplasm':pred_C,
250
        #             'Cytoplasmic Basophilia': pred_CB, 'Cytoplasmic Vacuoles': pred_CV,
251
        #             'x_min':x_min,'y_min':y_min,'x_max':x_max,'y_max':y_max}
252
        #     with open(csv_path, mode='a', newline='') as f:
253
        #         writer = csv.DictWriter(f, fieldnames=data.keys())
254
        #         if not csv_path.is_file():
255
        #             writer.writeheader()
256
        #         writer.writerow(data)
257
        # Create or append to the CSV file
258
        def write_to_csv(name, predicts, confid, pred_NC, pred_NS, 
259
                        pred_N, pred_C, pred_CB, pred_CV,
260
                        x_min, y_min, x_max, y_max):
261
            data = {'Image Name': name, 'Prediction': predicts, 'Confidence': confid, 'Nuclear Chromatin': pred_NC,
262
                    'Nuclear Shape': pred_NS, 'Nucleus': pred_N, 'Cytoplasm': pred_C,
263
                    'Cytoplasmic Basophilia': pred_CB, 'Cytoplasmic Vacuoles': pred_CV,
264
                    'x_min': x_min, 'y_min': y_min, 'x_max': x_max, 'y_max': y_max}
265
            
266
            # Check if the CSV file exists
267
            if not os.path.isfile(csv_path):
268
                with open(csv_path, mode='w', newline='') as f:
269
                    writer = csv.DictWriter(f, fieldnames=data.keys())
270
                    writer.writeheader()
271
272
            # Append data to CSV file
273
            with open(csv_path, mode='a', newline='') as f:
274
                writer = csv.DictWriter(f, fieldnames=data.keys())
275
                writer.writerow(data)
276
277
        # Process predictions
278
        for i, det in enumerate(pred):  # per image
279
            seen += 1
280
            if webcam:  # batch_size >= 1
281
                p, im0, frame = path[i], im0s[i].copy(), dataset.count
282
                s += f'{i}: '
283
            else:
284
                p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
285
286
            p = Path(p)  # to Path
287
            save_path = str(save_dir / p.name)  # im.jpg
288
            txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}')  # im.txt
289
            s += '%gx%g ' % im.shape[2:]  # print string
290
            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
291
            imc = im0.copy() if save_crop else im0  # for save_crop
292
            annotator = Annotator(im0, line_width=line_thickness, example=str(names))
293
            if len(det):
294
                # Rescale boxes from img_size to im0 size
295
                det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
296
297
                # Print results
298
                for c in det[:, 5].unique():
299
                    n = (det[:, 5] == c).sum()  # detections per class
300
                    s += f"{n} {names[int(c)]}{'s' * (n > 1)}, "  # add to string
301
                # Write results
302
                for count, (*xyxy, conf, cls) in enumerate(det):
303
                    c = int(cls)  # integer class
304
                    label = names[c] if hide_conf else f'{names[c]}'
305
                    confidence = float(conf)
306
                    confidence_str = f'{confidence:.2f}'
307
                    
308
                    if save_csv:
309
                        x_min,y_min,x_max,y_max = xyxy
310
311
                        # Scaling factors
312
                        scale_width = orig_img.shape[1] / 640
313
                        scale_height = orig_img.shape[0] / 640
314
315
                        # Convert bounding box coordinates to 800x448 image
316
                        x_min_new = int(x_min * scale_width)
317
                        y_min_new = int(y_min * scale_height)
318
                        x_max_new = int(x_max * scale_width)
319
                        y_max_new = int(y_max * scale_height)
320
321
                        write_to_csv(p.name, label, confidence_str,
322
                                     pred_Nuclear_Chromatin_array[count],pred_Nuclear_Shape_array[count], 
323
                                     pred_Nucleus_array[count],pred_Cytoplasm_array[count],pred_Cytoplasmic_Basophilia_array[count],
324
                                     pred_Cytoplasmic_Vacuoles_array[count],
325
                                     int(x_min_new),int(y_min_new),
326
                                     int(x_max_new),int(y_max_new))
327
                        
328
329
                    if save_txt:  # Write to file
330
                        xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
331
                        line = (cls, *xywh, conf) if save_conf else (cls, *xywh)  # label format
332
                        with open(f'{txt_path}.txt', 'a') as f:
333
                            f.write(('%g ' * len(line)).rstrip() % line + '\n')
334
335
                    if save_img or save_crop or view_img:  # Add bbox to image
336
                        c = int(cls)  # integer class
337
                        label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
338
                        annotator.box_label(xyxy, label, color=colors(c, True))
339
                        # annotator.my_box_label(xyxy, label, color=colors(c, True), att1=pred_Nuclear_Chromatin_array[0],
340
                        #                        att2 = pred_Nuclear_Shape_array[0], att3 = pred_Nucleus_array[0],
341
                        #                        att4 = pred_Cytoplasm_array[0], att5 = pred_Cytoplasmic_Basophilia_array[0],
342
                        #                        att6 = pred_Cytoplasmic_Vacuoles_array[0]
343
                        #                        )
344
345
                    if save_crop:
346
                        save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
347
348
            # Stream results
349
            im0 = annotator.result()
350
            if view_img:
351
                if platform.system() == 'Linux' and p not in windows:
352
                    windows.append(p)
353
                    cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)  # allow window resize (Linux)
354
                    cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
355
                cv2.imshow(str(p), im0)
356
                cv2.waitKey(1)  # 1 millisecond
357
358
            # Save results (image with detections)
359
            if save_img:
360
                if dataset.mode == 'image':
361
                    cv2.imwrite(save_path, im0)
362
                else:  # 'video' or 'stream'
363
                    if vid_path[i] != save_path:  # new video
364
                        vid_path[i] = save_path
365
                        if isinstance(vid_writer[i], cv2.VideoWriter):
366
                            vid_writer[i].release()  # release previous video writer
367
                        if vid_cap:  # video
368
                            fps = vid_cap.get(cv2.CAP_PROP_FPS)
369
                            w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
370
                            h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
371
                        else:  # stream
372
                            fps, w, h = 30, im0.shape[1], im0.shape[0]
373
                        save_path = str(Path(save_path).with_suffix('.mp4'))  # force *.mp4 suffix on results videos
374
                        vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
375
                    vid_writer[i].write(im0)
376
377
        # Print time (inference-only)
378
        LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
379
380
    # Print results
381
    t = tuple(x.t / seen * 1E3 for x in dt)  # speeds per image
382
    LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
383
    if save_txt or save_img:
384
        s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
385
        LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
386
    if update:
387
        strip_optimizer(weights[0])  # update model (to fix SourceChangeWarning)
388
389
390
def parse_opt():
391
    parser = argparse.ArgumentParser()
392
    parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'runs/train/yolov5x_300Epochs_training/weights/best.pt', help='model path or triton URL')
393
    parser.add_argument('--source', type=str, default='/home/iml/Desktop/bc_experiment/HCM_V3/HCM_840_attribute/images/test/', help='file/dir/URL/glob/screen/0(webcam)')
394
    parser.add_argument('--data', type=str, default=ROOT / 'data/WBC_v1.yaml', help='(optional) dataset.yaml path')
395
    parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
396
    parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
397
    parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
398
    parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
399
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
400
    parser.add_argument('--view-img', action='store_true', help='show results')
401
    parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
402
    parser.add_argument('--save-csv', action='store_true', help='save results in CSV format')
403
    parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
404
    parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
405
    parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
406
    parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
407
    parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
408
    parser.add_argument('--augment', action='store_true', help='augmented inference')
409
    parser.add_argument('--visualize', action='store_true', help='visualize features')
410
    parser.add_argument('--update', action='store_true', help='update all models')
411
    parser.add_argument('--project', default=ROOT / 'runs/detect', help='save results to project/name')
412
    parser.add_argument('--name', default='exp', help='save results to project/name')
413
    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
414
    parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
415
    parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
416
    parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
417
    parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
418
    parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
419
    parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
420
    opt = parser.parse_args()
421
    opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1  # expand
422
    print_args(vars(opt))
423
    return opt
424
425
426
def main(opt):
427
    check_requirements(ROOT / 'requirements.txt', exclude=('tensorboard', 'thop'))
428
    run(**vars(opt))
429
430
431
if __name__ == '__main__':
432
    opt = parse_opt()
433
    main(opt)