Finish tensor size checks for #36.

This commit is contained in:
Brandon Amos 2016-01-11 16:44:50 -05:00
parent d0b40686a2
commit 0479d0e345
3 changed files with 5 additions and 3 deletions

View File

@ -25,8 +25,8 @@ if not os.execute('cd ' .. opt.data) then
error(("could not chdir to '%s'"):format(opt.data))
end
local loadSize = {3, imgDim, imgDim}
local sampleSize = {3, imgDim, imgDim}
local loadSize = {3, opt.imgDim, opt.imgDim}
local sampleSize = {3, opt.imgDim, opt.imgDim}
-- function to load the image, jitter it appropriately (random crops etc.)
local trainHook = function(self, path)

View File

@ -82,6 +82,8 @@ if opt.retrain ~= 'none' then
model = torch.load(opt.retrain)
else
paths.dofile(opt.modelDef)
assert(imgDim, "Model definition must set global variable 'imgDim'")
assert(imgDim == opt.imgDim, "Model definiton's imgDim must match imgDim option.")
model = createModel()
end

View File

@ -45,7 +45,7 @@ function M.parse(arg)
---------- Model options ----------------------------------
cmd:option('-retrain', 'none', 'provide path to model to retrain with')
cmd:option('-modelDef', '../models/openface/nn4.def.lua', 'path to model definiton')
-- cmd:option('-imgDim', 96, 'Image dimension. nn2=224, nn4=96') Provided by model def.
cmd:option('-imgDim', 96, 'Image dimension. nn2=224, nn4=96')
cmd:option('-embSize', 128, 'size of embedding from model')
cmd:option('-alpha', 0.2, 'margin in TripletLoss')
cmd:text()