Switch to unified view

a b/jz-char-rnn-tensorflow/utils.py
1
import os
2
import collections
3
from six.moves import cPickle
4
import numpy as np
5
6
class TextLoader():
7
    def __init__(self, data_dir, batch_size, seq_length):
8
        self.data_dir = data_dir
9
        self.batch_size = batch_size
10
        self.seq_length = seq_length
11
12
        input_file = os.path.join(data_dir, "input.txt")
13
        vocab_file = os.path.join(data_dir, "vocab.pkl")
14
        tensor_file = os.path.join(data_dir, "data.npy")
15
16
        if not (os.path.exists(vocab_file) and os.path.exists(tensor_file)):
17
            print("reading text file")
18
            self.preprocess(input_file, vocab_file, tensor_file)
19
        else:
20
            print("loading preprocessed files")
21
            self.load_preprocessed(vocab_file, tensor_file)
22
        self.create_batches()
23
        self.reset_batch_pointer()
24
25
    def preprocess(self, input_file, vocab_file, tensor_file):
26
        with open(input_file, "r") as f:
27
            data = f.read()
28
        counter = collections.Counter(data)
29
        count_pairs = sorted(counter.items(), key=lambda x: -x[1])
30
        self.chars, _ = zip(*count_pairs)
31
        self.vocab_size = len(self.chars)
32
        self.vocab = dict(zip(self.chars, range(len(self.chars))))
33
        with open(vocab_file, 'wb') as f:
34
            cPickle.dump(self.chars, f)
35
        self.tensor = np.array(list(map(self.vocab.get, data)))
36
        np.save(tensor_file, self.tensor)
37
38
    def load_preprocessed(self, vocab_file, tensor_file):
39
        with open(vocab_file, 'rb') as f:
40
            self.chars = cPickle.load(f)
41
        self.vocab_size = len(self.chars)
42
        self.vocab = dict(zip(self.chars, range(len(self.chars))))
43
        self.tensor = np.load(tensor_file)
44
        self.num_batches = int(self.tensor.size / (self.batch_size *
45
                                                   self.seq_length))
46
47
    def create_batches(self):
48
        self.num_batches = int(self.tensor.size / (self.batch_size *
49
                                                   self.seq_length))
50
51
        # When the data (tesor) is too small, let's give them a better error message
52
        if self.num_batches==0:
53
            assert False, "Not enough data. Make seq_length and batch_size small."
54
55
        self.tensor = self.tensor[:self.num_batches * self.batch_size * self.seq_length]
56
        xdata = self.tensor
57
        ydata = np.copy(self.tensor)
58
        ydata[:-1] = xdata[1:]
59
        ydata[-1] = xdata[0]
60
        self.x_batches = np.split(xdata.reshape(self.batch_size, -1), self.num_batches, 1)
61
        self.y_batches = np.split(ydata.reshape(self.batch_size, -1), self.num_batches, 1)
62
63
64
    def next_batch(self):
65
        x, y = self.x_batches[self.pointer], self.y_batches[self.pointer]
66
        self.pointer += 1
67
        return x, y
68
69
    def reset_batch_pointer(self):
70
        self.pointer = 0