--- a
+++ b/2DNet/src/predict.py
@@ -0,0 +1,521 @@
+#============ Basic imports ============#e
+import os
+import sys
+sys.path.insert(0, '..')
+import time
+import gc
+import pandas as pd
+import cv2
+import csv
+from torch.utils.data import DataLoader
+from dataset.dataset import *
+from tuils.tools import *
+from tqdm import tqdm
+import torch.nn as nn
+import numpy as np
+import random
+import math
+import argparse
+
+def randomHorizontalFlip(image, u=0.5):
+    if np.random.random() < u:
+        image = cv2.flip(image, 1)
+    return image
+
+def randomVerticleFlip(image, u=0.5):
+    if np.random.random() < u:
+        image = cv2.flip(image, 0)
+    return image
+
+def randomRotate90(image, u=0.5):
+    if np.random.random() < u:
+        image[:,:,0:3] = np.rot90(image[:,:,0:3])
+    return image
+
+def random_cropping(image, ratio=0.8, is_random = True):
+    height, width, _ = image.shape
+    target_h = int(height*ratio)
+    target_w = int(width*ratio)
+
+    if is_random:
+        start_x = random.randint(0, width - target_w)
+        start_y = random.randint(0, height - target_h)
+    else:
+        start_x = ( width - target_w ) // 2
+        start_y = ( height - target_h ) // 2
+
+    zeros = image[start_y:start_y+target_h,start_x:start_x+target_w,:]
+    zeros = cv2.resize(zeros ,(width,height))
+    return zeros
+
+def cropping(image, ratio=0.8, code = 0):
+    height, width, _ = image.shape
+    target_h = int(height*ratio)
+    target_w = int(width*ratio)
+
+    if code==0:
+        start_x = ( width - target_w ) // 2
+        start_y = ( height - target_h ) // 2
+
+    elif code == 1:
+        start_x = 0
+        start_y = 0
+
+    elif code == 2:
+        start_x = width - target_w
+        start_y = 0
+
+    elif code == 3:
+        start_x = 0
+        start_y = height - target_h
+
+    elif code == 4:
+        start_x = width - target_w
+        start_y = height - target_h
+
+    elif code == -1:
+        return image
+
+    zeros = image[start_y:start_y+target_h,start_x:start_x+target_w,:]
+    zeros = cv2.resize(zeros ,(width,height))
+    return zeros
+
+def random_erasing(img, probability=0.5, sl=0.02, sh=0.4, r1=0.3):
+    if random.uniform(0, 1) > probability:
+        return img
+
+    for attempt in range(100):
+        area = img.shape[0] * img.shape[1]
+
+        target_area = random.uniform(sl, sh) * area
+        aspect_ratio = random.uniform(r1, 1 / r1)
+
+        h = int(round(math.sqrt(target_area * aspect_ratio)))
+        w = int(round(math.sqrt(target_area / aspect_ratio)))
+
+        if w < img.shape[1] and h < img.shape[0]:
+            x1 = random.randint(0, img.shape[0] - h)
+            y1 = random.randint(0, img.shape[1] - w)
+            if img.shape[2] == 3:
+                img[x1:x1 + h, y1:y1 + w, :] = 0.0
+            else:
+                print('!!!!!!!! random_erasing dim wrong!!!!!!!!!!!')
+                return
+
+            return img
+
+    return img
+
+def randomShiftScaleRotate(image,
+                           shift_limit=(-0.0, 0.0),
+                           scale_limit=(-0.0, 0.0),
+                           rotate_limit=(-0.0, 0.0),
+                           aspect_limit=(-0.0, 0.0),
+                           borderMode=cv2.BORDER_CONSTANT, u=0.5):
+
+    if np.random.random() < u:
+        height, width, channel = image.shape
+
+        angle = np.random.uniform(rotate_limit[0], rotate_limit[1])
+        scale = np.random.uniform(1 + scale_limit[0], 1 + scale_limit[1])
+        aspect = np.random.uniform(1 + aspect_limit[0], 1 + aspect_limit[1])
+        sx = scale * aspect / (aspect ** 0.5)
+        sy = scale / (aspect ** 0.5)
+        dx = round(np.random.uniform(shift_limit[0], shift_limit[1]) * width)
+        dy = round(np.random.uniform(shift_limit[0], shift_limit[1]) * height)
+
+        cc = np.math.cos(angle / 180 * np.math.pi) * sx
+        ss = np.math.sin(angle / 180 * np.math.pi) * sy
+        rotate_matrix = np.array([[cc, -ss], [ss, cc]])
+
+        box0 = np.array([[0, 0], [width, 0], [width, height], [0, height], ])
+        box1 = box0 - np.array([width / 2, height / 2])
+        box1 = np.dot(box1, rotate_matrix.T) + np.array([width / 2 + dx, height / 2 + dy])
+
+        box0 = box0.astype(np.float32)
+        box1 = box1.astype(np.float32)
+        mat = cv2.getPerspectiveTransform(box0, box1)
+        image = cv2.warpPerspective(image, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode,
+                                    borderValue=(
+                                        0, 0,
+                                        0,))
+    return image
+
+
+def aug_image(image, is_infer=False, augment = [0,0]):
+
+    if is_infer:
+        image = randomHorizontalFlip(image, u=augment[0])
+        image = np.asarray(image)
+        image = cropping(image, ratio=0.8, code=augment[1])
+        return image
+
+    else:
+        image = randomHorizontalFlip(image)
+        height, width, _ = image.shape
+        image = randomShiftScaleRotate(image,
+                                       shift_limit=(-0.1, 0.1),
+                                       scale_limit=(-0.1, 0.1),
+                                       aspect_limit=(-0.1, 0.1),
+                                       rotate_limit=(-30, 30))
+
+        image = cv2.resize(image, (width, height))
+        image = random_erasing(image, probability=0.5, sl=0.02, sh=0.4, r1=0.3)
+
+        ratio = random.uniform(0.6,0.99)
+        image = random_cropping(image, ratio=ratio, is_random=True)
+
+        return image
+
+valid_transform_aug = albumentations.Compose([
+
+    albumentations.Normalize(mean=(0.456, 0.456, 0.456), std=(0.224, 0.224, 0.224), max_pixel_value=255.0, p=1.0)
+])
+
+valid_transform_pure = albumentations.Compose([
+
+    albumentations.Normalize(mean=(0.456, 0.456, 0.456), std=(0.224, 0.224, 0.224), max_pixel_value=255.0, p=1.0)
+])
+
+class PredictionDatasetPure:
+    def __init__(self, name_list, df_train, df_test, n_test_aug, mode):
+        self.name_list = name_list
+        if mode == 'val':
+            self.df = df_train[df_train['filename'].isin(name_list)]
+        elif mode == 'test':
+            self.df = df_test[df_test['filename'].isin(name_list)]
+        self.n_test_aug = n_test_aug
+        self.mode = mode
+
+    def __len__(self):
+        return len(self.name_list) * self.n_test_aug
+
+    def __getitem__(self, idx):
+        if self.mode == 'val':
+            filename = self.name_list[idx % len(self.name_list)]
+            image_cat = cv2.imread('/home1/kaggle_rsna2019/process/train_concat_3images_256/' + filename)
+            label = torch.FloatTensor(self.df[self.df['filename']==filename].loc[:, 'any':'subdural'].values)
+
+        if self.mode == 'test':
+            filename = self.name_list[idx % len(self.name_list)]
+            image_cat = cv2.imread('/home1/kaggle_rsna2019/process/stage2_test_concat_3images/' + filename)
+            image_cat = cv2.resize(image_cat, (256, 256))
+            label = torch.FloatTensor([0,0,0,0,0,0])
+
+        image_cat = aug_image(image_cat, is_infer=True)
+        image_cat = valid_transform_pure(image=image_cat)['image'].transpose(2, 0, 1)
+        
+        return filename, image_cat, label
+
+class PredictionDatasetAug:
+    def __init__(self, name_list, df_train, df_test, n_test_aug, mode):
+        self.name_list = name_list
+        if mode == 'val':
+            self.df = df_train[df_train['filename'].isin(name_list)]
+        elif mode == 'test':
+            self.df = df_test[df_test['filename'].isin(name_list)]
+        self.n_test_aug = n_test_aug
+        self.mode = mode
+
+    def __len__(self):
+        return len(self.name_list) * self.n_test_aug
+
+    def __getitem__(self, idx):
+        if self.mode == 'val':
+            filename = self.name_list[idx % len(self.name_list)]
+            image_cat = cv2.imread('/home1/kaggle_rsna2019/process/train_concat_3images_256/' + filename)
+            image_cat = cv2.resize(image_cat, (256, 256))
+            label = torch.FloatTensor(self.df[self.df['filename']==filename].loc[:, 'any':'subdural'].values)
+        if self.mode == 'test':
+            filename = self.name_list[idx % len(self.name_list)]
+            image_cat = cv2.imread('/home1/kaggle_rsna2019/process/stage2_test_concat_3images/' + filename)
+            image_cat = cv2.resize(image_cat, (256, 256))
+            label = torch.FloatTensor([0,0,0,0,0,0])
+
+        if random.random() < 0.5:
+            image_cat = cv2.cvtColor(image_cat, cv2.COLOR_BGR2RGB)
+        else:
+            image_cat
+
+        image_cat = randomHorizontalFlip(image_cat, u=0.5)
+        height, width, _ = image_cat.shape
+        ratio = random.uniform(0.6,0.99)
+        image_cat = random_cropping(image_cat, ratio=ratio, is_random=True)
+        image_cat = valid_transform_aug(image=image_cat)['image'].transpose(2, 0, 1)
+        
+        return filename, image_cat, label
+
+def predict(model, name_list, df_all, df_test, batch_size: int, n_test_aug: int, aug=False, mode='val', fold=0):
+    if aug:
+        loader = DataLoader(
+            dataset=PredictionDatasetAug(name_list, df_all, df_test, n_test_aug, mode),
+            shuffle=False,
+            batch_size=batch_size,
+            num_workers=16,
+            pin_memory=True
+        )
+    else:
+        loader = DataLoader(
+            dataset=PredictionDatasetPure(name_list, df_all, df_test, n_test_aug, mode),
+            shuffle=False,
+            batch_size=batch_size,
+            num_workers=16,
+            pin_memory=True
+        )
+
+    model.eval()
+
+    all_names = []
+    all_outputs = torch.FloatTensor().cuda()
+    all_truth = torch.FloatTensor().cuda()
+
+    features_list = {}
+    for names, inputs, labels in tqdm(loader, desc='Predict'):
+        labels = labels.view(-1, 6).contiguous().cuda(async=True)
+        all_truth = torch.cat((all_truth, labels), 0)
+        with torch.no_grad():
+            inputs = torch.autograd.variable(inputs).cuda(async=True)
+
+        if backbone == 'DenseNet121_change_avg':
+            feature = model.module.densenet121(inputs)      
+            feature = model.module.relu(feature)
+            feature = model.module.avgpool(feature)
+            feature = feature.view(feature.size(0), -1)
+            for index, name in enumerate(names):
+                if name not in features_list:
+                    features_list[name] = feature[index,:].cpu().detach().numpy()/10
+                else:
+                    features_list[name] += feature[index,:].cpu().detach().numpy()/10
+            feature = model.module.mlp(feature)
+
+        elif backbone == 'DenseNet169_change_avg':
+            feature = model.module.densenet169(inputs)      
+            feature = model.module.relu(feature)
+            feature = model.module.avgpool(feature)
+            feature = feature.view(feature.size(0), -1)
+            for index, name in enumerate(names):
+                if name not in features_list:
+                    features_list[name] = feature[index,:].cpu().detach().numpy()/10
+                else:
+                    features_list[name] += feature[index,:].cpu().detach().numpy()/10
+            feature = model.module.mlp(feature)
+
+        elif backbone == 'se_resnext101_32x4d':
+            feature = model.module.model_ft.layer0(inputs)
+            feature = model.module.model_ft.layer1(feature)
+            feature = model.module.model_ft.layer2(feature)
+            feature = model.module.model_ft.layer3(feature)
+            feature = model.module.model_ft.layer4(feature)
+            feature = model.module.model_ft.avg_pool(feature)
+
+            feature = feature.view(feature.size(0), -1)
+
+            for index, name in enumerate(names):
+                if name not in features_list:
+                    features_list[name] = feature[index,:].cpu().detach().numpy()/10
+                else:
+                    features_list[name] += feature[index,:].cpu().detach().numpy()/10
+
+            feature = model.module.model_ft.last_linear(feature)
+
+            
+        feature = feature.sigmoid()
+        all_outputs = torch.cat((all_outputs, feature.data), 0)
+        all_names.extend(names)
+
+    for key in features_list.keys():
+        if mode == 'val':
+            np.save(model_snapshot_path + 'prediction/npy_train/' + key.replace('.png', '.npy'), features_list[key].astype(np.float16))
+        else:
+            np.save(model_snapshot_path + 'prediction/npy_test/' + key.replace('.png', '_'+str(fold)+'.npy'), features_list[key].astype(np.float16))
+
+    datanpGT = all_truth.cpu().numpy()
+    datanpPRED = all_outputs.cpu().numpy()
+    return datanpPRED, all_names, datanpGT
+
+def group_aug(val_p_aug, val_names_aug, val_truth_aug):
+    """
+    Average augmented predictions
+    :param val_p_aug:
+    :return:
+    """
+    df_prob = pd.DataFrame(val_p_aug)
+    df_prob['id'] = val_names_aug
+    
+    df_truth = pd.DataFrame(val_truth_aug)
+    df_truth['id'] = val_names_aug
+
+    g_prob = df_prob.groupby('id').mean()
+    g_prob = g_prob.reset_index()
+    g_prob = g_prob.sort_values(by='id')
+    
+    g_truth = df_truth.groupby('id').mean()
+    g_truth = g_truth.reset_index()
+    g_truth = g_truth.sort_values(by='id')
+
+    return g_prob.drop('id', 1).values, g_truth['id'].values, g_truth.drop('id', 1).values
+
+
+def predict_all(model_name, image_size):
+
+    for fold in [0,1,2,3,4]:
+
+        print(fold)
+        
+        if not os.path.exists(model_snapshot_path + 'prediction/'):
+            os.makedirs(model_snapshot_path + 'prediction/')
+        if not os.path.exists(model_snapshot_path + 'prediction/npy_train/'):
+            os.makedirs(model_snapshot_path + 'prediction/npy_train/')
+        if not os.path.exists(model_snapshot_path + 'prediction/npy_test/'):
+            os.makedirs(model_snapshot_path + 'prediction/npy_test/')
+
+        prediction_path = model_snapshot_path+'prediction/fold_{fold}'.format(fold=fold)
+
+        f_val = open(kfold_path + 'fold{fold}/val.txt'.format(fold=fold), 'r')
+        c_val = f_val.readlines()
+        f_val.close()
+        c_val = [s.replace('\n', '') for s in c_val]
+            
+        model = eval(model_name+'()')
+        model = nn.DataParallel(model).cuda()
+
+        state = torch.load(model_snapshot_path + 'model_epoch_best_{fold}.pth'.format(fold=fold))
+
+        epoch = state['epoch']
+        best_valid_loss = state['valLoss']
+        model.load_state_dict(state['state_dict'])
+        print(epoch, best_valid_loss)
+
+        model.eval()
+        
+        if is_center:
+            val_p, val_names, val_truth = predict(model, c_val, df_all, df_test, batch_size, 1, False, 'val', fold)
+            val_predictions, val_image_names, val_truth = group_aug(val_p, val_names, val_truth)
+            val_loss, val_loss_sum = weighted_log_loss_numpy(val_predictions, val_truth)
+            print('val_loss = ', val_loss, 'val_loss_sum = ', val_loss_sum)
+            
+            df = pd.DataFrame(data=val_predictions,  columns=['any', 'epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural'])
+            df['filename'] = val_image_names
+            df.to_csv(prediction_path + '_val_center.csv')
+
+        if is_aug:
+            val_p_aug, val_names_aug, val_truth_aug = predict(model, c_val, df_all, df_test, batch_size, num_aug, True, 'val', fold)
+            val_predictions_aug, val_image_names_aug, val_truth_aug = group_aug(val_p_aug, val_names_aug, val_truth_aug)
+            val_loss_aug, val_loss_sum_aug = weighted_log_loss_numpy(val_predictions_aug, val_truth_aug)
+            print('val_loss_aug = ', val_loss_aug, 'val_loss_sum_aug = ', val_loss_sum_aug)
+            df = pd.DataFrame(data=val_predictions_aug,  columns=['any', 'epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural'])
+            df['filename'] = val_image_names_aug
+            df.to_csv(prediction_path + '_val_aug_{num_aug}.csv'.format(num_aug=num_aug))
+
+            test_p_aug, test_names_aug, test_truth_aug = predict(model, c_test, df_all, df_test, batch_size, num_aug, True, 'test', fold)
+            test_predictions_aug, test_image_names_aug, test_truth_aug = group_aug(test_p_aug, test_names_aug, test_truth_aug)
+
+            df = pd.DataFrame(data=test_predictions_aug,  columns=['any', 'epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural'])
+            df['filename'] = test_image_names_aug
+            df.to_csv(prediction_path + '_test_aug_{num_aug}.csv'.format(num_aug=num_aug))
+            
+
+        with open(model_snapshot_path + 'prediction/{model_name}.csv'.format(model_name=model_name), 'a', newline='') as f:
+            writer = csv.writer(f)
+            if is_center:
+                writer.writerow([fold, 1, val_loss, val_loss_sum])  
+            if is_aug:
+                writer.writerow([fold, num_aug, val_loss_aug, val_loss_sum_aug])  
+
+    val_lists_center = []
+    test_lists_center = []
+    val_lists_aug = []
+    test_lists_aug = []
+
+    prediction_path = model_snapshot_path + 'prediction/'
+    for fold in range(5):
+        if is_center:
+            df_val_center = pd.read_csv(prediction_path + 'fold_{fold}_val_center.csv'.format(fold=fold), index_col=0 )
+            val_lists_center.append(df_val_center)
+            df_test_center = pd.read_csv(prediction_path + 'fold_{fold}_test_center.csv'.format(fold=fold), index_col=0)
+            test_lists_center.append(df_test_center)
+
+        if is_aug:
+            df_val_aug = pd.read_csv(prediction_path + 'fold_{fold}_val_aug_{num_aug}.csv'.format(fold=fold, num_aug=num_aug), index_col=0)
+            val_lists_aug.append(df_val_aug)
+            df_test_aug = pd.read_csv(prediction_path + 'fold_{fold}_test_aug_{num_aug}.csv'.format(fold=fold, num_aug=num_aug), index_col=0)
+            test_lists_aug.append(df_test_aug)
+
+    if is_center:
+        df_val_center = pd.concat(val_lists_center)
+        df_val_center = df_val_center.sort_values(by='filename').reset_index(drop = True)
+        df_val_center.to_csv(prediction_path + 'val_center.csv', index=0)
+
+        val_predictions_center = df_val_center.loc[:, ['any', 'epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural']].values
+        val_loss_center, val_loss_sum_center = weighted_log_loss_numpy(val_predictions_center, val_truth_oof)
+        print('center: ', val_loss_center, val_loss_sum_center)
+
+    if is_aug:
+        df_val_aug = pd.concat(val_lists_aug)
+        df_val_aug = df_val_aug.sort_values(by='filename').reset_index(drop = True)
+        df_val_aug.to_csv(prediction_path + 'val_aug_{num_aug}.csv'.format(num_aug=num_aug), index=0)
+
+        val_predictions_aug = df_val_aug.loc[:, ['any', 'epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural']].values
+        val_loss_aug, val_loss_sum_aug = weighted_log_loss_numpy(val_predictions_aug, val_truth_oof)
+        print('aug: ', val_loss_aug, val_loss_sum_aug)
+
+
+    with open(model_snapshot_path + 'prediction/{model_name}.csv'.format(model_name=model_name), 'a', newline='') as f:
+        writer = csv.writer(f)
+        if is_center:
+            writer.writerow(['center: ', val_loss_center, val_loss_sum_center])  
+        if is_aug:
+            writer.writerow(['aug: ', val_loss_aug, val_loss_sum_aug])  
+
+    if is_center:
+        df_test_center = pd.concat(test_lists_center)
+        df_test_center = df_test_center.groupby('filename').mean()
+        df_test_center.to_csv(prediction_path + 'test_center.csv')
+        
+    if is_aug:
+        df_test_aug = pd.concat(test_lists_aug)
+        df_test_aug = pd.concat(test_lists_aug)
+        df_test_aug = df_test_aug.groupby('filename').mean()
+        df_test_aug.to_csv(prediction_path + 'test_aug_{num_aug}.csv'.format(num_aug=num_aug))
+    
+    
+if __name__ == '__main__':
+    csv_path = '../data/stage1_train_cls.csv'
+    test_csv_path = '../data/stage2_test_cls.csv'
+
+    parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+    parser.add_argument("-backbone", "--backbone", type=str, default='DenseNet121_change_avg', help='backbone')
+    parser.add_argument("-img_size", "--Image_size", type=int, default=1024, help='image_size')
+    parser.add_argument("-tbs", "--train_batch_size", type=int, default=4, help='train_batch_size')
+    parser.add_argument("-vbs", "--val_batch_size", type=int, default=4, help='val_batch_size')
+
+    parser.add_argument("-spth", "--snapshot_path", type=str,
+                        default='DenseNet121_change_avg', help='epoch')
+
+    args = parser.parse_args()
+
+    Image_size = args.Image_size
+    train_batch_size = args.train_batch_size
+    val_batch_size = args.val_batch_size
+    batch_size = val_batch_size
+    workers = 4
+    print(Image_size)
+    print(train_batch_size)
+    print(val_batch_size)
+
+    model_snapshot_path = args.snapshot_path.replace('\n', '').replace('\r', '') + '/'
+    kfold_path = '../data/fold_5_by_study_image/'
+
+    df_test = pd.read_csv(test_csv_path)  
+    c_test = list(set(df_test['filename'].values.tolist()))
+    df_all = pd.read_csv(csv_path)
+    is_center = False
+    is_aug = True
+    num_aug = 10
+    val_truth_oof = df_all.sort_values(by='filename').reset_index(drop = True).loc[:, ['any', 'epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural']].values
+
+    backbone = args.backbone
+    print(backbone)
+    predict_all(backbone, Image_size)
+
+