Diff of /torchfile.py [000000] .. [968c76]

Switch to unified view

a b/torchfile.py
1
"""
2
Copyright (c) 2016, Brendan Shillingford
3
All rights reserved.
4
5
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the 
6
following conditions are met: 
7
8
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following 
9
disclaimer. 
10
11
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the 
12
following disclaimer in the documentation and/or other materials provided with the distribution. 
13
14
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote 
15
products derived from this software without specific prior written permission. 
16
17
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, 
18
INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 
19
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 
20
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 
21
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 
22
WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
23
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 
24
----------------------------------------------------------------------------------------------------------------------
25
The file was taken from https://github.com/bshillingford/python-torchfile and slightly modified
26
----------------------------------------------------------------------------------------------------------------------
27
28
Mostly direct port of the Lua and C serialization implementation to 
29
Python, depending only on `struct`, `array`, and numpy.
30
31
Supported types:
32
 * `nil` to Python `None`
33
 * numbers to Python floats, or by default a heuristic changes them to ints or
34
   longs if they are integral
35
 * booleans
36
 * strings: read as byte strings (Python 3) or normal strings (Python 2), like
37
   lua strings which don't support unicode, and that can contain null chars
38
 * tables converted to a special dict (*); if they are list-like (i.e. have
39
   numeric keys from 1 through n) they become a python list by default
40
 * Torch classes: supports Tensors and Storages, and most classes such as 
41
   modules. Trivially extensible much like the Torch serialization code.
42
   Trivial torch classes like most `nn.Module` subclasses become 
43
   `TorchObject`s. The `torch_readers` dict contains the mapping from class
44
   names to reading functions.
45
 * functions: loaded into the `LuaFunction` `namedtuple`,
46
   which simply wraps the raw serialized data, i.e. upvalues and code.
47
   These are mostly useless, but exist so you can deserialize anything.
48
49
(*) Since Lua allows you to index a table with a table but Python does not, we 
50
    replace dicts with a subclass that is hashable, and change its
51
    equality comparison behaviour to compare by reference.
52
    See `hashable_uniq_dict`.
53
54
Currently, the implementation assumes the system-dependent binary Torch 
55
format, but minor refactoring can give support for the ascii format as well.
56
"""
57
58
TYPE_NIL = 0
59
TYPE_NUMBER = 1
60
TYPE_STRING = 2
61
TYPE_TABLE = 3
62
TYPE_TORCH = 4
63
TYPE_BOOLEAN = 5
64
TYPE_FUNCTION = 6
65
TYPE_RECUR_FUNCTION = 8
66
LEGACY_TYPE_RECUR_FUNCTION = 7
67
68
import struct
69
from array import array
70
import numpy as np
71
import sys
72
from collections import namedtuple
73
74
LuaFunction = namedtuple('LuaFunction',
75
                         ['size', 'dumped', 'upvalues'])
76
77
78
class hashable_uniq_dict(dict):
79
    """
80
    Subclass of dict with equality and hashing semantics changed:
81
    equality and hashing is purely by reference/instance, to match
82
    the behaviour of lua tables.
83
84
    Supports lua-style dot indexing.
85
86
    This way, dicts can be keys of other dicts.
87
    """
88
89
    def __hash__(self):
90
        return id(self)
91
92
    def __getattr__(self, key):
93
        return self.get(key)
94
95
    def __eq__(self, other):
96
        return id(self) == id(other)
97
    # TODO: dict's __lt__ etc. still exist
98
99
torch_readers = {}
100
101
102
def add_tensor_reader(typename, dtype):
103
    def read_tensor_generic(reader, version):
104
        # source:
105
        # https://github.com/torch/torch7/blob/master/generic/Tensor.c#L1243
106
        ndim = reader.read_int()
107
108
        # read size:
109
        size = reader.read_long_array(ndim)
110
        # read stride:
111
        stride = reader.read_long_array(ndim)
112
        # storage offset:
113
        storage_offset = reader.read_long() - 1
114
        # read storage:
115
        storage = reader.read_obj()
116
117
        if storage is None or ndim == 0 or len(size) == 0 or len(stride) == 0:
