Diff of /src/inference.py [000000] .. [95f789]

Switch to side-by-side view

--- a
+++ b/src/inference.py
@@ -0,0 +1,259 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as Ftorch
+from torch.utils.data import DataLoader
+from tqdm import *
+
+from models import *
+from augmentation import *
+from dataset import *
+import glob
+
+
+device = torch.device('cuda')
+
+
+def predict(model, loader):
+    model.eval()
+    preds = []
+    with torch.no_grad():
+        for dct in tqdm(loader, total=len(loader)):
+            images = dct['images'].to(device)
+            # meta = dct["meta"].to(device)
+            pred = model(images)
+            pred = Ftorch.sigmoid(pred)
+            pred = pred.detach().cpu().numpy()
+            preds.append(pred)
+
+    preds = np.concatenate(preds, axis=0)
+    return preds
+
+
+def get_best_checkpoints(checkpoint_dir, n_best=3, minimize_metric=True):
+    files = glob.glob(f"{checkpoint_dir}/checkpoints/best*.pth")
+    files = [file for file in files if not 'full' in file]
+
+    top_best_metrics = []
+    for file in files:
+        ckp = torch.load(file)
+        valid_metric = ckp['valid_metrics']['loss']
+        top_best_metrics.append((file, valid_metric))
+
+    top_best_metrics = sorted(
+        top_best_metrics,
+        key=lambda x: x[1],
+        reverse=not minimize_metric
+    )
+    top_best_metrics = top_best_metrics[:n_best]
+    return top_best_metrics
+
+
+def predict_test_tta_ckp():
+    test_csv = "./csv/patient2_kfold/test.csv"
+    # test_root = "/data/stage_1_test_3w/"
+    # test_root = "/data/png/test_stage_1/adjacent-brain-cropped/"
+    # test_root = "/data/stage_1_test_3w/"
+    test_root = "/data/stage_1_test_images_jpg_preprocessing/"
+    image_type = 'jpg'
+
+    image_size = [512, 512]
+    backbone = "densenet169"
+    normalization = True
+    # fold = 2
+    for fold in [0, 1, 2, 3, 4]:
+        # /logs/rsna/test/resnet50-anju-512-resume-0/checkpoints//train512.13.pth
+        scheme = f"{backbone}-mw-512-resume-{fold}"
+
+        log_dir = f"/logs/rsna/test/{scheme}/"
+
+        with_any = True
+
+        if with_any:
+            num_classes = 6
+            target_cols = LABEL_COLS
+        else:
+            num_classes = 5
+            target_cols = LABEL_COLS_WITHOUT_ANY
+
+        # test_preds = 0
+
+        top_best_metrics = get_best_checkpoints(log_dir, n_best=1, minimize_metric=True)
+
+        test_preds = 0
+        for best_metric in top_best_metrics:
+
+            checkpoint_path, checkpoint_metric = best_metric
+            print("*" * 50)
+            print(f"checkpoint: {checkpoint_path}")
+            print(f"Metric: {checkpoint_metric}")
+
+            model = CNNFinetuneModels(
+                model_name=backbone,
+                num_classes=num_classes,
+                pretrained=False
+            )
+
+            ckp = os.path.join(log_dir, f"checkpoints/best.pth")
+            checkpoint = torch.load(ckp)
+            model.load_state_dict(checkpoint['model_state_dict'])
+            model = nn.DataParallel(model)
+            model = model.to(device)
+
+            augs = test_tta(image_size, normalization)
+
+            for name, aug in augs.items():
+                print("Augmentation: {}".format(name))
+
+                test_dataset = RSNADataset(
+                    csv_file=test_csv,
+                    root=test_root,
+                    with_any=with_any,
+                    transform=aug,
+                    mode="test",
+                    image_type=image_type
+                )
+
+                test_loader = DataLoader(
+                    dataset=test_dataset,
+                    batch_size=64,
+                    shuffle=False,
+                    num_workers=8,
+                )
+
+                test_preds += predict(model, test_loader) / (len(augs) * len(top_best_metrics))
+
+        os.makedirs(f"/logs/prediction/{scheme}", exist_ok=True)
+        np.save(f"/logs/prediction/{scheme}/test_{fold}_ckp_tta.npy", test_preds)
+
+        test_df = pd.read_csv(test_csv)
+        test_ids = test_df['sop_instance_uid'].values
+
+        ids = []
+        labels = []
+        for i, id in enumerate(test_ids):
+            if not "ID" in id:
+                id = "ID_" + id
+            pred = test_preds[i]
+            for j, target in enumerate(target_cols):
+                id_target = id + "_" + target
+                ids.append(id_target)
+                labels.append(pred[j])
+            if not with_any:
+                id_target = id + "_" + "any"
+                ids.append(id_target)
+                labels.append(pred.max())
+
+        submission_df = pd.DataFrame({
+            'ID': ids,
+            'Label': labels
+        })
+
+        submission_df.to_csv(f"/logs/prediction/{scheme}/{scheme}_ckp_tta.csv", index=False)
+
+
+def predict_valid_tta_ckp():
+
+    # test_root = "/data/png/train/adjacent-brain-cropped/"
+    # test_root = "/data/stage_1_train_3w/"
+    test_root = "/data/stage_1_test_images_jpg_preprocessing/"
+    image_type = 'jpg'
+
+    image_size = [512, 512]
+    backbone = "densenet169"
+    normalization = True
+    # fold = 2
+    for fold in [0, 1, 2, 3, 4]:
+        test_csv = f"./csv/patient2_kfold/valid_{fold}.csv"
+        # /logs/rsna/test/resnet50-anju-512-resume-0/checkpoints//train512.13.pth
+        scheme = f"{backbone}-mw-512-resume-{fold}"
+
+        log_dir = f"/logs/rsna/test/{scheme}/"
+
+        with_any = True
+
+        if with_any:
+            num_classes = 6
+            target_cols = LABEL_COLS
+        else:
+            num_classes = 5
+            target_cols = LABEL_COLS_WITHOUT_ANY
+
+        # test_preds = 0
+
+        top_best_metrics = get_best_checkpoints(log_dir, n_best=1, minimize_metric=True)
+
+        test_preds = 0
+        for best_metric in top_best_metrics:
+
+            checkpoint_path, checkpoint_metric = best_metric
+            print("*" * 50)
+            print(f"checkpoint: {checkpoint_path}")
+            print(f"Metric: {checkpoint_metric}")
+
+            model = CNNFinetuneModels(
+                model_name=backbone,
+                num_classes=num_classes,
+            )
+
+            ckp = os.path.join(log_dir, f"checkpoints/best.pth")
+            checkpoint = torch.load(ckp)
+            model.load_state_dict(checkpoint['model_state_dict'])
+            model = nn.DataParallel(model)
+            model = model.to(device)
+
+            augs = test_tta(image_size, normalization)
+
+            for name, aug in augs.items():
+                print("Augmentation: {}".format(name))
+
+                test_dataset = RSNADataset(
+                    csv_file=test_csv,
+                    root=test_root,
+                    with_any=with_any,
+                    transform=aug,
+                    mode="valid",
+                    image_type=image_type
+                )
+
+                test_loader = DataLoader(
+                    dataset=test_dataset,
+                    batch_size=64,
+                    shuffle=False,
+                    num_workers=8,
+                )
+
+                test_preds += predict(model, test_loader) / (len(augs) * len(top_best_metrics))
+
+        os.makedirs(f"/logs/prediction/{scheme}", exist_ok=True)
+        np.save(f"/logs/prediction/{scheme}/valid_{scheme}.npy", test_preds)
+
+        test_df = pd.read_csv(test_csv)
+        test_ids = test_df['sop_instance_uid'].values
+
+        ids = []
+        labels = []
+        for i, id in enumerate(test_ids):
+            if not "ID" in id:
+                id = "ID_" + id
+            pred = test_preds[i]
+            for j, target in enumerate(target_cols):
+                id_target = id + "_" + target
+                ids.append(id_target)
+                labels.append(pred[j])
+            if not with_any:
+                id_target = id + "_" + "any"
+                ids.append(id_target)
+                labels.append(pred.max())
+
+        submission_df = pd.DataFrame({
+            'ID': ids,
+            'Label': labels
+        })
+
+        submission_df.to_csv(f"/logs/prediction/{scheme}/valid_{scheme}.csv", index=False)
+
+
+if __name__ == '__main__':
+    # predict_test()
+    predict_test_tta_ckp()
+    # predict_valid_tta_ckp()
\ No newline at end of file