|
a |
|
b/utils/array_tool.py |
|
|
1 |
import torch |
|
|
2 |
from torch.autograd import Variable |
|
|
3 |
import numpy as np |
|
|
4 |
""" |
|
|
5 |
tools to convert specified type |
|
|
6 |
""" |
|
|
7 |
def tonumpy(data): |
|
|
8 |
if data is None: |
|
|
9 |
return None |
|
|
10 |
if isinstance(data, np.ndarray): |
|
|
11 |
return data |
|
|
12 |
if isinstance(data, torch._TensorBase): |
|
|
13 |
return data.cpu().numpy() |
|
|
14 |
if isinstance(data, torch.autograd.Variable): |
|
|
15 |
return tonumpy(data.data) |
|
|
16 |
if isinstance(data, np.int32): |
|
|
17 |
return np.array(data) |
|
|
18 |
if isinstance(data, list): |
|
|
19 |
return np.array(data) |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
def totensor(data, cuda=True): |
|
|
23 |
if isinstance(data, np.ndarray): |
|
|
24 |
tensor = torch.from_numpy(data) |
|
|
25 |
if isinstance(data, torch._TensorBase): |
|
|
26 |
tensor = data |
|
|
27 |
if isinstance(data, torch.autograd.Variable): |
|
|
28 |
tensor = data.data |
|
|
29 |
if cuda: |
|
|
30 |
tensor = tensor.cuda() |
|
|
31 |
return tensor |
|
|
32 |
|
|
|
33 |
|
|
|
34 |
def tovariable(data): |
|
|
35 |
if isinstance(data, np.ndarray): |
|
|
36 |
return tovariable(totensor(data)) |
|
|
37 |
if isinstance(data, torch._TensorBase): |
|
|
38 |
return torch.autograd.Variable(data) |
|
|
39 |
if isinstance(data, torch.autograd.Variable): |
|
|
40 |
return data |
|
|
41 |
else: |
|
|
42 |
raise ValueError("UnKnow data type: %s, input should be {np.ndarray,Tensor,Variable}" %type(data)) |
|
|
43 |
|
|
|
44 |
|
|
|
45 |
def scalar(data): |
|
|
46 |
if isinstance(data, np.ndarray): |
|
|
47 |
return data.reshape(1)[0] |
|
|
48 |
if isinstance(data, torch._TensorBase): |
|
|
49 |
return data.view(1)[0] |
|
|
50 |
if isinstance(data, torch.autograd.Variable): |
|
|
51 |
return data.data.view(1)[0] |
|
|
52 |
|
|
|
53 |
|
|
|
54 |
|
|
|
55 |
# Test |
|
|
56 |
if __name__ == '__main__': |
|
|
57 |
x = torch.randn(3, 3) |
|
|
58 |
y = torch.randn(9) |
|
|
59 |
z = Variable(x) |
|
|
60 |
print(type(x)) |
|
|
61 |
print(x.type()) |
|
|
62 |
print(z.type()) |
|
|
63 |
|
|
|
64 |
if isinstance(z, torch.Tensor): |
|
|
65 |
print('yes') |