--- a +++ b/tools/misc/bsn_proposal_generation.py @@ -0,0 +1,198 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp + +import mmcv +import numpy as np +import torch.multiprocessing as mp + +from mmaction.localization import (generate_bsp_feature, + generate_candidate_proposals) + + +def load_video_infos(ann_file): + """Load the video annotations. + + Args: + ann_file (str): A json file path of the annotation file. + + Returns: + list[dict]: A list containing annotations for videos. + """ + video_infos = [] + anno_database = mmcv.load(ann_file) + for video_name in anno_database: + video_info = anno_database[video_name] + video_info['video_name'] = video_name + video_infos.append(video_info) + return video_infos + + +def generate_proposals(ann_file, tem_results_dir, pgm_proposals_dir, + pgm_proposals_thread, **kwargs): + """Generate proposals using multi-process. + + Args: + ann_file (str): A json file path of the annotation file for + all videos to be processed. + tem_results_dir (str): Directory to read tem results + pgm_proposals_dir (str): Directory to save generated proposals. + pgm_proposals_thread (int): Total number of threads. + kwargs (dict): Keyword arguments for "generate_candidate_proposals". + """ + video_infos = load_video_infos(ann_file) + num_videos = len(video_infos) + num_videos_per_thread = num_videos // pgm_proposals_thread + processes = [] + manager = mp.Manager() + result_dict = manager.dict() + kwargs['result_dict'] = result_dict + for tid in range(pgm_proposals_thread - 1): + tmp_video_list = range(tid * num_videos_per_thread, + (tid + 1) * num_videos_per_thread) + p = mp.Process( + target=generate_candidate_proposals, + args=( + tmp_video_list, + video_infos, + tem_results_dir, + ), + kwargs=kwargs) + p.start() + processes.append(p) + + tmp_video_list = range((pgm_proposals_thread - 1) * num_videos_per_thread, + num_videos) + p = mp.Process( + target=generate_candidate_proposals, + args=( + tmp_video_list, + video_infos, + tem_results_dir, + ), + kwargs=kwargs) + p.start() + processes.append(p) + + for p in processes: + p.join() + + # save results + os.makedirs(pgm_proposals_dir, exist_ok=True) + prog_bar = mmcv.ProgressBar(num_videos) + header = 'tmin,tmax,tmin_score,tmax_score,score,match_iou,match_ioa' + for video_name in result_dict: + proposals = result_dict[video_name] + proposal_path = osp.join(pgm_proposals_dir, video_name + '.csv') + np.savetxt( + proposal_path, + proposals, + header=header, + delimiter=',', + comments='') + prog_bar.update() + + +def generate_features(ann_file, tem_results_dir, pgm_proposals_dir, + pgm_features_dir, pgm_features_thread, **kwargs): + """Generate proposals features using multi-process. + + Args: + ann_file (str): A json file path of the annotation file for + all videos to be processed. + tem_results_dir (str): Directory to read tem results. + pgm_proposals_dir (str): Directory to read generated proposals. + pgm_features_dir (str): Directory to save generated features. + pgm_features_thread (int): Total number of threads. + kwargs (dict): Keyword arguments for "generate_bsp_feature". + """ + video_infos = load_video_infos(ann_file) + num_videos = len(video_infos) + num_videos_per_thread = num_videos // pgm_features_thread + processes = [] + manager = mp.Manager() + feature_return_dict = manager.dict() + kwargs['result_dict'] = feature_return_dict + for tid in range(pgm_features_thread - 1): + tmp_video_list = range(tid * num_videos_per_thread, + (tid + 1) * num_videos_per_thread) + p = mp.Process( + target=generate_bsp_feature, + args=( + tmp_video_list, + video_infos, + tem_results_dir, + pgm_proposals_dir, + ), + kwargs=kwargs) + p.start() + processes.append(p) + tmp_video_list = range((pgm_features_thread - 1) * num_videos_per_thread, + num_videos) + p = mp.Process( + target=generate_bsp_feature, + args=( + tmp_video_list, + video_infos, + tem_results_dir, + pgm_proposals_dir, + ), + kwargs=kwargs) + p.start() + processes.append(p) + + for p in processes: + p.join() + + # save results + os.makedirs(pgm_features_dir, exist_ok=True) + prog_bar = mmcv.ProgressBar(num_videos) + for video_name in feature_return_dict.keys(): + bsp_feature = feature_return_dict[video_name] + feature_path = osp.join(pgm_features_dir, video_name + '.npy') + np.save(feature_path, bsp_feature) + prog_bar.update() + + +def parse_args(): + parser = argparse.ArgumentParser(description='Proposal generation module') + parser.add_argument('config', help='test config file path') + parser.add_argument( + '--mode', + choices=['train', 'test'], + default='test', + help='train or test') + args = parser.parse_args() + return args + + +def main(): + print('Begin Proposal Generation Module') + args = parse_args() + cfg = mmcv.Config.fromfile(args.config) + tem_results_dir = cfg.tem_results_dir + pgm_proposals_dir = cfg.pgm_proposals_dir + pgm_features_dir = cfg.pgm_features_dir + if args.mode == 'test': + generate_proposals(cfg.ann_file_val, tem_results_dir, + pgm_proposals_dir, **cfg.pgm_proposals_cfg) + print('\nFinish proposal generation') + generate_features(cfg.ann_file_val, tem_results_dir, pgm_proposals_dir, + pgm_features_dir, **cfg.pgm_features_test_cfg) + print('\nFinish feature generation') + + elif args.mode == 'train': + generate_proposals(cfg.ann_file_train, tem_results_dir, + pgm_proposals_dir, **cfg.pgm_proposals_cfg) + print('\nFinish proposal generation') + generate_features(cfg.ann_file_train, tem_results_dir, + pgm_proposals_dir, pgm_features_dir, + **cfg.pgm_features_train_cfg) + print('\nFinish feature generation') + + print('Finish Proposal Generation Module') + + +if __name__ == '__main__': + main()