Diff of /code/data.lua [000000] .. [b758a2]

Switch to side-by-side view

--- a
+++ b/code/data.lua
@@ -0,0 +1,245 @@
+require 'torch';
+require 'sys';
+require 'image';
+local dir = require 'pl.dir';
+local ffi = require 'ffi';
+
+function getImagePaths(folder)
+	-- obtain list of image files
+	local classes = {}
+	local classPaths = {}
+	local dirs = dir.getdirectories(folder);
+	for k,dirpath in ipairs(dirs) do
+		local class = paths.basename(dirpath)
+		table.insert(classes, class)
+		table.insert(classPaths, dirpath)
+	end
+
+	-- define command-line tools, try your best to maintain OSX compatibility
+	local wc = 'wc'
+	local cut = 'cut'
+	local find = 'find'
+	if ffi.os == 'OSX' then
+		wc = 'gwc'
+		cut = 'gcut'
+		find = 'gfind'
+	end
+
+	-- options for the GNU find command
+	local extensionList	 = {'jpg', 'JPG', 'png', 'PNG', 'jpeg', 'JPEG', 'ppm', 'PPM', 'bmp', 'BMP'}
+	local findOptions = ' -iname "*.' .. extensionList[1] .. '"'
+	for i=2,#extensionList do
+		findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"'
+	end
+
+	-- find the image path names
+	local imagePaths = torch.CharTensor()	-- path to each image in dataset
+	local imageClass = torch.LongTensor()	-- class index of each image (class index in self.classes)
+	local classList = {}			-- index of imageList to each image of a particular class
+
+	-- create file listing the paths to every image
+	local classFindFiles = {}
+	for i=1,#classes do
+		classFindFiles[i] = os.tmpname()
+	end
+	local combinedFindList = os.tmpname()
+
+	local tmpfile = os.tmpname()
+	local tmphandle = assert(io.open(tmpfile, 'w'))
+	for i,class in ipairs(classes) do
+		local command = find .. ' "' .. classPaths[i] .. '" ' .. findOptions .. ' >>"' .. classFindFiles[i] .. '" \n'
+		tmphandle:write(command)
+	end
+	io.close(tmphandle)
+	os.execute('bash ' .. tmpfile)
+	os.execute('rm -f ' .. tmpfile)
+
+	local tmpfile = os.tmpname()
+	local tmphandle = assert(io.open(tmpfile, 'w'))
+	-- concat all finds to a single large file in the order of self.classes
+	for i=1,#classes do
+		local command = 'cat "' .. classFindFiles[i] .. '" >>' .. combinedFindList .. ' \n'
+		tmphandle:write(command)
+	end
+	io.close(tmphandle)
+	os.execute('bash ' .. tmpfile)
+	os.execute('rm -f ' .. tmpfile)
+
+	local maxPathLength = tonumber(sys.fexecute(wc .. " -L '" .. combinedFindList .. "' |" .. cut .. " -f1 -d' '")) + 1
+	local length = tonumber(sys.fexecute(wc .. " -l '" .. combinedFindList .. "' |" .. cut .. " -f1 -d' '"))
+
+	imagePaths:resize(length, maxPathLength):fill(0)
+	local s_data = imagePaths:data()
+	for line in io.lines(combinedFindList) do
+		ffi.copy(s_data, line)
+		s_data = s_data + maxPathLength
+	end
+	numSamples = imagePaths:size(1)
+	print(numSamples ..  ' samples found.')
+
+	imageClass:resize(numSamples)
+	local runningIndex = 0
+	for i=1,#classes do
+		local length = tonumber(sys.fexecute(wc .. " -l '" .. classFindFiles[i] .. "' |" .. cut .. " -f1 -d' '"))
+		classList[i] = torch.linspace(runningIndex + 1, runningIndex + length, length):long()
+		imageClass[{{runningIndex + 1, runningIndex + length}}]:fill(i)
+		runningIndex = runningIndex + length
+	end
+
+	local tmpfilelistall = ''
+	for i=1,#(classFindFiles) do
+		tmpfilelistall = tmpfilelistall .. ' "' .. classFindFiles[i] .. '"'
+		if i % 1000 == 0 then
+			os.execute('rm -f ' .. tmpfilelistall)
+			tmpfilelistall = ''
+		end
+	end
+	os.execute('rm -f '  .. tmpfilelistall)
+	os.execute('rm -f "' .. combinedFindList .. '"')
+
+	return classes, classList, imagePaths
+end
+
+function getSample(classes, sampleList, imagePaths)
+	dataTable = {}
+	scalarTable = {}
+	N = 0
+	for i=1,#classes do
+		for j=1,sampleList[i]:nElement() do
+			local imgpath = ffi.string(torch.data(imagePaths[sampleList[i][j]]))
+			out = image.load(imgpath, 3, 'float')
+			table.insert(dataTable, out)
+			table.insert(scalarTable, i)
+			N = N + 1
+		end
+	end
+	data = torch.Tensor(N, 3, 101, 101)
+	scalarLabels = torch.LongTensor(N):fill(-1111)
+	for i=1,#dataTable do
+		data[i]:copy(dataTable[i])
+		scalarLabels[i] = scalarTable[i]
+	end
+	dataset = {}
+	dataset.data = data
+	dataset.label = scalarLabels
+
+	setmetatable(dataset,
+		{__index = function(t, i) 
+						return {t.data[i], t.label[i]}
+					end}
+	);
+
+	function dataset:size() 
+		return self.data:size(1) 
+	end
+
+	-- data normalization
+	mean = {}
+	stdv  = {}
+	for i=1,3 do
+		mean[i] = dataset.data[{ {}, {i}, {}, {}  }]:mean()
+		dataset.data[{ {}, {i}, {}, {}  }]:add(-mean[i])
+		
+		stdv[i] = dataset.data[{ {}, {i}, {}, {}  }]:std()
+		if stdv[i] ~= 0 then
+			dataset.data[{ {}, {i}, {}, {}  }]:div(stdv[i])
+		end
+	end
+
+	return dataset
+end
+
+function getRandomSample(classes, batchSize, classList, imagePaths)
+	dataTable = {}
+	scalarTable = {}
+	N = 0
+	for i=1,#classes do
+		for j=1,batchSize do
+			local index = math.max(1, math.ceil(torch.uniform() * classList[i]:nElement()))
+			local imgpath = ffi.string(torch.data(imagePaths[classList[i][index]]))
+			out = image.load(imgpath, 3, 'float')
+			table.insert(dataTable, out)
+			table.insert(scalarTable, i)
+			N = N + 1
+		end
+	end
+	data = torch.Tensor(N, 3, 101, 101)
+	scalarLabels = torch.LongTensor(N):fill(-1111)
+	for i=1,#dataTable do
+		data[i]:copy(dataTable[i])
+		scalarLabels[i] = scalarTable[i]
+	end
+	dataset = {}
+	dataset.data = data
+	dataset.label = scalarLabels
+
+	setmetatable(dataset,
+		{__index = function(t, i) 
+						return {t.data[i], t.label[i]}
+					end}
+	);
+
+	function dataset:size() 
+		return self.data:size(1) 
+	end
+
+	-- data normalization
+	mean = {}
+	stdv  = {}
+	for i=1,3 do
+		mean[i] = dataset.data[{ {}, {i}, {}, {}  }]:mean()
+		dataset.data[{ {}, {i}, {}, {}  }]:add(-mean[i])
+		
+		stdv[i] = dataset.data[{ {}, {i}, {}, {}  }]:std()
+		if stdv[i] ~= 0 then
+			dataset.data[{ {}, {i}, {}, {}  }]:div(stdv[i])
+		end
+	end
+
+	return dataset
+end
+
+function getBatchSizes(classes, classList, batchSize)
+	local numSamples = 0
+	for i=1,#classes do
+		numSamples = numSamples + classList[i]:nElement()
+	end
+	local numBatches = math.ceil(numSamples/batchSize)
+
+	local batchSizes = {}
+	for i=1,#classes do
+		local roundFlag = 0
+		local batchSum = 0
+		batchSizes[i] = {}
+		for j=1,numBatches-1 do
+			if roundFlag == 0 then
+				batchSizes[i][j] = math.floor(classList[i]:nElement()/numBatches)
+			else
+				batchSizes[i][j] = math.ceil(classList[i]:nElement()/numBatches)
+			end
+			
+			batchSum = batchSum + batchSizes[i][j]
+			
+			if j*classList[i]:nElement()/numBatches > batchSum then
+				roundFlag = 1
+			else
+				roundFlag = 0
+			end
+		end
+		batchSizes[i][numBatches] = classList[i]:nElement() - batchSum
+	end
+	
+	return batchSizes, numBatches, numSamples
+end
+
+function shuffleImages(classList, classes)
+	local temp = {}
+	for i=1,#classes do
+		local perm = torch.randperm(classList[i]:size(1))
+		temp[i] = torch.LongTensor(classList[i]:size(1))
+		for j=1,classList[i]:size(1) do
+			temp[i][j] = classList[i][perm[j]]
+		end
+	end
+	return temp
+end