|
a |
|
b/benchmark.py |
|
|
1 |
import torch |
|
|
2 |
from torchmetrics.image.inception import InceptionScore |
|
|
3 |
from glob import glob |
|
|
4 |
from natsort import natsorted |
|
|
5 |
from tqdm import tqdm |
|
|
6 |
import os |
|
|
7 |
import cv2 |
|
|
8 |
import numpy as np |
|
|
9 |
import torch |
|
|
10 |
from torch import nn |
|
|
11 |
from torch.nn import functional as F |
|
|
12 |
from torch.utils.data import DataLoader |
|
|
13 |
from torchvision import transforms |
|
|
14 |
from torchvision.models import inception_v3, Inception3 |
|
|
15 |
from torchvision.datasets import ImageFolder |
|
|
16 |
import pathlib |
|
|
17 |
from torch.utils.data import Dataset |
|
|
18 |
from PIL import Image |
|
|
19 |
from pytorch_gan_metrics import get_inception_score, get_fid |
|
|
20 |
|
|
|
21 |
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices' |
|
|
22 |
os.environ["CUDA_DEVICE_ORDER"]= "PCI_BUS_ID" |
|
|
23 |
os.environ["CUDA_VISIBLE_DEVICES"]= '1' |
|
|
24 |
# if __name__ == '__main__': |
|
|
25 |
# path = natsorted(glob('experiments/inception/294/*')) |
|
|
26 |
# for sub_path in path: |
|
|
27 |
# cls_name = os.path.split(sub_path)[-1] |
|
|
28 |
# imgs_path = natsorted(glob(sub_path+'/*')) |
|
|
29 |
# inception = InceptionScore(splits=1) |
|
|
30 |
# for img in tqdm(imgs_path): |
|
|
31 |
# img = cv2.imread(img, 1) |
|
|
32 |
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
33 |
# img = torch.permute(torch.Tensor(img).to(torch.uint8), (2, 0, 1)) |
|
|
34 |
# inception.update(torch.unsqueeze(img, axis=0)) |
|
|
35 |
# mean, std = inception.compute() |
|
|
36 |
# print('Inception score for class {}: {}'.format(cls_name, mean, std)) |
|
|
37 |
# break |
|
|
38 |
class MyDataset(Dataset): |
|
|
39 |
def __init__(self, image_paths, transform=None): |
|
|
40 |
self.image_paths = image_paths |
|
|
41 |
self.transform = transform |
|
|
42 |
|
|
|
43 |
def __getitem__(self, index): |
|
|
44 |
image_path = self.image_paths[index] |
|
|
45 |
x = Image.open(image_path) |
|
|
46 |
# y = self.get_class_label(image_path.split('/')[-1]) |
|
|
47 |
if self.transform is not None: |
|
|
48 |
x = self.transform(x) |
|
|
49 |
# print(x.shape) |
|
|
50 |
return x |
|
|
51 |
|
|
|
52 |
def __len__(self): |
|
|
53 |
return len(self.image_paths) |
|
|
54 |
|
|
|
55 |
|
|
|
56 |
@torch.no_grad() |
|
|
57 |
def extract_features(loader, inception, device): |
|
|
58 |
pbar = tqdm(loader) |
|
|
59 |
|
|
|
60 |
feature_list = [] |
|
|
61 |
|
|
|
62 |
for img in pbar: |
|
|
63 |
img = img.to(device) |
|
|
64 |
img = img.to(torch.uint8) |
|
|
65 |
# print(img.dtype) |
|
|
66 |
# feature = inception(img)[0].view(img.shape[0], -1) |
|
|
67 |
# feature_list.append(feature.to('cpu')) |
|
|
68 |
inception.update(img) |
|
|
69 |
|
|
|
70 |
# features = torch.cat(feature_list, 0) |
|
|
71 |
|
|
|
72 |
return inception.compute() |
|
|
73 |
|
|
|
74 |
|
|
|
75 |
if __name__ == '__main__': |
|
|
76 |
|
|
|
77 |
paths = natsorted(glob('experiments/inception/*')) |
|
|
78 |
|
|
|
79 |
for path in paths: |
|
|
80 |
# device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') |
|
|
81 |
# inception = InceptionScore(splits=10) |
|
|
82 |
# inception = InceptionScore(feature=2048, splits=1) |
|
|
83 |
# inception = inception.to(device) |
|
|
84 |
|
|
|
85 |
transform = transforms.Compose( |
|
|
86 |
[ |
|
|
87 |
# transforms.Resize( (299, 299) ), |
|
|
88 |
# transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0), |
|
|
89 |
transforms.ToTensor(), |
|
|
90 |
# transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
|
|
91 |
] |
|
|
92 |
) |
|
|
93 |
|
|
|
94 |
dset = MyDataset(natsorted(glob(path+'/*')), transform) |
|
|
95 |
loader = DataLoader(dset, batch_size=256, num_workers=8, shuffle=False) |
|
|
96 |
|
|
|
97 |
# mean, std = extract_features(loader, inception, device) |
|
|
98 |
mean, std = get_inception_score(loader) |
|
|
99 |
|
|
|
100 |
print('Inception score for epoch {}: ({}, {})'.format(os.path.split(path)[-1], mean, std)) |
|
|
101 |
|
|
|
102 |
with open('experiments/inceptionscore_torch.txt', 'a') as file: |
|
|
103 |
file.write('Inception score for epoch {}: ({}, {})\n'.format(os.path.split(path)[-1], mean, std)) |