[5ba3a6]: / evaluate.py

Download this file

103 lines (85 with data), 3.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
# -*- coding: utf-8 -*-
"""
@File : trian_res34.py
@Time : 2019/6/23 15:40
@Author : Parker
@Email : now_cherish@163.com
@Software: PyCharm
@Des :
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from tensorboardX import SummaryWriter
import numpy as np
import time
import datetime
import argparse
import os
import os.path as osp
from rs_dataset import RSDataset
from get_logger import get_logger
from res_network import Resnet18, Resnet34
def parse_args():
parse = argparse.ArgumentParser()
parse.add_argument('--epoch', type=int, default=15)
parse.add_argument('--schedule_step', type=int, default=2)
parse.add_argument('--batch_size', type=int, default=48)
parse.add_argument('--test_batch_size', type=int, default=256)
parse.add_argument('--num_workers', type=int, default=16)
parse.add_argument('--eval_fre', type=int, default=1)
parse.add_argument('--msg_fre', type=int, default=10)
parse.add_argument('--save_fre', type=int, default=2)
parse.add_argument('--name', type=str, default='res34_baseline',
help='unique out file name of this task include log/model_out/tensorboard log')
parse.add_argument('--data_dir', type=str, default='/home/tiger/projects/rscup2019_classifier/data')
parse.add_argument('--log_dir', type=str, default='./logs')
parse.add_argument('--tensorboard_dir', type=str, default='./tensorboard')
parse.add_argument('--model_out_dir', type=str, default='./model_out')
parse.add_argument('--model_out_name', type=str, default='final_model.pth')
parse.add_argument('--seed', type=int, default=5, help='random seed')
parse.add_argument('--eval_model_path', type=str,
default='/home/tiger/projects/rscup2019_classifier/model_out/logistic_out_6.pth')
return parse.parse_args()
def evalute(args):
val_set = RSDataset(rootpth=args.data_dir, mode='val')
val_loader = DataLoader(val_set,
batch_size=args.test_batch_size,
drop_last=True,
shuffle=True,
pin_memory=True,
num_workers=args.num_workers)
net = Resnet34()
net.eval()
net.load_state_dict(torch.load(args.eval_model_path))
net.cuda()
total = [0 for i in range(45)]
correct = [0 for i in range(45)]
with torch.no_grad():
for img, lb in val_loader:
img, lb = img.cuda(), lb.cuda()
outputs = net(img)
outputs = torch.sigmoid(outputs)
predicted = torch.max(outputs, dim=1)[1]
res = predicted == lb
for label_idx in range(args.test_batch_size):
label_single = lb[label_idx]
correct[label_single] += res[label_idx].item()
total[label_single] += 1
# print(correct, total)
acc_str = 'Accuracy: {}\n'.format(sum(correct)/sum(total))
for acc_idx in range(45):
try:
acc = correct[acc_idx] / total[acc_idx]
except:
acc = 0
finally:
acc_str += 'classID: {},\taccuracy: {}\n'.format(acc_idx+1, acc)
print(acc_str)
if __name__ == '__main__':
args = parse_args()
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
evalute(args)