NN Training: Add LFW validation for #100.

This commit is contained in:
Brandon Amos 2016-03-04 18:30:39 -05:00
parent e4c4aef3e4
commit f5766412d2
4 changed files with 68 additions and 2 deletions

View File

@ -28,6 +28,7 @@ torch.manualSeed(opt.manualSeed)
paths.dofile('data.lua')
paths.dofile('model.lua')
paths.dofile('train.lua')
paths.dofile('test.lua')
paths.dofile('util.lua')
if opt.peoplePerBatch > nClasses then
@ -41,5 +42,6 @@ epoch = opt.epochNumber
for _=1,opt.nEpochs do
train()
test()
epoch = epoch + 1
end

View File

@ -40,6 +40,7 @@ function M.parse(arg)
-- GPU memory usage depends on peoplePerBatch and imagesPerPerson.
cmd:option('-peoplePerBatch', 15, 'Number of people to sample in each mini-batch.')
cmd:option('-imagesPerPerson', 20, 'Number of images to sample per person in each mini-batch.')
cmd:option('-testBatchSize', 800, 'Batch size for testing.')
---------- Model options ----------------------------------
cmd:option('-retrain', 'none', 'provide path to model to retrain with')

64
training/test.lua Normal file
View File

@ -0,0 +1,64 @@
-- Copyright 2016 Carnegie Mellon University
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
require 'io'
require 'string'
require 'sys'
local lfwDir = '../data/lfw/aligned'
local batchRepresentBin = "../batch-represent/main.lua"
local lfwEvalBin = "../evaluation/lfw.py"
local testLogger = optim.Logger(paths.concat(opt.save, 'test.log'))
local function getLfwAcc(fName)
local f = io.open(fName, 'r')
io.input(f)
local lastLine = nil
while true do
local line = io.read("*line")
if line == nil then break end
lastLine = line
end
io.close()
return tonumber(string.sub(lastLine, 6, 11))
end
function test()
if opt.cuda then
model = model:float()
end
local latestModelFile = paths.concat(opt.save, 'model_' .. epoch .. '.t7')
local outDir = paths.concat(opt.save, 'lfw-' .. epoch)
print(latestModelFile)
print(outDir)
local cmd = batchRepresentBin
if opt.cuda then
cmd = cmd .. ' -cuda '
end
cmd = cmd .. ' -batchSize ' .. opt.testBatchSize ..
' -model ' .. latestModelFile ..
' -data ' .. lfwDir ..
' -outDir ' .. outDir
os.execute(cmd)
cmd = lfwEvalBin .. ' Epoch' .. epoch .. ' ' .. outDir
os.execute(cmd)
lfwAcc = getLfwAcc(paths.concat(outDir, "accuracies.txt"))
testLogger:add{
['lfwAcc'] = lfwAcc
}
end

View File

@ -102,10 +102,9 @@ function train()
cutorch.synchronize()
end
-- set the dropouts to training mode
model:training()
if opt.cuda then
model:cuda() -- get it back on the right GPUs.
model:cuda()
end
local tm = torch.Timer()