--- a +++ b/src/inference.py @@ -0,0 +1,259 @@ +import torch +import torch.nn as nn +import torch.nn.functional as Ftorch +from torch.utils.data import DataLoader +from tqdm import * + +from models import * +from augmentation import * +from dataset import * +import glob + + +device = torch.device('cuda') + + +def predict(model, loader): + model.eval() + preds = [] + with torch.no_grad(): + for dct in tqdm(loader, total=len(loader)): + images = dct['images'].to(device) + # meta = dct["meta"].to(device) + pred = model(images) + pred = Ftorch.sigmoid(pred) + pred = pred.detach().cpu().numpy() + preds.append(pred) + + preds = np.concatenate(preds, axis=0) + return preds + + +def get_best_checkpoints(checkpoint_dir, n_best=3, minimize_metric=True): + files = glob.glob(f"{checkpoint_dir}/checkpoints/best*.pth") + files = [file for file in files if not 'full' in file] + + top_best_metrics = [] + for file in files: + ckp = torch.load(file) + valid_metric = ckp['valid_metrics']['loss'] + top_best_metrics.append((file, valid_metric)) + + top_best_metrics = sorted( + top_best_metrics, + key=lambda x: x[1], + reverse=not minimize_metric + ) + top_best_metrics = top_best_metrics[:n_best] + return top_best_metrics + + +def predict_test_tta_ckp(): + test_csv = "./csv/patient2_kfold/test.csv" + # test_root = "/data/stage_1_test_3w/" + # test_root = "/data/png/test_stage_1/adjacent-brain-cropped/" + # test_root = "/data/stage_1_test_3w/" + test_root = "/data/stage_1_test_images_jpg_preprocessing/" + image_type = 'jpg' + + image_size = [512, 512] + backbone = "densenet169" + normalization = True + # fold = 2 + for fold in [0, 1, 2, 3, 4]: + # /logs/rsna/test/resnet50-anju-512-resume-0/checkpoints//train512.13.pth + scheme = f"{backbone}-mw-512-resume-{fold}" + + log_dir = f"/logs/rsna/test/{scheme}/" + + with_any = True + + if with_any: + num_classes = 6 + target_cols = LABEL_COLS + else: + num_classes = 5 + target_cols = LABEL_COLS_WITHOUT_ANY + + # test_preds = 0 + + top_best_metrics = get_best_checkpoints(log_dir, n_best=1, minimize_metric=True) + + test_preds = 0 + for best_metric in top_best_metrics: + + checkpoint_path, checkpoint_metric = best_metric + print("*" * 50) + print(f"checkpoint: {checkpoint_path}") + print(f"Metric: {checkpoint_metric}") + + model = CNNFinetuneModels( + model_name=backbone, + num_classes=num_classes, + pretrained=False + ) + + ckp = os.path.join(log_dir, f"checkpoints/best.pth") + checkpoint = torch.load(ckp) + model.load_state_dict(checkpoint['model_state_dict']) + model = nn.DataParallel(model) + model = model.to(device) + + augs = test_tta(image_size, normalization) + + for name, aug in augs.items(): + print("Augmentation: {}".format(name)) + + test_dataset = RSNADataset( + csv_file=test_csv, + root=test_root, + with_any=with_any, + transform=aug, + mode="test", + image_type=image_type + ) + + test_loader = DataLoader( + dataset=test_dataset, + batch_size=64, + shuffle=False, + num_workers=8, + ) + + test_preds += predict(model, test_loader) / (len(augs) * len(top_best_metrics)) + + os.makedirs(f"/logs/prediction/{scheme}", exist_ok=True) + np.save(f"/logs/prediction/{scheme}/test_{fold}_ckp_tta.npy", test_preds) + + test_df = pd.read_csv(test_csv) + test_ids = test_df['sop_instance_uid'].values + + ids = [] + labels = [] + for i, id in enumerate(test_ids): + if not "ID" in id: + id = "ID_" + id + pred = test_preds[i] + for j, target in enumerate(target_cols): + id_target = id + "_" + target + ids.append(id_target) + labels.append(pred[j]) + if not with_any: + id_target = id + "_" + "any" + ids.append(id_target) + labels.append(pred.max()) + + submission_df = pd.DataFrame({ + 'ID': ids, + 'Label': labels + }) + + submission_df.to_csv(f"/logs/prediction/{scheme}/{scheme}_ckp_tta.csv", index=False) + + +def predict_valid_tta_ckp(): + + # test_root = "/data/png/train/adjacent-brain-cropped/" + # test_root = "/data/stage_1_train_3w/" + test_root = "/data/stage_1_test_images_jpg_preprocessing/" + image_type = 'jpg' + + image_size = [512, 512] + backbone = "densenet169" + normalization = True + # fold = 2 + for fold in [0, 1, 2, 3, 4]: + test_csv = f"./csv/patient2_kfold/valid_{fold}.csv" + # /logs/rsna/test/resnet50-anju-512-resume-0/checkpoints//train512.13.pth + scheme = f"{backbone}-mw-512-resume-{fold}" + + log_dir = f"/logs/rsna/test/{scheme}/" + + with_any = True + + if with_any: + num_classes = 6 + target_cols = LABEL_COLS + else: + num_classes = 5 + target_cols = LABEL_COLS_WITHOUT_ANY + + # test_preds = 0 + + top_best_metrics = get_best_checkpoints(log_dir, n_best=1, minimize_metric=True) + + test_preds = 0 + for best_metric in top_best_metrics: + + checkpoint_path, checkpoint_metric = best_metric + print("*" * 50) + print(f"checkpoint: {checkpoint_path}") + print(f"Metric: {checkpoint_metric}") + + model = CNNFinetuneModels( + model_name=backbone, + num_classes=num_classes, + ) + + ckp = os.path.join(log_dir, f"checkpoints/best.pth") + checkpoint = torch.load(ckp) + model.load_state_dict(checkpoint['model_state_dict']) + model = nn.DataParallel(model) + model = model.to(device) + + augs = test_tta(image_size, normalization) + + for name, aug in augs.items(): + print("Augmentation: {}".format(name)) + + test_dataset = RSNADataset( + csv_file=test_csv, + root=test_root, + with_any=with_any, + transform=aug, + mode="valid", + image_type=image_type + ) + + test_loader = DataLoader( + dataset=test_dataset, + batch_size=64, + shuffle=False, + num_workers=8, + ) + + test_preds += predict(model, test_loader) / (len(augs) * len(top_best_metrics)) + + os.makedirs(f"/logs/prediction/{scheme}", exist_ok=True) + np.save(f"/logs/prediction/{scheme}/valid_{scheme}.npy", test_preds) + + test_df = pd.read_csv(test_csv) + test_ids = test_df['sop_instance_uid'].values + + ids = [] + labels = [] + for i, id in enumerate(test_ids): + if not "ID" in id: + id = "ID_" + id + pred = test_preds[i] + for j, target in enumerate(target_cols): + id_target = id + "_" + target + ids.append(id_target) + labels.append(pred[j]) + if not with_any: + id_target = id + "_" + "any" + ids.append(id_target) + labels.append(pred.max()) + + submission_df = pd.DataFrame({ + 'ID': ids, + 'Label': labels + }) + + submission_df.to_csv(f"/logs/prediction/{scheme}/valid_{scheme}.csv", index=False) + + +if __name__ == '__main__': + # predict_test() + predict_test_tta_ckp() + # predict_valid_tta_ckp() \ No newline at end of file