--- a +++ b/src/run_grad_cam.py @@ -0,0 +1,19 @@ +from grad_cam import gc +from model import initialize_model +import torch +import os +from shutil import rmtree + +model_name = 'resnet' +num_classes = 2 +feature_extract = False + +model, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True) +model.load_state_dict(torch.load('./IH_resnet_weights.pt')) +data_path = './imgs_to_test_grad_cam/' +classes = ['noIH', 'IH'] +output_dir = './results' +if os.path.exists(output_dir): + rmtree(output_dir) +os.makedirs(output_dir) +gc(data_path, output_dir, model, classes) \ No newline at end of file