118
            # empty torch tensor
119
            return np.empty((0), dtype=dtype)
120
121
        # convert stride to numpy style (i.e. in bytes)
122
        stride = [storage.dtype.itemsize * x for x in stride]
123
124
        # create numpy array that indexes into the storage:
125
        return np.lib.stride_tricks.as_strided(
126
            storage[storage_offset:],
127
            shape=size,
128
            strides=stride)
129
    torch_readers[typename] = read_tensor_generic
130
add_tensor_reader(b'torch.ByteTensor', dtype=np.uint8)
131
add_tensor_reader(b'torch.CharTensor', dtype=np.int8)
132
add_tensor_reader(b'torch.ShortTensor', dtype=np.int16)
133
add_tensor_reader(b'torch.IntTensor', dtype=np.int32)
134
add_tensor_reader(b'torch.LongTensor', dtype=np.int64)
135
add_tensor_reader(b'torch.FloatTensor', dtype=np.float32)
136
add_tensor_reader(b'torch.DoubleTensor', dtype=np.float64)
137
add_tensor_reader(b'torch.CudaTensor', np.float32)  # float
138
add_tensor_reader(b'torch.CudaByteTensor', dtype=np.uint8)
139
add_tensor_reader(b'torch.CudaCharTensor', dtype=np.int8)
140
add_tensor_reader(b'torch.CudaShortTensor', dtype=np.int16)
141
add_tensor_reader(b'torch.CudaIntTensor', dtype=np.int32)
142
add_tensor_reader(b'torch.CudaDoubleTensor', dtype=np.float64)
143
144
145
def add_storage_reader(typename, dtype):
146
    def read_storage(reader, version):
147
        # source:
148
        # https://github.com/torch/torch7/blob/master/generic/Storage.c#L244
149
        size = reader.read_long()
150
        return np.fromfile(reader.f, dtype=dtype, count=size)
151
    torch_readers[typename] = read_storage
152
add_storage_reader(b'torch.ByteStorage', dtype=np.uint8)
153
add_storage_reader(b'torch.CharStorage', dtype=np.int8)
154
add_storage_reader(b'torch.ShortStorage', dtype=np.int16)
155
add_storage_reader(b'torch.IntStorage', dtype=np.int32)
156
add_storage_reader(b'torch.LongStorage', dtype=np.int64)
157
add_storage_reader(b'torch.FloatStorage', dtype=np.float32)
158
add_storage_reader(b'torch.DoubleStorage', dtype=np.float64)
159
add_storage_reader(b'torch.CudaStorage', dtype=np.float32)  # float
160
add_storage_reader(b'torch.CudaByteStorage', dtype=np.uint8)
161
add_storage_reader(b'torch.CudaCharStorage', dtype=np.int8)
162
add_storage_reader(b'torch.CudaShortStorage', dtype=np.int16)
163
add_storage_reader(b'torch.CudaIntStorage', dtype=np.int32)
164
add_storage_reader(b'torch.CudaDoubleStorage', dtype=np.float64)
165
166
167
class TorchObject(object):
168
    """
169
    Simple torch object, used by `add_trivial_class_reader`.
170
    Supports both forms of lua-style indexing, i.e. getattr and getitem.
171
    Use the `torch_typename` method to get the object's torch class name.
172
173
    Equality is by reference, as usual for lua (and the default for Python
174
    objects).
175
    """
176
177
    def __init__(self, typename, obj):
178
        self._typename = typename
179
        self._obj = obj
180
181
    def __getattr__(self, k):
182
        return self._obj.get(k)
183
184
    def __getitem__(self, k):
185
        return self._obj.get(k)
186
187
    def torch_typename(self):
188
        return self._typename
189
190
    def __repr__(self):
191
        return "TorchObject(%s, %s)" % (self._typename, repr(self._obj))
192
193
    def __str__(self):
194
        return repr(self)
195
196
    def __dir__(self):
197
        keys = list(self._obj.keys())
198
        keys.append('torch_typename')
199
        return keys
200
201
202
def add_trivial_class_reader(typename):
203
    def reader(reader, version):
204
        obj = reader.read_obj()
