Diff of /src/testing.py [000000] .. [f45789]

Switch to unified view

a b/src/testing.py
1
import sys
2
sys.path.append('.')
3
import torch
4
from torch.nn import functional as F
5
import os
6
import yaml
7
from src.new_grad_cam import gc
8
9
10
def test(conf):
11
    device = conf['device']
12
    dataset = conf['test_dataset']
13
    classes = conf['data']['classes']
14
    weights_path = conf['weights_path']
15
    results_dir = conf['results_dir']
16
17
    model = conf['model']
18
    model.load_state_dict(torch.load(weights_path))
19
    model = model.to(device)
20
    model.eval()
21
22
    gc(model=model,
23
       dataset=dataset,
24
       results_dir=results_dir,
25
       classes=classes,
26
       device=device)
27
28
29
if __name__ == '__main__':
30
    from config import get_config
31
    conf = get_config('./conf/testing.yaml')
32
    test(conf)