a b/benchmark.py
1
import torch
2
from torchmetrics.image.inception import InceptionScore
3
from glob import glob
4
from natsort import natsorted
5
from tqdm import tqdm
6
import os
7
import cv2
8
import numpy as np
9
import torch
10
from torch import nn
11
from torch.nn import functional as F
12
from torch.utils.data import DataLoader
13
from torchvision import transforms
14
from torchvision.models import inception_v3, Inception3
15
from torchvision.datasets import ImageFolder
16
import pathlib
17
from torch.utils.data import Dataset
18
from PIL import Image
19
from pytorch_gan_metrics import get_inception_score, get_fid
20
21
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
22
os.environ["CUDA_DEVICE_ORDER"]= "PCI_BUS_ID"
23
os.environ["CUDA_VISIBLE_DEVICES"]= '1'
24
# if __name__ == '__main__':
25
#   path = natsorted(glob('experiments/inception/294/*'))
26
#   for sub_path in path:
27
#       cls_name  = os.path.split(sub_path)[-1]
28
#       imgs_path = natsorted(glob(sub_path+'/*'))
29
#       inception = InceptionScore(splits=1)
30
#       for img in tqdm(imgs_path):
31
#           img = cv2.imread(img, 1)
32
#           img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
33
#           img = torch.permute(torch.Tensor(img).to(torch.uint8), (2, 0, 1))
34
#           inception.update(torch.unsqueeze(img, axis=0))
35
#           mean, std = inception.compute()
36
#       print('Inception score for class {}: {}'.format(cls_name, mean, std))
37
#       break
38
class MyDataset(Dataset):
39
    def __init__(self, image_paths, transform=None):
40
        self.image_paths = image_paths
41
        self.transform = transform
42
43
    def __getitem__(self, index):
44
        image_path = self.image_paths[index]
45
        x = Image.open(image_path)
46
        # y = self.get_class_label(image_path.split('/')[-1])
47
        if self.transform is not None:
48
            x = self.transform(x)
49
        # print(x.shape)
50
        return x
51
52
    def __len__(self):
53
        return len(self.image_paths)
54
55
56
@torch.no_grad()
57
def extract_features(loader, inception, device):
58
    pbar = tqdm(loader)
59
60
    feature_list = []
61
62
    for img in pbar:
63
        img = img.to(device)
64
        img = img.to(torch.uint8)
65
        # print(img.dtype)
66
        # feature = inception(img)[0].view(img.shape[0], -1)
67
        # feature_list.append(feature.to('cpu'))
68
        inception.update(img)
69
70
    # features = torch.cat(feature_list, 0)
71
72
    return inception.compute()
73
74
75
if __name__ == '__main__':
76
77
    paths = natsorted(glob('experiments/inception/*'))
78
79
    for path in paths:
80
        # device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
81
        # inception = InceptionScore(splits=10)
82
        # inception = InceptionScore(feature=2048, splits=1)
83
        # inception = inception.to(device)
84
85
        transform = transforms.Compose(
86
            [
87
                # transforms.Resize( (299, 299) ),
88
                # transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0),
89
                transforms.ToTensor(),
90
                # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
91
            ]
92
        )
93
94
        dset   = MyDataset(natsorted(glob(path+'/*')), transform)
95
        loader = DataLoader(dset, batch_size=256, num_workers=8, shuffle=False)
96
97
        # mean, std = extract_features(loader, inception, device)
98
        mean, std = get_inception_score(loader)
99
100
        print('Inception score for epoch {}: ({}, {})'.format(os.path.split(path)[-1], mean, std))
101
102
        with open('experiments/inceptionscore_torch.txt', 'a') as file:
103
            file.write('Inception score for epoch {}: ({}, {})\n'.format(os.path.split(path)[-1], mean, std))