99 lines
2.9 KiB
Lua
99 lines
2.9 KiB
Lua
|
-- Source: https://github.com/facebook/fbcunn/blob/master/examples/imagenet/donkey.lua
|
||
|
--
|
||
|
-- Copyright (c) 2014, Facebook, Inc.
|
||
|
-- All rights reserved.
|
||
|
--
|
||
|
-- This source code is licensed under the BSD-style license found in the
|
||
|
-- LICENSE file in the root directory of this source tree. An additional grant
|
||
|
-- of patent rights can be found in the PATENTS file in the same directory.
|
||
|
--
|
||
|
local gm = assert(require 'graphicsmagick')
|
||
|
paths.dofile('dataset.lua')
|
||
|
paths.dofile('util.lua')
|
||
|
ffi=require 'ffi'
|
||
|
|
||
|
-- This file contains the data-loading logic and details.
|
||
|
-- It is run by each data-loader thread.
|
||
|
------------------------------------------
|
||
|
|
||
|
-- a cache file of the training metadata (if doesnt exist, will be created)
|
||
|
local trainCache = paths.concat(opt.cache, 'trainCache.t7')
|
||
|
local testCache = paths.concat(opt.cache, 'testCache.t7')
|
||
|
|
||
|
-- Check for existence of opt.data
|
||
|
if not os.execute('cd ' .. opt.data) then
|
||
|
error(("could not chdir to '%s'"):format(opt.data))
|
||
|
end
|
||
|
|
||
|
local loadSize = {3, opt.imgDim, opt.imgDim}
|
||
|
local sampleSize = {3, opt.imgDim, opt.imgDim}
|
||
|
|
||
|
-- function to load the image, jitter it appropriately (random crops etc.)
|
||
|
local trainHook = function(self, path)
|
||
|
-- load image with size hints
|
||
|
local input = gm.Image():load(path, self.loadSize[3], self.loadSize[2])
|
||
|
|
||
|
input:size(self.sampleSize[3], self.sampleSize[2])
|
||
|
|
||
|
local out = input
|
||
|
|
||
|
-- do hflip with probability 0.5
|
||
|
if torch.uniform() > 0.5 then out:flop(); end
|
||
|
|
||
|
out = out:toTensor('float','RGB','DHW')
|
||
|
|
||
|
return out
|
||
|
end
|
||
|
|
||
|
if paths.filep(trainCache) then
|
||
|
print('Loading train metadata from cache')
|
||
|
trainLoader = torch.load(trainCache)
|
||
|
trainLoader.sampleHookTrain = trainHook
|
||
|
else
|
||
|
print('Creating train metadata')
|
||
|
trainLoader = dataLoader{
|
||
|
paths = {paths.concat(opt.data, 'train')},
|
||
|
loadSize = loadSize,
|
||
|
sampleSize = sampleSize,
|
||
|
split = 100,
|
||
|
verbose = true
|
||
|
}
|
||
|
torch.save(trainCache, trainLoader)
|
||
|
trainLoader.sampleHookTrain = trainHook
|
||
|
end
|
||
|
collectgarbage()
|
||
|
|
||
|
-- do some sanity checks on trainLoader
|
||
|
do
|
||
|
local class = trainLoader.imageClass
|
||
|
local nClasses = #trainLoader.classes
|
||
|
assert(class:max() <= nClasses, "class logic has error")
|
||
|
assert(class:min() >= 1, "class logic has error")
|
||
|
|
||
|
end
|
||
|
|
||
|
-- End of train loader section
|
||
|
--------------------------------------------------------------------------------
|
||
|
|
||
|
--[[ Section 2: Create a test data loader (testLoader), ]]--
|
||
|
|
||
|
if paths.filep(testCache) then
|
||
|
print('Loading test metadata from cache')
|
||
|
testLoader = torch.load(testCache)
|
||
|
else
|
||
|
print('Creating test metadata')
|
||
|
testLoader = dataLoader{
|
||
|
paths = {paths.concat(opt.data, 'val')},
|
||
|
loadSize = loadSize,
|
||
|
sampleSize = sampleSize,
|
||
|
-- split = 0,
|
||
|
split = 100,
|
||
|
verbose = true,
|
||
|
-- force consistent class indices between trainLoader and testLoader
|
||
|
forceClasses = trainLoader.classes
|
||
|
}
|
||
|
torch.save(testCache, testLoader)
|
||
|
end
|
||
|
collectgarbage()
|
||
|
-- End of test loader section
|