205
        return TorchObject(typename, obj)
206
    torch_readers[typename] = reader
207
for mod in [b"nn.ConcatTable", b"nn.SpatialAveragePooling",
208
            b"nn.TemporalConvolutionFB", b"nn.BCECriterion", b"nn.Reshape", b"nn.gModule",
209
            b"nn.SparseLinear", b"nn.WeightedLookupTable", b"nn.CAddTable",
210
            b"nn.TemporalConvolution", b"nn.PairwiseDistance", b"nn.WeightedMSECriterion",
211
            b"nn.SmoothL1Criterion", b"nn.TemporalSubSampling", b"nn.TanhShrink",
212
            b"nn.MixtureTable", b"nn.Mul", b"nn.LogSoftMax", b"nn.Min", b"nn.Exp", b"nn.Add",
213
            b"nn.BatchNormalization", b"nn.AbsCriterion", b"nn.MultiCriterion",
214
            b"nn.LookupTableGPU", b"nn.Max", b"nn.MulConstant", b"nn.NarrowTable", b"nn.View",
215
            b"nn.ClassNLLCriterionWithUNK", b"nn.VolumetricConvolution",
216
            b"nn.SpatialSubSampling", b"nn.HardTanh", b"nn.DistKLDivCriterion",
217
            b"nn.SplitTable", b"nn.DotProduct", b"nn.HingeEmbeddingCriterion",
218
            b"nn.SpatialBatchNormalization", b"nn.DepthConcat", b"nn.Sigmoid",
219
            b"nn.SpatialAdaptiveMaxPooling", b"nn.Parallel", b"nn.SoftShrink",
220
            b"nn.SpatialSubtractiveNormalization", b"nn.TrueNLLCriterion", b"nn.Log",
221
            b"nn.SpatialDropout", b"nn.LeakyReLU", b"nn.VolumetricMaxPooling",
222
            b"nn.KMaxPooling", b"nn.Linear", b"nn.Euclidean", b"nn.CriterionTable",
223
            b"nn.SpatialMaxPooling", b"nn.TemporalKMaxPooling", b"nn.MultiMarginCriterion",
224
            b"nn.ELU", b"nn.CSubTable", b"nn.MultiLabelMarginCriterion", b"nn.Copy",
225
            b"nn.CuBLASWrapper", b"nn.L1HingeEmbeddingCriterion",
226
            b"nn.VolumetricAveragePooling", b"nn.StochasticGradient",
227
            b"nn.SpatialContrastiveNormalization", b"nn.CosineEmbeddingCriterion",
228
            b"nn.CachingLookupTable", b"nn.FeatureLPPooling", b"nn.Padding", b"nn.Container",
229
            b"nn.MarginRankingCriterion", b"nn.Module", b"nn.ParallelCriterion",
230
            b"nn.DataParallelTable", b"nn.Concat", b"nn.CrossEntropyCriterion",
231
            b"nn.LookupTable", b"nn.SpatialSoftMax", b"nn.HardShrink", b"nn.Abs", b"nn.SoftMin",
232
            b"nn.WeightedEuclidean", b"nn.Replicate", b"nn.DataParallel",
233
            b"nn.OneBitQuantization", b"nn.OneBitDataParallel", b"nn.AddConstant", b"nn.L1Cost",
234
            b"nn.HSM", b"nn.PReLU", b"nn.JoinTable", b"nn.ClassNLLCriterion", b"nn.CMul",
235
            b"nn.CosineDistance", b"nn.Index", b"nn.Mean", b"nn.FFTWrapper", b"nn.Dropout",
236
            b"nn.SpatialConvolutionCuFFT", b"nn.SoftPlus", b"nn.AbstractParallel",
237
            b"nn.SequentialCriterion", b"nn.LocallyConnected",
238
            b"nn.SpatialDivisiveNormalization", b"nn.L1Penalty", b"nn.Threshold", b"nn.Power",
239
            b"nn.Sqrt", b"nn.MM", b"nn.GroupKMaxPooling", b"nn.CrossMapNormalization",
240
            b"nn.ReLU", b"nn.ClassHierarchicalNLLCriterion", b"nn.Optim", b"nn.SoftMax",
241
            b"nn.SpatialConvolutionMM", b"nn.Cosine", b"nn.Clamp", b"nn.CMulTable",
242
            b"nn.LogSigmoid", b"nn.LinearNB", b"nn.TemporalMaxPooling", b"nn.MSECriterion",
243
            b"nn.Sum", b"nn.SoftSign", b"nn.Normalize", b"nn.ParallelTable", b"nn.FlattenTable",
244
            b"nn.CDivTable", b"nn.Tanh", b"nn.ModuleFromCriterion", b"nn.Square", b"nn.Select",
245
            b"nn.GradientReversal", b"nn.SpatialFullConvolutionMap", b"nn.SpatialConvolution",
246
            b"nn.Criterion", b"nn.SpatialConvolutionMap", b"nn.SpatialLPPooling",
247
            b"nn.Sequential", b"nn.Transpose", b"nn.SpatialUpSamplingNearest",
248
            b"nn.SpatialFullConvolution", b"nn.ModelParallel", b"nn.RReLU",
249
            b"nn.SpatialZeroPadding", b"nn.Identity", b"nn.Narrow", b"nn.MarginCriterion",
250
            b"nn.SelectTable", b"nn.VolumetricFullConvolution",
251
            b"nn.SpatialFractionalMaxPooling", b"fbnn.ProjectiveGradientNormalization",
252
            b"fbnn.Probe", b"fbnn.SparseLinear", b"cudnn._Pooling3D",
253
            b"cudnn.VolumetricMaxPooling", b"cudnn.SpatialCrossEntropyCriterion",
254
            b"cudnn.VolumetricConvolution", b"cudnn.SpatialAveragePooling", b"cudnn.Tanh",
255
            b"cudnn.LogSoftMax", b"cudnn.SpatialConvolution", b"cudnn._Pooling",
256
            b"cudnn.SpatialMaxPooling", b"cudnn.ReLU", b"cudnn.SpatialCrossMapLRN",
257
            b"cudnn.SoftMax", b"cudnn._Pointwise", b"cudnn.SpatialSoftMax", b"cudnn.Sigmoid",
258
            b"cudnn.SpatialLogSoftMax", b"cudnn.VolumetricAveragePooling", b"nngraph.Node",
259
            b"nngraph.JustTable", b"graph.Edge", b"graph.Node", b"graph.Graph"]:
