Diff of /torchfile.py [000000] .. [968c76]

Switch to side-by-side view

--- a
+++ b/torchfile.py
@@ -0,0 +1,428 @@
+"""
+Copyright (c) 2016, Brendan Shillingford
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification, are permitted provided that the 
+following conditions are met: 
+
+1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following 
+disclaimer. 
+
+2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the 
+following disclaimer in the documentation and/or other materials provided with the distribution. 
+
+3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote 
+products derived from this software without specific prior written permission. 
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, 
+INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 
+WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 
+----------------------------------------------------------------------------------------------------------------------
+The file was taken from https://github.com/bshillingford/python-torchfile and slightly modified
+----------------------------------------------------------------------------------------------------------------------
+
+Mostly direct port of the Lua and C serialization implementation to 
+Python, depending only on `struct`, `array`, and numpy.
+
+Supported types:
+ * `nil` to Python `None`
+ * numbers to Python floats, or by default a heuristic changes them to ints or
+   longs if they are integral
+ * booleans
+ * strings: read as byte strings (Python 3) or normal strings (Python 2), like
+   lua strings which don't support unicode, and that can contain null chars
+ * tables converted to a special dict (*); if they are list-like (i.e. have
+   numeric keys from 1 through n) they become a python list by default
+ * Torch classes: supports Tensors and Storages, and most classes such as 
+   modules. Trivially extensible much like the Torch serialization code.
+   Trivial torch classes like most `nn.Module` subclasses become 
+   `TorchObject`s. The `torch_readers` dict contains the mapping from class
+   names to reading functions.
+ * functions: loaded into the `LuaFunction` `namedtuple`,
+   which simply wraps the raw serialized data, i.e. upvalues and code.
+   These are mostly useless, but exist so you can deserialize anything.
+
+(*) Since Lua allows you to index a table with a table but Python does not, we 
+    replace dicts with a subclass that is hashable, and change its
+    equality comparison behaviour to compare by reference.
+    See `hashable_uniq_dict`.
+
+Currently, the implementation assumes the system-dependent binary Torch 
+format, but minor refactoring can give support for the ascii format as well.
+"""
+
+TYPE_NIL = 0
+TYPE_NUMBER = 1
+TYPE_STRING = 2
+TYPE_TABLE = 3
+TYPE_TORCH = 4
+TYPE_BOOLEAN = 5
+TYPE_FUNCTION = 6
+TYPE_RECUR_FUNCTION = 8
+LEGACY_TYPE_RECUR_FUNCTION = 7
+
+import struct
+from array import array
+import numpy as np
+import sys
+from collections import namedtuple
+
+LuaFunction = namedtuple('LuaFunction',
+                         ['size', 'dumped', 'upvalues'])
+
+
+class hashable_uniq_dict(dict):
+    """
+    Subclass of dict with equality and hashing semantics changed:
+    equality and hashing is purely by reference/instance, to match
+    the behaviour of lua tables.
+
+    Supports lua-style dot indexing.
+
+    This way, dicts can be keys of other dicts.
+    """
+
+    def __hash__(self):
+        return id(self)
+
+    def __getattr__(self, key):
+        return self.get(key)
+
+    def __eq__(self, other):
+        return id(self) == id(other)
+    # TODO: dict's __lt__ etc. still exist
+
+torch_readers = {}
+
+
+def add_tensor_reader(typename, dtype):
+    def read_tensor_generic(reader, version):
+        # source:
+        # https://github.com/torch/torch7/blob/master/generic/Tensor.c#L1243
+        ndim = reader.read_int()
+
+        # read size:
+        size = reader.read_long_array(ndim)
+        # read stride:
+        stride = reader.read_long_array(ndim)
+        # storage offset:
+        storage_offset = reader.read_long() - 1
+        # read storage:
+        storage = reader.read_obj()
+
+        if storage is None or ndim == 0 or len(size) == 0 or len(stride) == 0:
+            # empty torch tensor
+            return np.empty((0), dtype=dtype)
+
+        # convert stride to numpy style (i.e. in bytes)
+        stride = [storage.dtype.itemsize * x for x in stride]
+
+        # create numpy array that indexes into the storage:
+        return np.lib.stride_tricks.as_strided(
+            storage[storage_offset:],
+            shape=size,
+            strides=stride)
+    torch_readers[typename] = read_tensor_generic
+add_tensor_reader(b'torch.ByteTensor', dtype=np.uint8)
+add_tensor_reader(b'torch.CharTensor', dtype=np.int8)
+add_tensor_reader(b'torch.ShortTensor', dtype=np.int16)
+add_tensor_reader(b'torch.IntTensor', dtype=np.int32)
+add_tensor_reader(b'torch.LongTensor', dtype=np.int64)
+add_tensor_reader(b'torch.FloatTensor', dtype=np.float32)
+add_tensor_reader(b'torch.DoubleTensor', dtype=np.float64)
+add_tensor_reader(b'torch.CudaTensor', np.float32)  # float
+add_tensor_reader(b'torch.CudaByteTensor', dtype=np.uint8)
+add_tensor_reader(b'torch.CudaCharTensor', dtype=np.int8)
+add_tensor_reader(b'torch.CudaShortTensor', dtype=np.int16)
+add_tensor_reader(b'torch.CudaIntTensor', dtype=np.int32)
+add_tensor_reader(b'torch.CudaDoubleTensor', dtype=np.float64)
+
+
+def add_storage_reader(typename, dtype):
+    def read_storage(reader, version):
+        # source:
+        # https://github.com/torch/torch7/blob/master/generic/Storage.c#L244
+        size = reader.read_long()
+        return np.fromfile(reader.f, dtype=dtype, count=size)
+    torch_readers[typename] = read_storage
+add_storage_reader(b'torch.ByteStorage', dtype=np.uint8)
+add_storage_reader(b'torch.CharStorage', dtype=np.int8)
+add_storage_reader(b'torch.ShortStorage', dtype=np.int16)
+add_storage_reader(b'torch.IntStorage', dtype=np.int32)
+add_storage_reader(b'torch.LongStorage', dtype=np.int64)
+add_storage_reader(b'torch.FloatStorage', dtype=np.float32)
+add_storage_reader(b'torch.DoubleStorage', dtype=np.float64)
+add_storage_reader(b'torch.CudaStorage', dtype=np.float32)  # float
+add_storage_reader(b'torch.CudaByteStorage', dtype=np.uint8)
+add_storage_reader(b'torch.CudaCharStorage', dtype=np.int8)
+add_storage_reader(b'torch.CudaShortStorage', dtype=np.int16)
+add_storage_reader(b'torch.CudaIntStorage', dtype=np.int32)
+add_storage_reader(b'torch.CudaDoubleStorage', dtype=np.float64)
+
+
+class TorchObject(object):
+    """
+    Simple torch object, used by `add_trivial_class_reader`.
+    Supports both forms of lua-style indexing, i.e. getattr and getitem.
+    Use the `torch_typename` method to get the object's torch class name.
+
+    Equality is by reference, as usual for lua (and the default for Python
+    objects).
+    """
+
+    def __init__(self, typename, obj):
+        self._typename = typename
+        self._obj = obj
+
+    def __getattr__(self, k):
+        return self._obj.get(k)
+
+    def __getitem__(self, k):
+        return self._obj.get(k)
+
+    def torch_typename(self):
+        return self._typename
+
+    def __repr__(self):
+        return "TorchObject(%s, %s)" % (self._typename, repr(self._obj))
+
+    def __str__(self):
+        return repr(self)
+
+    def __dir__(self):
+        keys = list(self._obj.keys())
+        keys.append('torch_typename')
+        return keys
+
+
+def add_trivial_class_reader(typename):
+    def reader(reader, version):
+        obj = reader.read_obj()
+        return TorchObject(typename, obj)
+    torch_readers[typename] = reader
+for mod in [b"nn.ConcatTable", b"nn.SpatialAveragePooling",
+            b"nn.TemporalConvolutionFB", b"nn.BCECriterion", b"nn.Reshape", b"nn.gModule",
+            b"nn.SparseLinear", b"nn.WeightedLookupTable", b"nn.CAddTable",
+            b"nn.TemporalConvolution", b"nn.PairwiseDistance", b"nn.WeightedMSECriterion",
+            b"nn.SmoothL1Criterion", b"nn.TemporalSubSampling", b"nn.TanhShrink",
+            b"nn.MixtureTable", b"nn.Mul", b"nn.LogSoftMax", b"nn.Min", b"nn.Exp", b"nn.Add",
+            b"nn.BatchNormalization", b"nn.AbsCriterion", b"nn.MultiCriterion",
+            b"nn.LookupTableGPU", b"nn.Max", b"nn.MulConstant", b"nn.NarrowTable", b"nn.View",
+            b"nn.ClassNLLCriterionWithUNK", b"nn.VolumetricConvolution",
+            b"nn.SpatialSubSampling", b"nn.HardTanh", b"nn.DistKLDivCriterion",
+            b"nn.SplitTable", b"nn.DotProduct", b"nn.HingeEmbeddingCriterion",
+            b"nn.SpatialBatchNormalization", b"nn.DepthConcat", b"nn.Sigmoid",
+            b"nn.SpatialAdaptiveMaxPooling", b"nn.Parallel", b"nn.SoftShrink",
+            b"nn.SpatialSubtractiveNormalization", b"nn.TrueNLLCriterion", b"nn.Log",
+            b"nn.SpatialDropout", b"nn.LeakyReLU", b"nn.VolumetricMaxPooling",
+            b"nn.KMaxPooling", b"nn.Linear", b"nn.Euclidean", b"nn.CriterionTable",
+            b"nn.SpatialMaxPooling", b"nn.TemporalKMaxPooling", b"nn.MultiMarginCriterion",
+            b"nn.ELU", b"nn.CSubTable", b"nn.MultiLabelMarginCriterion", b"nn.Copy",
+            b"nn.CuBLASWrapper", b"nn.L1HingeEmbeddingCriterion",
+            b"nn.VolumetricAveragePooling", b"nn.StochasticGradient",
+            b"nn.SpatialContrastiveNormalization", b"nn.CosineEmbeddingCriterion",
+            b"nn.CachingLookupTable", b"nn.FeatureLPPooling", b"nn.Padding", b"nn.Container",
+            b"nn.MarginRankingCriterion", b"nn.Module", b"nn.ParallelCriterion",
+            b"nn.DataParallelTable", b"nn.Concat", b"nn.CrossEntropyCriterion",
+            b"nn.LookupTable", b"nn.SpatialSoftMax", b"nn.HardShrink", b"nn.Abs", b"nn.SoftMin",
+            b"nn.WeightedEuclidean", b"nn.Replicate", b"nn.DataParallel",
+            b"nn.OneBitQuantization", b"nn.OneBitDataParallel", b"nn.AddConstant", b"nn.L1Cost",
+            b"nn.HSM", b"nn.PReLU", b"nn.JoinTable", b"nn.ClassNLLCriterion", b"nn.CMul",
+            b"nn.CosineDistance", b"nn.Index", b"nn.Mean", b"nn.FFTWrapper", b"nn.Dropout",
+            b"nn.SpatialConvolutionCuFFT", b"nn.SoftPlus", b"nn.AbstractParallel",
+            b"nn.SequentialCriterion", b"nn.LocallyConnected",
+            b"nn.SpatialDivisiveNormalization", b"nn.L1Penalty", b"nn.Threshold", b"nn.Power",
+            b"nn.Sqrt", b"nn.MM", b"nn.GroupKMaxPooling", b"nn.CrossMapNormalization",
+            b"nn.ReLU", b"nn.ClassHierarchicalNLLCriterion", b"nn.Optim", b"nn.SoftMax",
+            b"nn.SpatialConvolutionMM", b"nn.Cosine", b"nn.Clamp", b"nn.CMulTable",
+            b"nn.LogSigmoid", b"nn.LinearNB", b"nn.TemporalMaxPooling", b"nn.MSECriterion",
+            b"nn.Sum", b"nn.SoftSign", b"nn.Normalize", b"nn.ParallelTable", b"nn.FlattenTable",
+            b"nn.CDivTable", b"nn.Tanh", b"nn.ModuleFromCriterion", b"nn.Square", b"nn.Select",
+            b"nn.GradientReversal", b"nn.SpatialFullConvolutionMap", b"nn.SpatialConvolution",
+            b"nn.Criterion", b"nn.SpatialConvolutionMap", b"nn.SpatialLPPooling",
+            b"nn.Sequential", b"nn.Transpose", b"nn.SpatialUpSamplingNearest",
+            b"nn.SpatialFullConvolution", b"nn.ModelParallel", b"nn.RReLU",
+            b"nn.SpatialZeroPadding", b"nn.Identity", b"nn.Narrow", b"nn.MarginCriterion",
+            b"nn.SelectTable", b"nn.VolumetricFullConvolution",
+            b"nn.SpatialFractionalMaxPooling", b"fbnn.ProjectiveGradientNormalization",
+            b"fbnn.Probe", b"fbnn.SparseLinear", b"cudnn._Pooling3D",
+            b"cudnn.VolumetricMaxPooling", b"cudnn.SpatialCrossEntropyCriterion",
+            b"cudnn.VolumetricConvolution", b"cudnn.SpatialAveragePooling", b"cudnn.Tanh",
+            b"cudnn.LogSoftMax", b"cudnn.SpatialConvolution", b"cudnn._Pooling",
+            b"cudnn.SpatialMaxPooling", b"cudnn.ReLU", b"cudnn.SpatialCrossMapLRN",
+            b"cudnn.SoftMax", b"cudnn._Pointwise", b"cudnn.SpatialSoftMax", b"cudnn.Sigmoid",
+            b"cudnn.SpatialLogSoftMax", b"cudnn.VolumetricAveragePooling", b"nngraph.Node",
+            b"nngraph.JustTable", b"graph.Edge", b"graph.Node", b"graph.Graph"]:
+	
+    add_trivial_class_reader(mod)
+
+
+class T7ReaderException(Exception):
+    pass
+
+
+class T7Reader:
+
+    def __init__(self,
+                 fileobj,
+                 use_list_heuristic=True,
+                 use_int_heuristic=True,
+                 force_deserialize_classes=True,
+                 force_8bytes_long=True):
+        """
+        Params:
+        * `fileobj` file object to read from, must be actual file object
+                    as it must support array, struct, and numpy
+        * `use_list_heuristic`: automatically turn tables with only consecutive
+                                positive integral indices into lists
+                                (default True)
+        * `use_int_heuristic`: cast all whole floats into ints (default True)
+        * `force_deserialize_classes`: deserialize all classes, not just the
+                                       whitelisted ones (default True)
+        """
+        self.f = fileobj
+        self.objects = {}  # read objects so far
+
+        self.use_list_heuristic = use_list_heuristic
+        self.use_int_heuristic = use_int_heuristic
+        self.force_deserialize_classes = force_deserialize_classes
+        self.force_8bytes_long = force_8bytes_long
+
+    def _read(self, fmt):
+        sz = struct.calcsize(fmt)
+        b = self.f.read(sz)
+        if b == b'':
+            # print('x')
+            s = (0,)
+        else:
+            s = struct.unpack(fmt, b)
+
+        # print(s)
+        return s
+
+    def read_boolean(self):
+        return self.read_int() == 1
+
+    def read_int(self):
+        return self._read('i')[0]
+
+    def read_long(self):
+        if self.force_8bytes_long:
+            return self._read('q')[0]
+        else:
+            return self._read('l')[0]
+
+    def read_long_array(self, n):
+        if self.force_8bytes_long:
+            lst = []
+            for i in range(n):
+                lst.append(self.read_long())
+            return lst
+        else:
+            arr = array('l')
+            arr.fromfile(self.f, n)
+            return arr.tolist()       
+    
+    def read_float(self):
+        return self._read('f')[0]
+
+    def read_double(self):
+        return self._read('d')[0]
+
+    def read_string(self):
+        size = self.read_int()
+        return self.f.read(size)
+
+    def read_obj(self):
+        typeidx = self.read_int()
+        if typeidx == TYPE_NIL:
+            return None
+        elif typeidx == TYPE_NUMBER:
+            x = self.read_double()
+            # Extra checking for integral numbers:
+            if self.use_int_heuristic and x.is_integer():
+                return int(x)
+            return x
+        elif typeidx == TYPE_BOOLEAN:
+            return self.read_boolean()
+        elif typeidx == TYPE_STRING:
+            return self.read_string()
+        elif (typeidx == TYPE_TABLE or typeidx == TYPE_TORCH
+                or typeidx == TYPE_FUNCTION or typeidx == TYPE_RECUR_FUNCTION
+                or typeidx == LEGACY_TYPE_RECUR_FUNCTION):
+            # read the index
+            index = self.read_int()
+
+            # check it is loaded already
+            if index in self.objects:
+                return self.objects[index]
+
+            # otherwise read it
+            if (typeidx == TYPE_FUNCTION or typeidx == TYPE_RECUR_FUNCTION
+                    or typeidx == LEGACY_TYPE_RECUR_FUNCTION):
+                size = self.read_int()
+                dumped = self.f.read(size)
+                upvalues = self.read_obj()
+                obj = LuaFunction(size, dumped, upvalues)
+                self.objects[index] = obj
+                return obj
+            elif typeidx == TYPE_TORCH:
+                version = self.read_string()
+                if version.startswith(b'V '):
+                    versionNumber = int(version.partition(b' ')[2])
+                    className = self.read_string()
+                else:
+                    className = version
+                    versionNumber = 0  # created before existence of versioning
+                # print(className)
+                if className not in torch_readers:
+                    if not self.force_deserialize_classes:
+                        raise T7ReaderException(
+                            'unsupported torch class: <%s>' % className)
+                    obj = TorchObject(className, self.read_obj())
+                else:
+                    obj = torch_readers[className](self, version)
+                self.objects[index] = obj
+                return obj
+            else:  # it is a table: returns a custom dict or a list
+                size = self.read_int()
+                obj = hashable_uniq_dict()  # custom hashable dict, can be a key
+                key_sum = 0                # for checking if keys are consecutive
+                keys_natural = True        # and also natural numbers 1..n.
+                # If so, returns a list with indices converted to 0-indices.
+                for i in range(size):
+                    k = self.read_obj()
+                    v = self.read_obj()
+                    obj[k] = v
+
+                    if self.use_list_heuristic:
+                        if not isinstance(k, int) or k <= 0:
+                            keys_natural = False
+                        elif isinstance(k, int):
+                            key_sum += k
+                if self.use_list_heuristic:
+                    # n(n+1)/2 = sum <=> consecutive and natural numbers
+                    n = len(obj)
+                    if keys_natural and n * (n + 1) == 2 * key_sum:
+                        lst = []
+                        for i in range(len(obj)):
+                            lst.append(obj[i + 1])
+                        obj = lst
+                self.objects[index] = obj
+                return obj
+        else:
+            raise T7ReaderException("unknown object")
+
+
+def load(filename, **kwargs):
+    """
+    Loads the given t7 file using default settings; kwargs are forwarded
+    to `T7Reader`.
+    """
+    with open(filename, 'rb') as f:
+        reader = T7Reader(f, **kwargs)
+        return reader.read_obj()