Switch to side-by-side view

--- a
+++ b/inference/total_video_inference.py
@@ -0,0 +1,275 @@
+import argparse
+import os
+import os.path as osp
+import json
+import time
+from tqdm import tqdm 
+
+from PIL import Image
+import numpy as np
+
+import torch
+import torchvision.datasets as datasets
+from torchvision import transforms
+import torch.utils.data as data
+
+import mmcv
+from mmaction.datasets.pipelines import Compose
+from mmaction.models import build_model
+# Multi GPU
+from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
+from mmcv.parallel import collate
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='workflow recognition')
+    parser.add_argument("--local_rank", type=int, help="Local rank. Necessary for using the torch.distributed.launch utility.")
+    parser.add_argument('--data_path', default='.', help='dataset prefix')
+    parser.add_argument('--output_prefix', default='', help='output prefix')
+    parser.add_argument('--task_type', default='active_bleeding', help='active_bleeding')
+    parser.add_argument('--data_list', default=None, help='video list of the dataset, the format should be')
+    parser.add_argument(
+        '--frame_interval',
+        type=int,
+        default=30,
+        help='the sampling frequency of frame in the untrimed video')
+    parser.add_argument(
+        '--temporal_stride',
+        type=int,
+        default=30,
+        help='clip in frame interval')
+    parser.add_argument('--ckpt', default='.pth', help='checkpoint for feature extraction')
+    parser.add_argument(
+        '--batch_size', type=int, default=256, help='input batch size')
+    parser.add_argument(
+        '--numModel', type=int, default=1, help='number of the model')
+        
+    parser.add_argument(
+        '--config_file', type=str, default='ckpt/phase.py', help='config_file')
+    parser.add_argument(
+        '--num_class', type=int, default=2, help='active bleeding : 2')
+    parser.add_argument(
+        '--kfold', type=int, default=1, help='cross validation')
+    parser.add_argument(
+        '--multigpu', type=int, default=1, help='cross validation')
+    args = parser.parse_args()
+    return args
+
+
+class mmaction_inference():
+    def __init__(self, args):
+        self.args = args
+        self.config = mmcv.Config.fromfile(args.config_file)
+
+        self.ckpt = args.ckpt
+        self.num_class = args.num_class
+        self.kfold = args.kfold
+     
+        self.data_path = args.data_path
+
+
+        self.datalist = open(args.data_list).readlines()
+        self.datalist = [x.strip() for x in self.datalist]
+        
+        # multi gpu
+        self.device = torch.device("cuda:{}".format(args.local_rank) if torch.cuda.is_available() else "cpu")
+        
+        self.data_set()
+        self.model_set()
+    def data_set(self):
+        # Data Setting
+        self.img_norm_cfg = self.config['img_norm_cfg']
+        self.img_norm_cfg['mean'] = [i / 255.0 for i in self.img_norm_cfg['mean']]
+        self.img_norm_cfg['std'] = [i / 255.0 for i in self.img_norm_cfg['std']]
+      
+        self.clip_len = self.config['data']['test']['pipeline'][0]['clip_len'] # clip len
+
+        self.transform = transforms.Compose([
+                transforms.Scale(224),
+                transforms.CenterCrop(224),
+                # transforms.PILToTensor(),
+                transforms.ToTensor(),
+                transforms.Normalize(mean=self.img_norm_cfg['mean'], std=self.img_norm_cfg['std']),
+                # transforms.ToPILImage(),
+                
+        ])
+    
+
+        kfold = self.kfold
+        if kfold == 1:
+            test_patient = [3, 4, 6, 13, 17, 18, 22, 116, 208, 303]
+        elif kfold == 2:
+            test_patient = [1, 7, 10, 19, 56, 74, 100, 117, 203, 304]
+        elif kfold == 3:
+            test_patient = [5, 48, 76, 94, 202, 204, 206, 209, 301, 305]
+
+
+    def model_set(self):
+        # Model Setting
+        model = build_model(self.config['model'])
+        state_dict = torch.load(self.ckpt)['state_dict']
+        model.load_state_dict(state_dict)
+        # self.model = model.cuda()
+        if self.args.multigpu > 1:
+            self.model = MMDataParallel(
+                model.cuda(0), device_ids=[self.args.local_rank], output_device=self.args.local_rank)
+            self.model = self.model.to(self.device)
+        else:
+            self.model = model.cuda()
+        
+    
+    def forward(self):
+        prog_bar = mmcv.ProgressBar(len(self.datalist))
+        probability = dict()
+        for videoID in self.datalist:
+            videoID = videoID.strip()
+            frame_dir = os.path.join(self.data_path, videoID)
+            output_dir = os.path.join(self.args.output_prefix, videoID)
+
+            if not osp.exists(output_dir):
+                os.system(f'mkdir -p {output_dir}')
+            start = time.time()
+   
+            print('\nstart', videoID)
+            inference_time_output_file = self.args.task_type + '_time.txt'
+            inference_time_output_file = osp.join(output_dir, inference_time_output_file)
+            output_file = self.args.task_type + '.json'
+            output_file = osp.join(output_dir, output_file)
+            
+            # first frame 
+            framelist = sorted(os.listdir(frame_dir))[::self.args.frame_interval]
+            probability[videoID] = np.zeros((len(framelist), self.num_class))
+            first_dataset = GastricDataset(data_path=frame_dir, datalist=framelist, temporal_stride=self.args.temporal_stride, windows=self.clip_len, transform=self.transform)
+            # first_loader = torch.utils.data.DataLoader(first_dataset, batch_size=self.args.batch_size, shuffle=False, num_workers=8, pin_memory=True)
+     
+            # middle frame
+            framelist = sorted(os.listdir(frame_dir))[::self.args.frame_interval]
+            framelist = [framelist[0]] * ((self.clip_len - 1) // 2) + framelist[:-1* ((self.clip_len - 1) // 2)]
+   
+            middle_dataset = GastricDataset(data_path=frame_dir, datalist=framelist, temporal_stride=self.args.temporal_stride, windows=self.clip_len, transform=self.transform)
+            # middle_loader = torch.utils.data.DataLoader(middle_dataset, batch_size=self.args.batch_size, shuffle=False, num_workers=8, pin_memory=True)
+    
+            # last frame
+            framelist = sorted(os.listdir(frame_dir))[::self.args.frame_interval]
+            framelist = [framelist[0]] * (self.clip_len - 1) + framelist[:-1* (self.clip_len - 1)]
+            last_dataset = GastricDataset(data_path=frame_dir, datalist=framelist, temporal_stride=self.args.temporal_stride, windows=self.clip_len, transform=self.transform)
+            # last_loader = torch.utils.data.DataLoader(last_dataset, batch_size=self.args.batch_size, shuffle=False, num_workers=8, pin_memory=True)
+            
+            dataset = ConcatDataset(first_dataset, middle_dataset, last_dataset)
+            if self.args.multigpu > 1:
+                rank, world_size = get_dist_info()
+                sampler = DistributedSampler(
+                    dataset, world_size, rank, shuffle=False, seed=None)
+            else:
+                sampler=None
+            data_loader = torch.utils.data.DataLoader(dataset,sampler=sampler,
+                batch_size=self.args.batch_size, shuffle=False, num_workers=6, pin_memory=True)
+            
+            with torch.no_grad():
+                self.model.eval()
+                idx = 0
+                for batch_idx, (images1,images2,images3) in enumerate(tqdm(data_loader)):#zip(first_loader, middle_loader, last_loader), total=len(first_loader))):
+                    # images1, images2, images3 = images1.cuda(), images2.cuda(), images3.cuda()
+                    images = torch.cat((images1.unsqueeze(1), images2.unsqueeze(1), images3.unsqueeze(1)), dim=1).cuda()
+
+                    prob = self.model(images, return_loss=False, infer_3d=True) #+ self.model(images2, return_loss=False, infer_3d=True) + self.model(images3, return_loss=False, infer_3d=True)
+              
+                    probability[videoID][idx:idx+images1.size(0)] = prob.squeeze().cpu().numpy()
+                    idx += images1.size(0)
+                results = self.save2json(probability[videoID])
+
+            with open(output_file, "w") as json_file:
+                json.dump(results, json_file)
+            end = time.time()
+            with open(inference_time_output_file, 'w') as f:
+                f.write(str(end - start))
+            prog_bar.update()
+
+    def save2json(self, prob):
+        result = dict()
+        result['labels'] = {0 : 'Normal', 1: 'Actvie Bleeding'}
+        result['prediction'] = dict()
+        
+        predictions = []
+        for idx in range(len(prob)):
+            predictions.append(str(np.argmax(prob[idx])))
+        result['result'] = predictions
+        probabilitys = []
+        for idx in range(len(prob)):
+            probabilitys.append(list(prob[idx]))
+        result['prob'] = probabilitys
+        result['frameInterval'] =  str(self.args.frame_interval) 
+        return result
+
+def main():
+    args = parse_args()
+    inference = mmaction_inference(args)
+    inference.forward()
+
+    # enumerate Untrimmed videos, extract feature from each of them
+    
+    
+ 
+    
+    
+    
+class ConcatDataset(data.Dataset):
+    def __init__(self, *datasets):
+        self.datasets = datasets
+
+    def __getitem__(self, i):
+        return tuple(d[i] for d in self.datasets)
+
+    def __len__(self):
+        return min(len(d) for d in self.datasets)
+    
+
+class GastricDataset(data.Dataset):
+    "Characterizes a dataset for PyTorch"
+    def __init__(self, data_path, datalist, temporal_stride, windows, transform=None):
+        "Initialization"
+        self.data_path = data_path
+        self.transform = transform
+        self.temporal_stride = temporal_stride
+        self.windows = windows
+        self.datalist = datalist
+
+    def __len__(self):
+        "Denotes the total number of samples"
+        return len(self.datalist)
+    
+
+    def read_images(self, data):  
+        X = []
+        ts = self.temporal_stride
+        frameNum = int(data[-10:-4])
+
+        iamge = None 
+        for i in range(0, self.windows):
+           
+            image_path = os.path.join(self.data_path, 'frame' + str((i*ts) \
+                + frameNum).zfill(10) + '.jpg')
+            if os.path.exists(image_path):
+                image =Image.open(image_path)
+                if self.transform is not None:
+                    image = self.transform(image)
+
+            X.append(image)
+        X = torch.stack(X, dim=0)
+
+        return X
+
+    def __getitem__(self, index):
+        "Generates one sample of data"
+        # Select sample
+
+        data = self.datalist[index]
+
+        # Load data
+        X = self.read_images(data)     # (input) spatial images
+        # BatchSize, Channel, Temporal, Width, Height
+        X = X.permute(1, 0, 2, 3)
+        return X
+    
+
+if __name__ == '__main__':
+    main()