Diff of /evaluate.py [000000] .. [5ba3a6]

Switch to unified view

a b/evaluate.py
1
# -*- coding: utf-8 -*-
2
"""
3
@File    : trian_res34.py
4
@Time    : 2019/6/23 15:40
5
@Author  : Parker
6
@Email   : now_cherish@163.com
7
@Software: PyCharm
8
@Des     : 
9
"""
10
11
import torch
12
import torch.nn as nn
13
import torch.nn.functional as F
14
from torch.utils.data import DataLoader
15
import torch.optim as optim
16
from tensorboardX import SummaryWriter
17
18
import numpy as np
19
import time
20
import datetime
21
import argparse
22
import os
23
import os.path as osp
24
25
from rs_dataset import RSDataset
26
from get_logger import get_logger
27
from res_network import Resnet18, Resnet34
28
29
30
def parse_args():
31
    parse = argparse.ArgumentParser()
32
    parse.add_argument('--epoch', type=int, default=15)
33
    parse.add_argument('--schedule_step', type=int, default=2)
34
35
    parse.add_argument('--batch_size', type=int, default=48)
36
    parse.add_argument('--test_batch_size', type=int, default=256)
37
    parse.add_argument('--num_workers', type=int, default=16)
38
39
    parse.add_argument('--eval_fre', type=int, default=1)
40
    parse.add_argument('--msg_fre', type=int, default=10)
41
    parse.add_argument('--save_fre', type=int, default=2)
42
43
    parse.add_argument('--name', type=str, default='res34_baseline',
44
                       help='unique out file name of this task include log/model_out/tensorboard log')
45
    parse.add_argument('--data_dir', type=str, default='/home/tiger/projects/rscup2019_classifier/data')
46
    parse.add_argument('--log_dir', type=str, default='./logs')
47
    parse.add_argument('--tensorboard_dir', type=str, default='./tensorboard')
48
    parse.add_argument('--model_out_dir', type=str, default='./model_out')
49
    parse.add_argument('--model_out_name', type=str, default='final_model.pth')
50
    parse.add_argument('--seed', type=int, default=5, help='random seed')
51
    parse.add_argument('--eval_model_path', type=str,
52
                       default='/home/tiger/projects/rscup2019_classifier/model_out/logistic_out_6.pth')
53
    return parse.parse_args()
54
55
56
def evalute(args):
57
    val_set = RSDataset(rootpth=args.data_dir, mode='val')
58
    val_loader = DataLoader(val_set,
59
                            batch_size=args.test_batch_size,
60
                            drop_last=True,
61
                            shuffle=True,
62
                            pin_memory=True,
63
                            num_workers=args.num_workers)
64
    net = Resnet34()
65
    net.eval()
66
    net.load_state_dict(torch.load(args.eval_model_path))
67
    net.cuda()
68
69
    total = [0 for i in range(45)]
70
    correct = [0 for i in range(45)]
71
    with torch.no_grad():
72
        for img, lb in val_loader:
73
            img, lb = img.cuda(), lb.cuda()
74
            outputs = net(img)
75
            outputs = torch.sigmoid(outputs)
76
            predicted = torch.max(outputs, dim=1)[1]
77
            res = predicted == lb
78
79
            for label_idx in range(args.test_batch_size):
80
                label_single = lb[label_idx]
81
                correct[label_single] += res[label_idx].item()
82
                total[label_single] += 1
83
            # print(correct, total)
84
85
        acc_str = 'Accuracy: {}\n'.format(sum(correct)/sum(total))
86
        for acc_idx in range(45):
87
            try:
88
                acc = correct[acc_idx] / total[acc_idx]
89
            except:
90
                acc = 0
91
            finally:
92
                acc_str += 'classID: {},\taccuracy: {}\n'.format(acc_idx+1, acc)
93
    print(acc_str)
94
95
96
if __name__ == '__main__':
97
    args = parse_args()
98
99
    torch.manual_seed(args.seed)
100
    torch.cuda.manual_seed(args.seed)
101
102
    evalute(args)