|
a |
|
b/jz-char-rnn-tensorflow/utils.py |
|
|
1 |
import os |
|
|
2 |
import collections |
|
|
3 |
from six.moves import cPickle |
|
|
4 |
import numpy as np |
|
|
5 |
|
|
|
6 |
class TextLoader(): |
|
|
7 |
def __init__(self, data_dir, batch_size, seq_length): |
|
|
8 |
self.data_dir = data_dir |
|
|
9 |
self.batch_size = batch_size |
|
|
10 |
self.seq_length = seq_length |
|
|
11 |
|
|
|
12 |
input_file = os.path.join(data_dir, "input.txt") |
|
|
13 |
vocab_file = os.path.join(data_dir, "vocab.pkl") |
|
|
14 |
tensor_file = os.path.join(data_dir, "data.npy") |
|
|
15 |
|
|
|
16 |
if not (os.path.exists(vocab_file) and os.path.exists(tensor_file)): |
|
|
17 |
print("reading text file") |
|
|
18 |
self.preprocess(input_file, vocab_file, tensor_file) |
|
|
19 |
else: |
|
|
20 |
print("loading preprocessed files") |
|
|
21 |
self.load_preprocessed(vocab_file, tensor_file) |
|
|
22 |
self.create_batches() |
|
|
23 |
self.reset_batch_pointer() |
|
|
24 |
|
|
|
25 |
def preprocess(self, input_file, vocab_file, tensor_file): |
|
|
26 |
with open(input_file, "r") as f: |
|
|
27 |
data = f.read() |
|
|
28 |
counter = collections.Counter(data) |
|
|
29 |
count_pairs = sorted(counter.items(), key=lambda x: -x[1]) |
|
|
30 |
self.chars, _ = zip(*count_pairs) |
|
|
31 |
self.vocab_size = len(self.chars) |
|
|
32 |
self.vocab = dict(zip(self.chars, range(len(self.chars)))) |
|
|
33 |
with open(vocab_file, 'wb') as f: |
|
|
34 |
cPickle.dump(self.chars, f) |
|
|
35 |
self.tensor = np.array(list(map(self.vocab.get, data))) |
|
|
36 |
np.save(tensor_file, self.tensor) |
|
|
37 |
|
|
|
38 |
def load_preprocessed(self, vocab_file, tensor_file): |
|
|
39 |
with open(vocab_file, 'rb') as f: |
|
|
40 |
self.chars = cPickle.load(f) |
|
|
41 |
self.vocab_size = len(self.chars) |
|
|
42 |
self.vocab = dict(zip(self.chars, range(len(self.chars)))) |
|
|
43 |
self.tensor = np.load(tensor_file) |
|
|
44 |
self.num_batches = int(self.tensor.size / (self.batch_size * |
|
|
45 |
self.seq_length)) |
|
|
46 |
|
|
|
47 |
def create_batches(self): |
|
|
48 |
self.num_batches = int(self.tensor.size / (self.batch_size * |
|
|
49 |
self.seq_length)) |
|
|
50 |
|
|
|
51 |
# When the data (tesor) is too small, let's give them a better error message |
|
|
52 |
if self.num_batches==0: |
|
|
53 |
assert False, "Not enough data. Make seq_length and batch_size small." |
|
|
54 |
|
|
|
55 |
self.tensor = self.tensor[:self.num_batches * self.batch_size * self.seq_length] |
|
|
56 |
xdata = self.tensor |
|
|
57 |
ydata = np.copy(self.tensor) |
|
|
58 |
ydata[:-1] = xdata[1:] |
|
|
59 |
ydata[-1] = xdata[0] |
|
|
60 |
self.x_batches = np.split(xdata.reshape(self.batch_size, -1), self.num_batches, 1) |
|
|
61 |
self.y_batches = np.split(ydata.reshape(self.batch_size, -1), self.num_batches, 1) |
|
|
62 |
|
|
|
63 |
|
|
|
64 |
def next_batch(self): |
|
|
65 |
x, y = self.x_batches[self.pointer], self.y_batches[self.pointer] |
|
|
66 |
self.pointer += 1 |
|
|
67 |
return x, y |
|
|
68 |
|
|
|
69 |
def reset_batch_pointer(self): |
|
|
70 |
self.pointer = 0 |