260
    
261
    add_trivial_class_reader(mod)
262
263
264
class T7ReaderException(Exception):
265
    pass
266
267
268
class T7Reader:
269
270
    def __init__(self,
271
                 fileobj,
272
                 use_list_heuristic=True,
273
                 use_int_heuristic=True,
274
                 force_deserialize_classes=True,
275
                 force_8bytes_long=True):
276
        """
277
        Params:
278
        * `fileobj` file object to read from, must be actual file object
279
                    as it must support array, struct, and numpy
280
        * `use_list_heuristic`: automatically turn tables with only consecutive
281
                                positive integral indices into lists
282
                                (default True)
283
        * `use_int_heuristic`: cast all whole floats into ints (default True)
284
        * `force_deserialize_classes`: deserialize all classes, not just the
285
                                       whitelisted ones (default True)
286
        """
287
        self.f = fileobj
288
        self.objects = {}  # read objects so far
289
290
        self.use_list_heuristic = use_list_heuristic
291
        self.use_int_heuristic = use_int_heuristic
292
        self.force_deserialize_classes = force_deserialize_classes
293
        self.force_8bytes_long = force_8bytes_long
294
295
    def _read(self, fmt):
296
        sz = struct.calcsize(fmt)
297
        b = self.f.read(sz)
298
        if b == b'':
299
            # print('x')
300
            s = (0,)
301
        else:
302
            s = struct.unpack(fmt, b)
303
304
        # print(s)
305
        return s
306
307
    def read_boolean(self):
308
        return self.read_int() == 1
309
310
    def read_int(self):
311
        return self._read('i')[0]
312
313
    def read_long(self):
314
        if self.force_8bytes_long:
315
            return self._read('q')[0]
316
        else:
