# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import mmcv
import numpy as np
import torch.multiprocessing as mp
from mmaction.localization import (generate_bsp_feature,
generate_candidate_proposals)
def load_video_infos(ann_file):
"""Load the video annotations.
Args:
ann_file (str): A json file path of the annotation file.
Returns:
list[dict]: A list containing annotations for videos.
"""
video_infos = []
anno_database = mmcv.load(ann_file)
for video_name in anno_database:
video_info = anno_database[video_name]
video_info['video_name'] = video_name
video_infos.append(video_info)
return video_infos
def generate_proposals(ann_file, tem_results_dir, pgm_proposals_dir,
pgm_proposals_thread, **kwargs):
"""Generate proposals using multi-process.
Args:
ann_file (str): A json file path of the annotation file for
all videos to be processed.
tem_results_dir (str): Directory to read tem results
pgm_proposals_dir (str): Directory to save generated proposals.
pgm_proposals_thread (int): Total number of threads.
kwargs (dict): Keyword arguments for "generate_candidate_proposals".
"""
video_infos = load_video_infos(ann_file)
num_videos = len(video_infos)
num_videos_per_thread = num_videos // pgm_proposals_thread
processes = []
manager = mp.Manager()
result_dict = manager.dict()
kwargs['result_dict'] = result_dict
for tid in range(pgm_proposals_thread - 1):
tmp_video_list = range(tid * num_videos_per_thread,
(tid + 1) * num_videos_per_thread)
p = mp.Process(
target=generate_candidate_proposals,
args=(
tmp_video_list,
video_infos,
tem_results_dir,
),
kwargs=kwargs)
p.start()
processes.append(p)
tmp_video_list = range((pgm_proposals_thread - 1) * num_videos_per_thread,
num_videos)
p = mp.Process(
target=generate_candidate_proposals,
args=(
tmp_video_list,
video_infos,
tem_results_dir,
),
kwargs=kwargs)
p.start()
processes.append(p)
for p in processes:
p.join()
# save results
os.makedirs(pgm_proposals_dir, exist_ok=True)
prog_bar = mmcv.ProgressBar(num_videos)
header = 'tmin,tmax,tmin_score,tmax_score,score,match_iou,match_ioa'
for video_name in result_dict:
proposals = result_dict[video_name]
proposal_path = osp.join(pgm_proposals_dir, video_name + '.csv')
np.savetxt(
proposal_path,
proposals,
header=header,
delimiter=',',
comments='')
prog_bar.update()
def generate_features(ann_file, tem_results_dir, pgm_proposals_dir,
pgm_features_dir, pgm_features_thread, **kwargs):
"""Generate proposals features using multi-process.
Args:
ann_file (str): A json file path of the annotation file for
all videos to be processed.
tem_results_dir (str): Directory to read tem results.
pgm_proposals_dir (str): Directory to read generated proposals.
pgm_features_dir (str): Directory to save generated features.
pgm_features_thread (int): Total number of threads.
kwargs (dict): Keyword arguments for "generate_bsp_feature".
"""
video_infos = load_video_infos(ann_file)
num_videos = len(video_infos)
num_videos_per_thread = num_videos // pgm_features_thread
processes = []
manager = mp.Manager()
feature_return_dict = manager.dict()
kwargs['result_dict'] = feature_return_dict
for tid in range(pgm_features_thread - 1):
tmp_video_list = range(tid * num_videos_per_thread,
(tid + 1) * num_videos_per_thread)
p = mp.Process(
target=generate_bsp_feature,
args=(
tmp_video_list,
video_infos,
tem_results_dir,
pgm_proposals_dir,
),
kwargs=kwargs)
p.start()
processes.append(p)
tmp_video_list = range((pgm_features_thread - 1) * num_videos_per_thread,
num_videos)
p = mp.Process(
target=generate_bsp_feature,
args=(
tmp_video_list,
video_infos,
tem_results_dir,
pgm_proposals_dir,
),
kwargs=kwargs)
p.start()
processes.append(p)
for p in processes:
p.join()
# save results
os.makedirs(pgm_features_dir, exist_ok=True)
prog_bar = mmcv.ProgressBar(num_videos)
for video_name in feature_return_dict.keys():
bsp_feature = feature_return_dict[video_name]
feature_path = osp.join(pgm_features_dir, video_name + '.npy')
np.save(feature_path, bsp_feature)
prog_bar.update()
def parse_args():
parser = argparse.ArgumentParser(description='Proposal generation module')
parser.add_argument('config', help='test config file path')
parser.add_argument(
'--mode',
choices=['train', 'test'],
default='test',
help='train or test')
args = parser.parse_args()
return args
def main():
print('Begin Proposal Generation Module')
args = parse_args()
cfg = mmcv.Config.fromfile(args.config)
tem_results_dir = cfg.tem_results_dir
pgm_proposals_dir = cfg.pgm_proposals_dir
pgm_features_dir = cfg.pgm_features_dir
if args.mode == 'test':
generate_proposals(cfg.ann_file_val, tem_results_dir,
pgm_proposals_dir, **cfg.pgm_proposals_cfg)
print('\nFinish proposal generation')
generate_features(cfg.ann_file_val, tem_results_dir, pgm_proposals_dir,
pgm_features_dir, **cfg.pgm_features_test_cfg)
print('\nFinish feature generation')
elif args.mode == 'train':
generate_proposals(cfg.ann_file_train, tem_results_dir,
pgm_proposals_dir, **cfg.pgm_proposals_cfg)
print('\nFinish proposal generation')
generate_features(cfg.ann_file_train, tem_results_dir,
pgm_proposals_dir, pgm_features_dir,
**cfg.pgm_features_train_cfg)
print('\nFinish feature generation')
print('Finish Proposal Generation Module')
if __name__ == '__main__':
main()