Diff of /code/data.lua [000000] .. [b758a2]

Switch to unified view

a b/code/data.lua
1
require 'torch';
2
require 'sys';
3
require 'image';
4
local dir = require 'pl.dir';
5
local ffi = require 'ffi';
6
7
function getImagePaths(folder)
8
    -- obtain list of image files
9
    local classes = {}
10
    local classPaths = {}
11
    local dirs = dir.getdirectories(folder);
12
    for k,dirpath in ipairs(dirs) do
13
        local class = paths.basename(dirpath)
14
        table.insert(classes, class)
15
        table.insert(classPaths, dirpath)
16
    end
17
18
    -- define command-line tools, try your best to maintain OSX compatibility
19
    local wc = 'wc'
20
    local cut = 'cut'
21
    local find = 'find'
22
    if ffi.os == 'OSX' then
23
        wc = 'gwc'
24
        cut = 'gcut'
25
        find = 'gfind'
26
    end
27
28
    -- options for the GNU find command
29
    local extensionList  = {'jpg', 'JPG', 'png', 'PNG', 'jpeg', 'JPEG', 'ppm', 'PPM', 'bmp', 'BMP'}
30
    local findOptions = ' -iname "*.' .. extensionList[1] .. '"'
31
    for i=2,#extensionList do
32
        findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"'
33
    end
34
35
    -- find the image path names
36
    local imagePaths = torch.CharTensor()   -- path to each image in dataset
37
    local imageClass = torch.LongTensor()   -- class index of each image (class index in self.classes)
38
    local classList = {}            -- index of imageList to each image of a particular class
39
40
    -- create file listing the paths to every image
41
    local classFindFiles = {}
42
    for i=1,#classes do
43
        classFindFiles[i] = os.tmpname()
44
    end
45
    local combinedFindList = os.tmpname()
46
47
    local tmpfile = os.tmpname()
48
    local tmphandle = assert(io.open(tmpfile, 'w'))
49
    for i,class in ipairs(classes) do
50
        local command = find .. ' "' .. classPaths[i] .. '" ' .. findOptions .. ' >>"' .. classFindFiles[i] .. '" \n'
51
        tmphandle:write(command)
52
    end
53
    io.close(tmphandle)
54
    os.execute('bash ' .. tmpfile)
55
    os.execute('rm -f ' .. tmpfile)
56
57
    local tmpfile = os.tmpname()
58
    local tmphandle = assert(io.open(tmpfile, 'w'))
59
    -- concat all finds to a single large file in the order of self.classes
60
    for i=1,#classes do
61
        local command = 'cat "' .. classFindFiles[i] .. '" >>' .. combinedFindList .. ' \n'
62
        tmphandle:write(command)
63
    end
64
    io.close(tmphandle)
65
    os.execute('bash ' .. tmpfile)
66
    os.execute('rm -f ' .. tmpfile)
67
68
    local maxPathLength = tonumber(sys.fexecute(wc .. " -L '" .. combinedFindList .. "' |" .. cut .. " -f1 -d' '")) + 1
69
    local length = tonumber(sys.fexecute(wc .. " -l '" .. combinedFindList .. "' |" .. cut .. " -f1 -d' '"))
70
71
    imagePaths:resize(length, maxPathLength):fill(0)
72
    local s_data = imagePaths:data()
73
    for line in io.lines(combinedFindList) do
74
        ffi.copy(s_data, line)
75
        s_data = s_data + maxPathLength
76
    end
77
    numSamples = imagePaths:size(1)
78
    print(numSamples ..  ' samples found.')
79
80
    imageClass:resize(numSamples)
81
    local runningIndex = 0
82
    for i=1,#classes do
83
        local length = tonumber(sys.fexecute(wc .. " -l '" .. classFindFiles[i] .. "' |" .. cut .. " -f1 -d' '"))
84
        classList[i] = torch.linspace(runningIndex + 1, runningIndex + length, length):long()
85
        imageClass[{{runningIndex + 1, runningIndex + length}}]:fill(i)
86
        runningIndex = runningIndex + length
87
    end
88
89
    local tmpfilelistall = ''
90
    for i=1,#(classFindFiles) do
91
        tmpfilelistall = tmpfilelistall .. ' "' .. classFindFiles[i] .. '"'
92
        if i % 1000 == 0 then
93
            os.execute('rm -f ' .. tmpfilelistall)
94
            tmpfilelistall = ''
95
        end
96
    end
97
    os.execute('rm -f '  .. tmpfilelistall)
98
    os.execute('rm -f "' .. combinedFindList .. '"')
99
100
    return classes, classList, imagePaths
101
end
102
103
function getSample(classes, sampleList, imagePaths)
104
    dataTable = {}
105
    scalarTable = {}
106
    N = 0
107
    for i=1,#classes do
108
        for j=1,sampleList[i]:nElement() do
