require 'torch';
require 'sys';
require 'image';
local dir = require 'pl.dir';
local ffi = require 'ffi';
function getImagePaths(folder)
-- obtain list of image files
local classes = {}
local classPaths = {}
local dirs = dir.getdirectories(folder);
for k,dirpath in ipairs(dirs) do
local class = paths.basename(dirpath)
table.insert(classes, class)
table.insert(classPaths, dirpath)
end
-- define command-line tools, try your best to maintain OSX compatibility
local wc = 'wc'
local cut = 'cut'
local find = 'find'
if ffi.os == 'OSX' then
wc = 'gwc'
cut = 'gcut'
find = 'gfind'
end
-- options for the GNU find command
local extensionList = {'jpg', 'JPG', 'png', 'PNG', 'jpeg', 'JPEG', 'ppm', 'PPM', 'bmp', 'BMP'}
local findOptions = ' -iname "*.' .. extensionList[1] .. '"'
for i=2,#extensionList do
findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"'
end
-- find the image path names
local imagePaths = torch.CharTensor() -- path to each image in dataset
local imageClass = torch.LongTensor() -- class index of each image (class index in self.classes)
local classList = {} -- index of imageList to each image of a particular class
-- create file listing the paths to every image
local classFindFiles = {}
for i=1,#classes do
classFindFiles[i] = os.tmpname()
end
local combinedFindList = os.tmpname()
local tmpfile = os.tmpname()
local tmphandle = assert(io.open(tmpfile, 'w'))
for i,class in ipairs(classes) do
local command = find .. ' "' .. classPaths[i] .. '" ' .. findOptions .. ' >>"' .. classFindFiles[i] .. '" \n'
tmphandle:write(command)
end
io.close(tmphandle)
os.execute('bash ' .. tmpfile)
os.execute('rm -f ' .. tmpfile)
local tmpfile = os.tmpname()
local tmphandle = assert(io.open(tmpfile, 'w'))
-- concat all finds to a single large file in the order of self.classes
for i=1,#classes do
local command = 'cat "' .. classFindFiles[i] .. '" >>' .. combinedFindList .. ' \n'
tmphandle:write(command)
end
io.close(tmphandle)
os.execute('bash ' .. tmpfile)
os.execute('rm -f ' .. tmpfile)
local maxPathLength = tonumber(sys.fexecute(wc .. " -L '" .. combinedFindList .. "' |" .. cut .. " -f1 -d' '")) + 1
local length = tonumber(sys.fexecute(wc .. " -l '" .. combinedFindList .. "' |" .. cut .. " -f1 -d' '"))
imagePaths:resize(length, maxPathLength):fill(0)
local s_data = imagePaths:data()
for line in io.lines(combinedFindList) do
ffi.copy(s_data, line)
s_data = s_data + maxPathLength
end
numSamples = imagePaths:size(1)
print(numSamples .. ' samples found.')
imageClass:resize(numSamples)
local runningIndex = 0
for i=1,#classes do
local length = tonumber(sys.fexecute(wc .. " -l '" .. classFindFiles[i] .. "' |" .. cut .. " -f1 -d' '"))
classList[i] = torch.linspace(runningIndex + 1, runningIndex + length, length):long()
imageClass[{{runningIndex + 1, runningIndex + length}}]:fill(i)
runningIndex = runningIndex + length
end
local tmpfilelistall = ''
for i=1,#(classFindFiles) do
tmpfilelistall = tmpfilelistall .. ' "' .. classFindFiles[i] .. '"'
if i % 1000 == 0 then
os.execute('rm -f ' .. tmpfilelistall)
tmpfilelistall = ''
end
end
os.execute('rm -f ' .. tmpfilelistall)
os.execute('rm -f "' .. combinedFindList .. '"')
return classes, classList, imagePaths
end
function getSample(classes, sampleList, imagePaths)
dataTable = {}
scalarTable = {}
N = 0
for i=1,#classes do
for j=1,sampleList[i]:nElement() do
local imgpath = ffi.string(torch.data(imagePaths[sampleList[i][j]]))
out = image.load(imgpath, 3, 'float')
table.insert(dataTable, out)
table.insert(scalarTable, i)
N = N + 1
end
end
data = torch.Tensor(N, 3, 101, 101)
scalarLabels = torch.LongTensor(N):fill(-1111)
for i=1,#dataTable do
data[i]:copy(dataTable[i])
scalarLabels[i] = scalarTable[i]
end
dataset = {}
dataset.data = data
dataset.label = scalarLabels
setmetatable(dataset,
{__index = function(t, i)
return {t.data[i], t.label[i]}
end}
);
function dataset:size()
return self.data:size(1)
end
-- data normalization
mean = {}
stdv = {}
for i=1,3 do
mean[i] = dataset.data[{ {}, {i}, {}, {} }]:mean()
dataset.data[{ {}, {i}, {}, {} }]:add(-mean[i])
stdv[i] = dataset.data[{ {}, {i}, {}, {} }]:std()
if stdv[i] ~= 0 then
dataset.data[{ {}, {i}, {}, {} }]:div(stdv[i])
end
end
return dataset
end
function getRandomSample(classes, batchSize, classList, imagePaths)
dataTable = {}
scalarTable = {}
N = 0
for i=1,#classes do
for j=1,batchSize do
local index = math.max(1, math.ceil(torch.uniform() * classList[i]:nElement()))
local imgpath = ffi.string(torch.data(imagePaths[classList[i][index]]))
out = image.load(imgpath, 3, 'float')
table.insert(dataTable, out)
table.insert(scalarTable, i)
N = N + 1
end
end
data = torch.Tensor(N, 3, 101, 101)
scalarLabels = torch.LongTensor(N):fill(-1111)
for i=1,#dataTable do
data[i]:copy(dataTable[i])
scalarLabels[i] = scalarTable[i]
end
dataset = {}
dataset.data = data
dataset.label = scalarLabels
setmetatable(dataset,
{__index = function(t, i)
return {t.data[i], t.label[i]}
end}
);
function dataset:size()
return self.data:size(1)
end
-- data normalization
mean = {}
stdv = {}
for i=1,3 do
mean[i] = dataset.data[{ {}, {i}, {}, {} }]:mean()
dataset.data[{ {}, {i}, {}, {} }]:add(-mean[i])
stdv[i] = dataset.data[{ {}, {i}, {}, {} }]:std()
if stdv[i] ~= 0 then
dataset.data[{ {}, {i}, {}, {} }]:div(stdv[i])
end
end
return dataset
end
function getBatchSizes(classes, classList, batchSize)
local numSamples = 0
for i=1,#classes do
numSamples = numSamples + classList[i]:nElement()
end
local numBatches = math.ceil(numSamples/batchSize)
local batchSizes = {}
for i=1,#classes do
local roundFlag = 0
local batchSum = 0
batchSizes[i] = {}
for j=1,numBatches-1 do
if roundFlag == 0 then
batchSizes[i][j] = math.floor(classList[i]:nElement()/numBatches)
else
batchSizes[i][j] = math.ceil(classList[i]:nElement()/numBatches)
end
batchSum = batchSum + batchSizes[i][j]
if j*classList[i]:nElement()/numBatches > batchSum then
roundFlag = 1
else
roundFlag = 0
end
end
batchSizes[i][numBatches] = classList[i]:nElement() - batchSum
end
return batchSizes, numBatches, numSamples
end
function shuffleImages(classList, classes)
local temp = {}
for i=1,#classes do
local perm = torch.randperm(classList[i]:size(1))
temp[i] = torch.LongTensor(classList[i]:size(1))
for j=1,classList[i]:size(1) do
temp[i][j] = classList[i][perm[j]]
end
end
return temp
end