|
a |
|
b/torchfile.py |
|
|
1 |
""" |
|
|
2 |
Copyright (c) 2016, Brendan Shillingford |
|
|
3 |
All rights reserved. |
|
|
4 |
|
|
|
5 |
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the |
|
|
6 |
following conditions are met: |
|
|
7 |
|
|
|
8 |
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following |
|
|
9 |
disclaimer. |
|
|
10 |
|
|
|
11 |
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the |
|
|
12 |
following disclaimer in the documentation and/or other materials provided with the distribution. |
|
|
13 |
|
|
|
14 |
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote |
|
|
15 |
products derived from this software without specific prior written permission. |
|
|
16 |
|
|
|
17 |
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, |
|
|
18 |
INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE |
|
|
19 |
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, |
|
|
20 |
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR |
|
|
21 |
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, |
|
|
22 |
WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
|
|
23 |
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
|
|
24 |
---------------------------------------------------------------------------------------------------------------------- |
|
|
25 |
The file was taken from https://github.com/bshillingford/python-torchfile and slightly modified |
|
|
26 |
---------------------------------------------------------------------------------------------------------------------- |
|
|
27 |
|
|
|
28 |
Mostly direct port of the Lua and C serialization implementation to |
|
|
29 |
Python, depending only on `struct`, `array`, and numpy. |
|
|
30 |
|
|
|
31 |
Supported types: |
|
|
32 |
* `nil` to Python `None` |
|
|
33 |
* numbers to Python floats, or by default a heuristic changes them to ints or |
|
|
34 |
longs if they are integral |
|
|
35 |
* booleans |
|
|
36 |
* strings: read as byte strings (Python 3) or normal strings (Python 2), like |
|
|
37 |
lua strings which don't support unicode, and that can contain null chars |
|
|
38 |
* tables converted to a special dict (*); if they are list-like (i.e. have |
|
|
39 |
numeric keys from 1 through n) they become a python list by default |
|
|
40 |
* Torch classes: supports Tensors and Storages, and most classes such as |
|
|
41 |
modules. Trivially extensible much like the Torch serialization code. |
|
|
42 |
Trivial torch classes like most `nn.Module` subclasses become |
|
|
43 |
`TorchObject`s. The `torch_readers` dict contains the mapping from class |
|
|
44 |
names to reading functions. |
|
|
45 |
* functions: loaded into the `LuaFunction` `namedtuple`, |
|
|
46 |
which simply wraps the raw serialized data, i.e. upvalues and code. |
|
|
47 |
These are mostly useless, but exist so you can deserialize anything. |
|
|
48 |
|
|
|
49 |
(*) Since Lua allows you to index a table with a table but Python does not, we |
|
|
50 |
replace dicts with a subclass that is hashable, and change its |
|
|
51 |
equality comparison behaviour to compare by reference. |
|
|
52 |
See `hashable_uniq_dict`. |
|
|
53 |
|
|
|
54 |
Currently, the implementation assumes the system-dependent binary Torch |
|
|
55 |
format, but minor refactoring can give support for the ascii format as well. |
|
|
56 |
""" |
|
|
57 |
|
|
|
58 |
TYPE_NIL = 0 |
|
|
59 |
TYPE_NUMBER = 1 |
|
|
60 |
TYPE_STRING = 2 |
|
|
61 |
TYPE_TABLE = 3 |
|
|
62 |
TYPE_TORCH = 4 |
|
|
63 |
TYPE_BOOLEAN = 5 |
|
|
64 |
TYPE_FUNCTION = 6 |
|
|
65 |
TYPE_RECUR_FUNCTION = 8 |
|
|
66 |
LEGACY_TYPE_RECUR_FUNCTION = 7 |
|
|
67 |
|
|
|
68 |
import struct |
|
|
69 |
from array import array |
|
|
70 |
import numpy as np |
|
|
71 |
import sys |
|
|
72 |
from collections import namedtuple |
|
|
73 |
|
|
|
74 |
LuaFunction = namedtuple('LuaFunction', |
|
|
75 |
['size', 'dumped', 'upvalues']) |
|
|
76 |
|
|
|
77 |
|
|
|
78 |
class hashable_uniq_dict(dict): |
|
|
79 |
""" |
|
|
80 |
Subclass of dict with equality and hashing semantics changed: |
|
|
81 |
equality and hashing is purely by reference/instance, to match |
|
|
82 |
the behaviour of lua tables. |
|
|
83 |
|
|
|
84 |
Supports lua-style dot indexing. |
|
|
85 |
|
|
|
86 |
This way, dicts can be keys of other dicts. |
|
|
87 |
""" |
|
|
88 |
|
|
|
89 |
def __hash__(self): |
|
|
90 |
return id(self) |
|
|
91 |
|
|
|
92 |
def __getattr__(self, key): |
|
|
93 |
return self.get(key) |
|
|
94 |
|
|
|
95 |
def __eq__(self, other): |
|
|
96 |
return id(self) == id(other) |
|
|
97 |
# TODO: dict's __lt__ etc. still exist |
|
|
98 |
|
|
|
99 |
torch_readers = {} |
|
|
100 |
|
|
|
101 |
|
|
|
102 |
def add_tensor_reader(typename, dtype): |
|
|
103 |
def read_tensor_generic(reader, version): |
|
|
104 |
# source: |
|
|
105 |
# https://github.com/torch/torch7/blob/master/generic/Tensor.c#L1243 |
|
|
106 |
ndim = reader.read_int() |
|
|
107 |
|
|
|
108 |
# read size: |
|
|
109 |
size = reader.read_long_array(ndim) |
|
|
110 |
# read stride: |
|
|
111 |
stride = reader.read_long_array(ndim) |
|
|
112 |
# storage offset: |
|
|
113 |
storage_offset = reader.read_long() - 1 |
|
|
114 |
# read storage: |
|
|
115 |
storage = reader.read_obj() |
|
|
116 |
|
|
|
117 |
if storage is None or ndim == 0 or len(size) == 0 or len(stride) == 0: |
|
|
118 |
# empty torch tensor |
|
|
119 |
return np.empty((0), dtype=dtype) |
|
|
120 |
|
|
|
121 |
# convert stride to numpy style (i.e. in bytes) |
|
|
122 |
stride = [storage.dtype.itemsize * x for x in stride] |
|
|
123 |
|
|
|
124 |
# create numpy array that indexes into the storage: |
|
|
125 |
return np.lib.stride_tricks.as_strided( |
|
|
126 |
storage[storage_offset:], |
|
|
127 |
shape=size, |
|
|
128 |
strides=stride) |
|
|
129 |
torch_readers[typename] = read_tensor_generic |
|
|
130 |
add_tensor_reader(b'torch.ByteTensor', dtype=np.uint8) |
|
|
131 |
add_tensor_reader(b'torch.CharTensor', dtype=np.int8) |
|
|
132 |
add_tensor_reader(b'torch.ShortTensor', dtype=np.int16) |
|
|
133 |
add_tensor_reader(b'torch.IntTensor', dtype=np.int32) |
|
|
134 |
add_tensor_reader(b'torch.LongTensor', dtype=np.int64) |
|
|
135 |
add_tensor_reader(b'torch.FloatTensor', dtype=np.float32) |
|
|
136 |
add_tensor_reader(b'torch.DoubleTensor', dtype=np.float64) |
|
|
137 |
add_tensor_reader(b'torch.CudaTensor', np.float32) # float |
|
|
138 |
add_tensor_reader(b'torch.CudaByteTensor', dtype=np.uint8) |
|
|
139 |
add_tensor_reader(b'torch.CudaCharTensor', dtype=np.int8) |
|
|
140 |
add_tensor_reader(b'torch.CudaShortTensor', dtype=np.int16) |
|
|
141 |
add_tensor_reader(b'torch.CudaIntTensor', dtype=np.int32) |
|
|
142 |
add_tensor_reader(b'torch.CudaDoubleTensor', dtype=np.float64) |
|
|
143 |
|
|
|
144 |
|
|
|
145 |
def add_storage_reader(typename, dtype): |
|
|
146 |
def read_storage(reader, version): |
|
|
147 |
# source: |
|
|
148 |
# https://github.com/torch/torch7/blob/master/generic/Storage.c#L244 |
|
|
149 |
size = reader.read_long() |
|
|
150 |
return np.fromfile(reader.f, dtype=dtype, count=size) |
|
|
151 |
torch_readers[typename] = read_storage |
|
|
152 |
add_storage_reader(b'torch.ByteStorage', dtype=np.uint8) |
|
|
153 |
add_storage_reader(b'torch.CharStorage', dtype=np.int8) |
|
|
154 |
add_storage_reader(b'torch.ShortStorage', dtype=np.int16) |
|
|
155 |
add_storage_reader(b'torch.IntStorage', dtype=np.int32) |
|
|
156 |
add_storage_reader(b'torch.LongStorage', dtype=np.int64) |
|
|
157 |
add_storage_reader(b'torch.FloatStorage', dtype=np.float32) |
|
|
158 |
add_storage_reader(b'torch.DoubleStorage', dtype=np.float64) |
|
|
159 |
add_storage_reader(b'torch.CudaStorage', dtype=np.float32) # float |
|
|
160 |
add_storage_reader(b'torch.CudaByteStorage', dtype=np.uint8) |
|
|
161 |
add_storage_reader(b'torch.CudaCharStorage', dtype=np.int8) |
|
|
162 |
add_storage_reader(b'torch.CudaShortStorage', dtype=np.int16) |
|
|
163 |
add_storage_reader(b'torch.CudaIntStorage', dtype=np.int32) |
|
|
164 |
add_storage_reader(b'torch.CudaDoubleStorage', dtype=np.float64) |
|
|
165 |
|
|
|
166 |
|
|
|
167 |
class TorchObject(object): |
|
|
168 |
""" |
|
|
169 |
Simple torch object, used by `add_trivial_class_reader`. |
|
|
170 |
Supports both forms of lua-style indexing, i.e. getattr and getitem. |
|
|
171 |
Use the `torch_typename` method to get the object's torch class name. |
|
|
172 |
|
|
|
173 |
Equality is by reference, as usual for lua (and the default for Python |
|
|
174 |
objects). |
|
|
175 |
""" |
|
|
176 |
|
|
|
177 |
def __init__(self, typename, obj): |
|
|
178 |
self._typename = typename |
|
|
179 |
self._obj = obj |
|
|
180 |
|
|
|
181 |
def __getattr__(self, k): |
|
|
182 |
return self._obj.get(k) |
|
|
183 |
|
|
|
184 |
def __getitem__(self, k): |
|
|
185 |
return self._obj.get(k) |
|
|
186 |
|
|
|
187 |
def torch_typename(self): |
|
|
188 |
return self._typename |
|
|
189 |
|
|
|
190 |
def __repr__(self): |
|
|
191 |
return "TorchObject(%s, %s)" % (self._typename, repr(self._obj)) |
|
|
192 |
|
|
|
193 |
def __str__(self): |
|
|
194 |
return repr(self) |
|
|
195 |
|
|
|
196 |
def __dir__(self): |
|
|
197 |
keys = list(self._obj.keys()) |
|
|
198 |
keys.append('torch_typename') |
|
|
199 |
return keys |
|
|
200 |
|
|
|
201 |
|
|
|
202 |
def add_trivial_class_reader(typename): |
|
|
203 |
def reader(reader, version): |
|
|
204 |
obj = reader.read_obj() |
|
|
205 |
return TorchObject(typename, obj) |
|
|
206 |
torch_readers[typename] = reader |
|
|
207 |
for mod in [b"nn.ConcatTable", b"nn.SpatialAveragePooling", |
|
|
208 |
b"nn.TemporalConvolutionFB", b"nn.BCECriterion", b"nn.Reshape", b"nn.gModule", |
|
|
209 |
b"nn.SparseLinear", b"nn.WeightedLookupTable", b"nn.CAddTable", |
|
|
210 |
b"nn.TemporalConvolution", b"nn.PairwiseDistance", b"nn.WeightedMSECriterion", |
|
|
211 |
b"nn.SmoothL1Criterion", b"nn.TemporalSubSampling", b"nn.TanhShrink", |
|
|
212 |
b"nn.MixtureTable", b"nn.Mul", b"nn.LogSoftMax", b"nn.Min", b"nn.Exp", b"nn.Add", |
|
|
213 |
b"nn.BatchNormalization", b"nn.AbsCriterion", b"nn.MultiCriterion", |
|
|
214 |
b"nn.LookupTableGPU", b"nn.Max", b"nn.MulConstant", b"nn.NarrowTable", b"nn.View", |
|
|
215 |
b"nn.ClassNLLCriterionWithUNK", b"nn.VolumetricConvolution", |
|
|
216 |
b"nn.SpatialSubSampling", b"nn.HardTanh", b"nn.DistKLDivCriterion", |
|
|
217 |
b"nn.SplitTable", b"nn.DotProduct", b"nn.HingeEmbeddingCriterion", |
|
|
218 |
b"nn.SpatialBatchNormalization", b"nn.DepthConcat", b"nn.Sigmoid", |
|
|
219 |
b"nn.SpatialAdaptiveMaxPooling", b"nn.Parallel", b"nn.SoftShrink", |
|
|
220 |
b"nn.SpatialSubtractiveNormalization", b"nn.TrueNLLCriterion", b"nn.Log", |
|
|
221 |
b"nn.SpatialDropout", b"nn.LeakyReLU", b"nn.VolumetricMaxPooling", |
|
|
222 |
b"nn.KMaxPooling", b"nn.Linear", b"nn.Euclidean", b"nn.CriterionTable", |
|
|
223 |
b"nn.SpatialMaxPooling", b"nn.TemporalKMaxPooling", b"nn.MultiMarginCriterion", |
|
|
224 |
b"nn.ELU", b"nn.CSubTable", b"nn.MultiLabelMarginCriterion", b"nn.Copy", |
|
|
225 |
b"nn.CuBLASWrapper", b"nn.L1HingeEmbeddingCriterion", |
|
|
226 |
b"nn.VolumetricAveragePooling", b"nn.StochasticGradient", |
|
|
227 |
b"nn.SpatialContrastiveNormalization", b"nn.CosineEmbeddingCriterion", |
|
|
228 |
b"nn.CachingLookupTable", b"nn.FeatureLPPooling", b"nn.Padding", b"nn.Container", |
|
|
229 |
b"nn.MarginRankingCriterion", b"nn.Module", b"nn.ParallelCriterion", |
|
|
230 |
b"nn.DataParallelTable", b"nn.Concat", b"nn.CrossEntropyCriterion", |
|
|
231 |
b"nn.LookupTable", b"nn.SpatialSoftMax", b"nn.HardShrink", b"nn.Abs", b"nn.SoftMin", |
|
|
232 |
b"nn.WeightedEuclidean", b"nn.Replicate", b"nn.DataParallel", |
|
|
233 |
b"nn.OneBitQuantization", b"nn.OneBitDataParallel", b"nn.AddConstant", b"nn.L1Cost", |
|
|
234 |
b"nn.HSM", b"nn.PReLU", b"nn.JoinTable", b"nn.ClassNLLCriterion", b"nn.CMul", |
|
|
235 |
b"nn.CosineDistance", b"nn.Index", b"nn.Mean", b"nn.FFTWrapper", b"nn.Dropout", |
|
|
236 |
b"nn.SpatialConvolutionCuFFT", b"nn.SoftPlus", b"nn.AbstractParallel", |
|
|
237 |
b"nn.SequentialCriterion", b"nn.LocallyConnected", |
|
|
238 |
b"nn.SpatialDivisiveNormalization", b"nn.L1Penalty", b"nn.Threshold", b"nn.Power", |
|
|
239 |
b"nn.Sqrt", b"nn.MM", b"nn.GroupKMaxPooling", b"nn.CrossMapNormalization", |
|
|
240 |
b"nn.ReLU", b"nn.ClassHierarchicalNLLCriterion", b"nn.Optim", b"nn.SoftMax", |
|
|
241 |
b"nn.SpatialConvolutionMM", b"nn.Cosine", b"nn.Clamp", b"nn.CMulTable", |
|
|
242 |
b"nn.LogSigmoid", b"nn.LinearNB", b"nn.TemporalMaxPooling", b"nn.MSECriterion", |
|
|
243 |
b"nn.Sum", b"nn.SoftSign", b"nn.Normalize", b"nn.ParallelTable", b"nn.FlattenTable", |
|
|
244 |
b"nn.CDivTable", b"nn.Tanh", b"nn.ModuleFromCriterion", b"nn.Square", b"nn.Select", |
|
|
245 |
b"nn.GradientReversal", b"nn.SpatialFullConvolutionMap", b"nn.SpatialConvolution", |
|
|
246 |
b"nn.Criterion", b"nn.SpatialConvolutionMap", b"nn.SpatialLPPooling", |
|
|
247 |
b"nn.Sequential", b"nn.Transpose", b"nn.SpatialUpSamplingNearest", |
|
|
248 |
b"nn.SpatialFullConvolution", b"nn.ModelParallel", b"nn.RReLU", |
|
|
249 |
b"nn.SpatialZeroPadding", b"nn.Identity", b"nn.Narrow", b"nn.MarginCriterion", |
|
|
250 |
b"nn.SelectTable", b"nn.VolumetricFullConvolution", |
|
|
251 |
b"nn.SpatialFractionalMaxPooling", b"fbnn.ProjectiveGradientNormalization", |
|
|
252 |
b"fbnn.Probe", b"fbnn.SparseLinear", b"cudnn._Pooling3D", |
|
|
253 |
b"cudnn.VolumetricMaxPooling", b"cudnn.SpatialCrossEntropyCriterion", |
|
|
254 |
b"cudnn.VolumetricConvolution", b"cudnn.SpatialAveragePooling", b"cudnn.Tanh", |
|
|
255 |
b"cudnn.LogSoftMax", b"cudnn.SpatialConvolution", b"cudnn._Pooling", |
|
|
256 |
b"cudnn.SpatialMaxPooling", b"cudnn.ReLU", b"cudnn.SpatialCrossMapLRN", |
|
|
257 |
b"cudnn.SoftMax", b"cudnn._Pointwise", b"cudnn.SpatialSoftMax", b"cudnn.Sigmoid", |
|
|
258 |
b"cudnn.SpatialLogSoftMax", b"cudnn.VolumetricAveragePooling", b"nngraph.Node", |
|
|
259 |
b"nngraph.JustTable", b"graph.Edge", b"graph.Node", b"graph.Graph"]: |
|
|
260 |
|
|
|
261 |
add_trivial_class_reader(mod) |
|
|
262 |
|
|
|
263 |
|
|
|
264 |
class T7ReaderException(Exception): |
|
|
265 |
pass |
|
|
266 |
|
|
|
267 |
|
|
|
268 |
class T7Reader: |
|
|
269 |
|
|
|
270 |
def __init__(self, |
|
|
271 |
fileobj, |
|
|
272 |
use_list_heuristic=True, |
|
|
273 |
use_int_heuristic=True, |
|
|
274 |
force_deserialize_classes=True, |
|
|
275 |
force_8bytes_long=True): |
|
|
276 |
""" |
|
|
277 |
Params: |
|
|
278 |
* `fileobj` file object to read from, must be actual file object |
|
|
279 |
as it must support array, struct, and numpy |
|
|
280 |
* `use_list_heuristic`: automatically turn tables with only consecutive |
|
|
281 |
positive integral indices into lists |
|
|
282 |
(default True) |
|
|
283 |
* `use_int_heuristic`: cast all whole floats into ints (default True) |
|
|
284 |
* `force_deserialize_classes`: deserialize all classes, not just the |
|
|
285 |
whitelisted ones (default True) |
|
|
286 |
""" |
|
|
287 |
self.f = fileobj |
|
|
288 |
self.objects = {} # read objects so far |
|
|
289 |
|
|
|
290 |
self.use_list_heuristic = use_list_heuristic |
|
|
291 |
self.use_int_heuristic = use_int_heuristic |
|
|
292 |
self.force_deserialize_classes = force_deserialize_classes |
|
|
293 |
self.force_8bytes_long = force_8bytes_long |
|
|
294 |
|
|
|
295 |
def _read(self, fmt): |
|
|
296 |
sz = struct.calcsize(fmt) |
|
|
297 |
b = self.f.read(sz) |
|
|
298 |
if b == b'': |
|
|
299 |
# print('x') |
|
|
300 |
s = (0,) |
|
|
301 |
else: |
|
|
302 |
s = struct.unpack(fmt, b) |
|
|
303 |
|
|
|
304 |
# print(s) |
|
|
305 |
return s |
|
|
306 |
|
|
|
307 |
def read_boolean(self): |
|
|
308 |
return self.read_int() == 1 |
|
|
309 |
|
|
|
310 |
def read_int(self): |
|
|
311 |
return self._read('i')[0] |
|
|
312 |
|
|
|
313 |
def read_long(self): |
|
|
314 |
if self.force_8bytes_long: |
|
|
315 |
return self._read('q')[0] |
|
|
316 |
else: |
|
|
317 |
return self._read('l')[0] |
|
|
318 |
|
|
|
319 |
def read_long_array(self, n): |
|
|
320 |
if self.force_8bytes_long: |
|
|
321 |
lst = [] |
|
|
322 |
for i in range(n): |
|
|
323 |
lst.append(self.read_long()) |
|
|
324 |
return lst |
|
|
325 |
else: |
|
|
326 |
arr = array('l') |
|
|
327 |
arr.fromfile(self.f, n) |
|
|
328 |
return arr.tolist() |
|
|
329 |
|
|
|
330 |
def read_float(self): |
|
|
331 |
return self._read('f')[0] |
|
|
332 |
|
|
|
333 |
def read_double(self): |
|
|
334 |
return self._read('d')[0] |
|
|
335 |
|
|
|
336 |
def read_string(self): |
|
|
337 |
size = self.read_int() |
|
|
338 |
return self.f.read(size) |
|
|
339 |
|
|
|
340 |
def read_obj(self): |
|
|
341 |
typeidx = self.read_int() |
|
|
342 |
if typeidx == TYPE_NIL: |
|
|
343 |
return None |
|
|
344 |
elif typeidx == TYPE_NUMBER: |
|
|
345 |
x = self.read_double() |
|
|
346 |
# Extra checking for integral numbers: |
|
|
347 |
if self.use_int_heuristic and x.is_integer(): |
|
|
348 |
return int(x) |
|
|
349 |
return x |
|
|
350 |
elif typeidx == TYPE_BOOLEAN: |
|
|
351 |
return self.read_boolean() |
|
|
352 |
elif typeidx == TYPE_STRING: |
|
|
353 |
return self.read_string() |
|
|
354 |
elif (typeidx == TYPE_TABLE or typeidx == TYPE_TORCH |
|
|
355 |
or typeidx == TYPE_FUNCTION or typeidx == TYPE_RECUR_FUNCTION |
|
|
356 |
or typeidx == LEGACY_TYPE_RECUR_FUNCTION): |
|
|
357 |
# read the index |
|
|
358 |
index = self.read_int() |
|
|
359 |
|
|
|
360 |
# check it is loaded already |
|
|
361 |
if index in self.objects: |
|
|
362 |
return self.objects[index] |
|
|
363 |
|
|
|
364 |
# otherwise read it |
|
|
365 |
if (typeidx == TYPE_FUNCTION or typeidx == TYPE_RECUR_FUNCTION |
|
|
366 |
or typeidx == LEGACY_TYPE_RECUR_FUNCTION): |
|
|
367 |
size = self.read_int() |
|
|
368 |
dumped = self.f.read(size) |
|
|
369 |
upvalues = self.read_obj() |
|
|
370 |
obj = LuaFunction(size, dumped, upvalues) |
|
|
371 |
self.objects[index] = obj |
|
|
372 |
return obj |
|
|
373 |
elif typeidx == TYPE_TORCH: |
|
|
374 |
version = self.read_string() |
|
|
375 |
if version.startswith(b'V '): |
|
|
376 |
versionNumber = int(version.partition(b' ')[2]) |
|
|
377 |
className = self.read_string() |
|
|
378 |
else: |
|
|
379 |
className = version |
|
|
380 |
versionNumber = 0 # created before existence of versioning |
|
|
381 |
# print(className) |
|
|
382 |
if className not in torch_readers: |
|
|
383 |
if not self.force_deserialize_classes: |
|
|
384 |
raise T7ReaderException( |
|
|
385 |
'unsupported torch class: <%s>' % className) |
|
|
386 |
obj = TorchObject(className, self.read_obj()) |
|
|
387 |
else: |
|
|
388 |
obj = torch_readers[className](self, version) |
|
|
389 |
self.objects[index] = obj |
|
|
390 |
return obj |
|
|
391 |
else: # it is a table: returns a custom dict or a list |
|
|
392 |
size = self.read_int() |
|
|
393 |
obj = hashable_uniq_dict() # custom hashable dict, can be a key |
|
|
394 |
key_sum = 0 # for checking if keys are consecutive |
|
|
395 |
keys_natural = True # and also natural numbers 1..n. |
|
|
396 |
# If so, returns a list with indices converted to 0-indices. |
|
|
397 |
for i in range(size): |
|
|
398 |
k = self.read_obj() |
|
|
399 |
v = self.read_obj() |
|
|
400 |
obj[k] = v |
|
|
401 |
|
|
|
402 |
if self.use_list_heuristic: |
|
|
403 |
if not isinstance(k, int) or k <= 0: |
|
|
404 |
keys_natural = False |
|
|
405 |
elif isinstance(k, int): |
|
|
406 |
key_sum += k |
|
|
407 |
if self.use_list_heuristic: |
|
|
408 |
# n(n+1)/2 = sum <=> consecutive and natural numbers |
|
|
409 |
n = len(obj) |
|
|
410 |
if keys_natural and n * (n + 1) == 2 * key_sum: |
|
|
411 |
lst = [] |
|
|
412 |
for i in range(len(obj)): |
|
|
413 |
lst.append(obj[i + 1]) |
|
|
414 |
obj = lst |
|
|
415 |
self.objects[index] = obj |
|
|
416 |
return obj |
|
|
417 |
else: |
|
|
418 |
raise T7ReaderException("unknown object") |
|
|
419 |
|
|
|
420 |
|
|
|
421 |
def load(filename, **kwargs): |
|
|
422 |
""" |
|
|
423 |
Loads the given t7 file using default settings; kwargs are forwarded |
|
|
424 |
to `T7Reader`. |
|
|
425 |
""" |
|
|
426 |
with open(filename, 'rb') as f: |
|
|
427 |
reader = T7Reader(f, **kwargs) |
|
|
428 |
return reader.read_obj() |