109
            local imgpath = ffi.string(torch.data(imagePaths[sampleList[i][j]]))
110
            out = image.load(imgpath, 3, 'float')
111
            table.insert(dataTable, out)
112
            table.insert(scalarTable, i)
113
            N = N + 1
114
        end
115
    end
116
    data = torch.Tensor(N, 3, 101, 101)
117
    scalarLabels = torch.LongTensor(N):fill(-1111)
118
    for i=1,#dataTable do
119
        data[i]:copy(dataTable[i])
120
        scalarLabels[i] = scalarTable[i]
121
    end
122
    dataset = {}
123
    dataset.data = data
124
    dataset.label = scalarLabels
125
126
    setmetatable(dataset,
127
        {__index = function(t, i) 
128
                        return {t.data[i], t.label[i]}
129
                    end}
130
    );
131
132
    function dataset:size() 
133
        return self.data:size(1) 
134
    end
135
136
    -- data normalization
137
    mean = {}
138
    stdv  = {}
139
    for i=1,3 do
140
        mean[i] = dataset.data[{ {}, {i}, {}, {}  }]:mean()
141
        dataset.data[{ {}, {i}, {}, {}  }]:add(-mean[i])
142
        
143
        stdv[i] = dataset.data[{ {}, {i}, {}, {}  }]:std()
144
        if stdv[i] ~= 0 then
145
            dataset.data[{ {}, {i}, {}, {}  }]:div(stdv[i])
146
        end
147
    end
148
149
    return dataset
150
end
151
152
function getRandomSample(classes, batchSize, classList, imagePaths)
153
    dataTable = {}
154
    scalarTable = {}
155
    N = 0
156
    for i=1,#classes do
157
        for j=1,batchSize do
158
            local index = math.max(1, math.ceil(torch.uniform() * classList[i]:nElement()))
159
            local imgpath = ffi.string(torch.data(imagePaths[classList[i][index]]))
160
            out = image.load(imgpath, 3, 'float')
161
            table.insert(dataTable, out)
162
            table.insert(scalarTable, i)
163
            N = N + 1
164
        end
165
    end
166
    data = torch.Tensor(N, 3, 101, 101)
167
    scalarLabels = torch.LongTensor(N):fill(-1111)
168
    for i=1,#dataTable do
169
        data[i]:copy(dataTable[i])
170
        scalarLabels[i] = scalarTable[i]
171
    end
172
    dataset = {}
173
    dataset.data = data
174
    dataset.label = scalarLabels
175
176
    setmetatable(dataset,
177
        {__index = function(t, i) 
178
                        return {t.data[i], t.label[i]}
179
                    end}
180
    );
181
182
    function dataset:size() 
183
        return self.data:size(1) 
184
    end
185
186
    -- data normalization
187
    mean = {}
188
    stdv  = {}
189
    for i=1,3 do
190
        mean[i] = dataset.data[{ {}, {i}, {}, {}  }]:mean()
191
        dataset.data[{ {}, {i}, {}, {}  }]:add(-mean[i])
192
        
193
        stdv[i] = dataset.data[{ {}, {i}, {}, {}  }]:std()
194
        if stdv[i] ~= 0 then
195
            dataset.data[{ {}, {i}, {}, {}  }]:div(stdv[i])
196
        end
197
    end
198
199
    return dataset
200
end
201
202
function getBatchSizes(classes, classList, batchSize)
203
    local numSamples = 0
204
    for i=1,#classes do
205
        numSamples = numSamples + classList[i]:nElement()
206
    end
207
    local numBatches = math.ceil(numSamples/batchSize)
208
209
    local batchSizes = {}
210
    for i=1,#classes do
211
        local roundFlag = 0
212
        local batchSum = 0
213
        batchSizes[i] = {}
214
        for j=1,numBatches-1 do
215
            if roundFlag == 0 then
216
                batchSizes[i][j] = math.floor(classList[i]:nElement()/numBatches)
217
            else
218
                batchSizes[i][j] = math.ceil(classList[i]:nElement()/numBatches)
219
            end
220
            
221
            batchSum = batchSum + batchSizes[i][j]
222
            
223
            if j*classList[i]:nElement()/numBatches > batchSum then
224
                roundFlag = 1
225
            else
226
                roundFlag = 0
227
            end
228
        end
229
        batchSizes[i][numBatches] = classList[i]:nElement() - batchSum
230
    end
231
    
232
    return batchSizes, numBatches, numSamples
233
end
234
235
function shuffleImages(classList, classes)
236
    local temp = {}
237
    for i=1,#classes do
238
        local perm = torch.randperm(classList[i]:size(1))
239
        temp[i] = torch.LongTensor(classList[i]:size(1))
240
        for j=1,classList[i]:size(1) do
241
            temp[i][j] = classList[i][perm[j]]
242
        end
243
    end
244
    return temp
245
end