317
            return self._read('l')[0]
318
319
    def read_long_array(self, n):
320
        if self.force_8bytes_long:
321
            lst = []
322
            for i in range(n):
323
                lst.append(self.read_long())
324
            return lst
325
        else:
326
            arr = array('l')
327
            arr.fromfile(self.f, n)
328
            return arr.tolist()       
329
    
330
    def read_float(self):
331
        return self._read('f')[0]
332
333
    def read_double(self):
334
        return self._read('d')[0]
335
336
    def read_string(self):
337
        size = self.read_int()
338
        return self.f.read(size)
339
340
    def read_obj(self):
341
        typeidx = self.read_int()
342
        if typeidx == TYPE_NIL:
343
            return None
344
        elif typeidx == TYPE_NUMBER:
345
            x = self.read_double()
346
            # Extra checking for integral numbers:
347
            if self.use_int_heuristic and x.is_integer():
348
                return int(x)
349
            return x
350
        elif typeidx == TYPE_BOOLEAN:
351
            return self.read_boolean()
352
        elif typeidx == TYPE_STRING:
353
            return self.read_string()
354
        elif (typeidx == TYPE_TABLE or typeidx == TYPE_TORCH
355
                or typeidx == TYPE_FUNCTION or typeidx == TYPE_RECUR_FUNCTION
356
                or typeidx == LEGACY_TYPE_RECUR_FUNCTION):
357
            # read the index
358
            index = self.read_int()
359
360
            # check it is loaded already
361
            if index in self.objects:
362
                return self.objects[index]
363
364
            # otherwise read it
365
            if (typeidx == TYPE_FUNCTION or typeidx == TYPE_RECUR_FUNCTION
366
                    or typeidx == LEGACY_TYPE_RECUR_FUNCTION):
367
                size = self.read_int()
368
                dumped = self.f.read(size)
369
                upvalues = self.read_obj()
370
                obj = LuaFunction(size, dumped, upvalues)
371
                self.objects[index] = obj
372
                return obj
373
            elif typeidx == TYPE_TORCH:
374
                version = self.read_string()
375
                if version.startswith(b'V '):
376
                    versionNumber = int(version.partition(b' ')[2])
377
                    className = self.read_string()
378
                else:
379
                    className = version
380
                    versionNumber = 0  # created before existence of versioning
381
                # print(className)
382
                if className not in torch_readers:
383
                    if not self.force_deserialize_classes:
384
                        raise T7ReaderException(
385
                            'unsupported torch class: <%s>' % className)
386
                    obj = TorchObject(className, self.read_obj())
387
                else:
388
                    obj = torch_readers[className](self, version)
389
                self.objects[index] = obj
390
                return obj
391
            else:  # it is a table: returns a custom dict or a list
392
                size = self.read_int()
393
                obj = hashable_uniq_dict()  # custom hashable dict, can be a key
394
                key_sum = 0                # for checking if keys are consecutive
395
                keys_natural = True        # and also natural numbers 1..n.
396
                # If so, returns a list with indices converted to 0-indices.
397
                for i in range(size):
398
                    k = self.read_obj()
399
                    v = self.read_obj()
400
                    obj[k] = v
401
402
                    if self.use_list_heuristic:
403
                        if not isinstance(k, int) or k <= 0:
404
                            keys_natural = False
405
                        elif isinstance(k, int):
406
                            key_sum += k
407
                if self.use_list_heuristic:
408
                    # n(n+1)/2 = sum <=> consecutive and natural numbers
409
                    n = len(obj)
410
                    if keys_natural and n * (n + 1) == 2 * key_sum:
411
                        lst = []
412
                        for i in range(len(obj)):
413
                            lst.append(obj[i + 1])
414
                        obj = lst
415
                self.objects[index] = obj
416
                return obj
417
        else:
418
            raise T7ReaderException("unknown object")
419
420
421
def load(filename, **kwargs):
422
    """
423
    Loads the given t7 file using default settings; kwargs are forwarded
424
    to `T7Reader`.
425
    """
426
    with open(filename, 'rb') as f:
427
        reader = T7Reader(f, **kwargs)
428
        return reader.read_obj()