Diff of /utils/array_tool.py [000000] .. [bd7f9c]

Switch to unified view

a b/utils/array_tool.py
1
import torch
2
from torch.autograd import Variable
3
import numpy as np
4
"""
5
tools to convert specified type
6
"""
7
def tonumpy(data):
8
    if data is None:
9
        return None
10
    if isinstance(data, np.ndarray):
11
        return data
12
    if isinstance(data, torch._TensorBase):
13
        return data.cpu().numpy()
14
    if isinstance(data, torch.autograd.Variable):
15
        return tonumpy(data.data)
16
    if isinstance(data, np.int32):
17
        return np.array(data)
18
    if isinstance(data, list):
19
        return np.array(data)
20
21
22
def totensor(data, cuda=True):
23
    if isinstance(data, np.ndarray):
24
        tensor = torch.from_numpy(data)
25
    if isinstance(data, torch._TensorBase):
26
        tensor = data
27
    if isinstance(data, torch.autograd.Variable):
28
        tensor = data.data
29
    if cuda:
30
        tensor = tensor.cuda()
31
    return tensor
32
33
34
def tovariable(data):
35
    if isinstance(data, np.ndarray):
36
        return tovariable(totensor(data))
37
    if isinstance(data, torch._TensorBase):
38
        return torch.autograd.Variable(data)
39
    if isinstance(data, torch.autograd.Variable):
40
        return data
41
    else:
42
        raise ValueError("UnKnow data type: %s, input should be {np.ndarray,Tensor,Variable}" %type(data))
43
44
45
def scalar(data):
46
    if isinstance(data, np.ndarray):
47
        return data.reshape(1)[0]
48
    if isinstance(data, torch._TensorBase):
49
        return data.view(1)[0]
50
    if isinstance(data, torch.autograd.Variable):
51
        return data.data.view(1)[0]
52
53
54
55
# Test
56
if __name__ == '__main__':
57
    x = torch.randn(3, 3)
58
    y = torch.randn(9)
59
    z = Variable(x)
60
    print(type(x))
61
    print(x.type())
62
    print(z.type())
63
64
    if isinstance(z, torch.Tensor):
65
        print('yes')