Switch to side-by-side view

--- a
+++ b/tools/misc/bsn_proposal_generation.py
@@ -0,0 +1,198 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import os
+import os.path as osp
+
+import mmcv
+import numpy as np
+import torch.multiprocessing as mp
+
+from mmaction.localization import (generate_bsp_feature,
+                                   generate_candidate_proposals)
+
+
+def load_video_infos(ann_file):
+    """Load the video annotations.
+
+    Args:
+        ann_file (str): A json file path of the annotation file.
+
+    Returns:
+        list[dict]: A list containing annotations for videos.
+    """
+    video_infos = []
+    anno_database = mmcv.load(ann_file)
+    for video_name in anno_database:
+        video_info = anno_database[video_name]
+        video_info['video_name'] = video_name
+        video_infos.append(video_info)
+    return video_infos
+
+
+def generate_proposals(ann_file, tem_results_dir, pgm_proposals_dir,
+                       pgm_proposals_thread, **kwargs):
+    """Generate proposals using multi-process.
+
+    Args:
+        ann_file (str): A json file path of the annotation file for
+            all videos to be processed.
+        tem_results_dir (str): Directory to read tem results
+        pgm_proposals_dir (str): Directory to save generated proposals.
+        pgm_proposals_thread (int): Total number of threads.
+        kwargs (dict): Keyword arguments for "generate_candidate_proposals".
+    """
+    video_infos = load_video_infos(ann_file)
+    num_videos = len(video_infos)
+    num_videos_per_thread = num_videos // pgm_proposals_thread
+    processes = []
+    manager = mp.Manager()
+    result_dict = manager.dict()
+    kwargs['result_dict'] = result_dict
+    for tid in range(pgm_proposals_thread - 1):
+        tmp_video_list = range(tid * num_videos_per_thread,
+                               (tid + 1) * num_videos_per_thread)
+        p = mp.Process(
+            target=generate_candidate_proposals,
+            args=(
+                tmp_video_list,
+                video_infos,
+                tem_results_dir,
+            ),
+            kwargs=kwargs)
+        p.start()
+        processes.append(p)
+
+    tmp_video_list = range((pgm_proposals_thread - 1) * num_videos_per_thread,
+                           num_videos)
+    p = mp.Process(
+        target=generate_candidate_proposals,
+        args=(
+            tmp_video_list,
+            video_infos,
+            tem_results_dir,
+        ),
+        kwargs=kwargs)
+    p.start()
+    processes.append(p)
+
+    for p in processes:
+        p.join()
+
+    # save results
+    os.makedirs(pgm_proposals_dir, exist_ok=True)
+    prog_bar = mmcv.ProgressBar(num_videos)
+    header = 'tmin,tmax,tmin_score,tmax_score,score,match_iou,match_ioa'
+    for video_name in result_dict:
+        proposals = result_dict[video_name]
+        proposal_path = osp.join(pgm_proposals_dir, video_name + '.csv')
+        np.savetxt(
+            proposal_path,
+            proposals,
+            header=header,
+            delimiter=',',
+            comments='')
+        prog_bar.update()
+
+
+def generate_features(ann_file, tem_results_dir, pgm_proposals_dir,
+                      pgm_features_dir, pgm_features_thread, **kwargs):
+    """Generate proposals features using multi-process.
+
+    Args:
+        ann_file (str): A json file path of the annotation file for
+            all videos to be processed.
+        tem_results_dir (str): Directory to read tem results.
+        pgm_proposals_dir (str): Directory to read generated proposals.
+        pgm_features_dir (str): Directory to save generated features.
+        pgm_features_thread (int): Total number of threads.
+        kwargs (dict): Keyword arguments for "generate_bsp_feature".
+    """
+    video_infos = load_video_infos(ann_file)
+    num_videos = len(video_infos)
+    num_videos_per_thread = num_videos // pgm_features_thread
+    processes = []
+    manager = mp.Manager()
+    feature_return_dict = manager.dict()
+    kwargs['result_dict'] = feature_return_dict
+    for tid in range(pgm_features_thread - 1):
+        tmp_video_list = range(tid * num_videos_per_thread,
+                               (tid + 1) * num_videos_per_thread)
+        p = mp.Process(
+            target=generate_bsp_feature,
+            args=(
+                tmp_video_list,
+                video_infos,
+                tem_results_dir,
+                pgm_proposals_dir,
+            ),
+            kwargs=kwargs)
+        p.start()
+        processes.append(p)
+    tmp_video_list = range((pgm_features_thread - 1) * num_videos_per_thread,
+                           num_videos)
+    p = mp.Process(
+        target=generate_bsp_feature,
+        args=(
+            tmp_video_list,
+            video_infos,
+            tem_results_dir,
+            pgm_proposals_dir,
+        ),
+        kwargs=kwargs)
+    p.start()
+    processes.append(p)
+
+    for p in processes:
+        p.join()
+
+    # save results
+    os.makedirs(pgm_features_dir, exist_ok=True)
+    prog_bar = mmcv.ProgressBar(num_videos)
+    for video_name in feature_return_dict.keys():
+        bsp_feature = feature_return_dict[video_name]
+        feature_path = osp.join(pgm_features_dir, video_name + '.npy')
+        np.save(feature_path, bsp_feature)
+        prog_bar.update()
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Proposal generation module')
+    parser.add_argument('config', help='test config file path')
+    parser.add_argument(
+        '--mode',
+        choices=['train', 'test'],
+        default='test',
+        help='train or test')
+    args = parser.parse_args()
+    return args
+
+
+def main():
+    print('Begin Proposal Generation Module')
+    args = parse_args()
+    cfg = mmcv.Config.fromfile(args.config)
+    tem_results_dir = cfg.tem_results_dir
+    pgm_proposals_dir = cfg.pgm_proposals_dir
+    pgm_features_dir = cfg.pgm_features_dir
+    if args.mode == 'test':
+        generate_proposals(cfg.ann_file_val, tem_results_dir,
+                           pgm_proposals_dir, **cfg.pgm_proposals_cfg)
+        print('\nFinish proposal generation')
+        generate_features(cfg.ann_file_val, tem_results_dir, pgm_proposals_dir,
+                          pgm_features_dir, **cfg.pgm_features_test_cfg)
+        print('\nFinish feature generation')
+
+    elif args.mode == 'train':
+        generate_proposals(cfg.ann_file_train, tem_results_dir,
+                           pgm_proposals_dir, **cfg.pgm_proposals_cfg)
+        print('\nFinish proposal generation')
+        generate_features(cfg.ann_file_train, tem_results_dir,
+                          pgm_proposals_dir, pgm_features_dir,
+                          **cfg.pgm_features_train_cfg)
+        print('\nFinish feature generation')
+
+    print('Finish Proposal Generation Module')
+
+
+if __name__ == '__main__':
+    main()