NN Training: Add LFW validation for #100.
This commit is contained in:
parent
e4c4aef3e4
commit
f5766412d2
|
@ -28,6 +28,7 @@ torch.manualSeed(opt.manualSeed)
|
|||
paths.dofile('data.lua')
|
||||
paths.dofile('model.lua')
|
||||
paths.dofile('train.lua')
|
||||
paths.dofile('test.lua')
|
||||
paths.dofile('util.lua')
|
||||
|
||||
if opt.peoplePerBatch > nClasses then
|
||||
|
@ -41,5 +42,6 @@ epoch = opt.epochNumber
|
|||
|
||||
for _=1,opt.nEpochs do
|
||||
train()
|
||||
test()
|
||||
epoch = epoch + 1
|
||||
end
|
||||
|
|
|
@ -40,6 +40,7 @@ function M.parse(arg)
|
|||
-- GPU memory usage depends on peoplePerBatch and imagesPerPerson.
|
||||
cmd:option('-peoplePerBatch', 15, 'Number of people to sample in each mini-batch.')
|
||||
cmd:option('-imagesPerPerson', 20, 'Number of images to sample per person in each mini-batch.')
|
||||
cmd:option('-testBatchSize', 800, 'Batch size for testing.')
|
||||
|
||||
---------- Model options ----------------------------------
|
||||
cmd:option('-retrain', 'none', 'provide path to model to retrain with')
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
-- Copyright 2016 Carnegie Mellon University
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
require 'io'
|
||||
require 'string'
|
||||
require 'sys'
|
||||
|
||||
local lfwDir = '../data/lfw/aligned'
|
||||
|
||||
local batchRepresentBin = "../batch-represent/main.lua"
|
||||
local lfwEvalBin = "../evaluation/lfw.py"
|
||||
|
||||
local testLogger = optim.Logger(paths.concat(opt.save, 'test.log'))
|
||||
|
||||
local function getLfwAcc(fName)
|
||||
local f = io.open(fName, 'r')
|
||||
io.input(f)
|
||||
local lastLine = nil
|
||||
while true do
|
||||
local line = io.read("*line")
|
||||
if line == nil then break end
|
||||
lastLine = line
|
||||
end
|
||||
io.close()
|
||||
return tonumber(string.sub(lastLine, 6, 11))
|
||||
end
|
||||
|
||||
function test()
|
||||
if opt.cuda then
|
||||
model = model:float()
|
||||
end
|
||||
local latestModelFile = paths.concat(opt.save, 'model_' .. epoch .. '.t7')
|
||||
local outDir = paths.concat(opt.save, 'lfw-' .. epoch)
|
||||
print(latestModelFile)
|
||||
print(outDir)
|
||||
local cmd = batchRepresentBin
|
||||
if opt.cuda then
|
||||
cmd = cmd .. ' -cuda '
|
||||
end
|
||||
cmd = cmd .. ' -batchSize ' .. opt.testBatchSize ..
|
||||
' -model ' .. latestModelFile ..
|
||||
' -data ' .. lfwDir ..
|
||||
' -outDir ' .. outDir
|
||||
os.execute(cmd)
|
||||
|
||||
cmd = lfwEvalBin .. ' Epoch' .. epoch .. ' ' .. outDir
|
||||
os.execute(cmd)
|
||||
|
||||
lfwAcc = getLfwAcc(paths.concat(outDir, "accuracies.txt"))
|
||||
testLogger:add{
|
||||
['lfwAcc'] = lfwAcc
|
||||
}
|
||||
end
|
|
@ -102,10 +102,9 @@ function train()
|
|||
cutorch.synchronize()
|
||||
end
|
||||
|
||||
-- set the dropouts to training mode
|
||||
model:training()
|
||||
if opt.cuda then
|
||||
model:cuda() -- get it back on the right GPUs.
|
||||
model:cuda()
|
||||
end
|
||||
|
||||
local tm = torch.Timer()
|
||||
|
|
Loading…
Reference in New Issue