diff --git a/training/main.lua b/training/main.lua index bf75bc4..48129c7 100755 --- a/training/main.lua +++ b/training/main.lua @@ -28,6 +28,7 @@ torch.manualSeed(opt.manualSeed) paths.dofile('data.lua') paths.dofile('model.lua') paths.dofile('train.lua') +paths.dofile('test.lua') paths.dofile('util.lua') if opt.peoplePerBatch > nClasses then @@ -41,5 +42,6 @@ epoch = opt.epochNumber for _=1,opt.nEpochs do train() + test() epoch = epoch + 1 end diff --git a/training/opts.lua b/training/opts.lua index e6bdc62..b06a418 100644 --- a/training/opts.lua +++ b/training/opts.lua @@ -40,6 +40,7 @@ function M.parse(arg) -- GPU memory usage depends on peoplePerBatch and imagesPerPerson. cmd:option('-peoplePerBatch', 15, 'Number of people to sample in each mini-batch.') cmd:option('-imagesPerPerson', 20, 'Number of images to sample per person in each mini-batch.') + cmd:option('-testBatchSize', 800, 'Batch size for testing.') ---------- Model options ---------------------------------- cmd:option('-retrain', 'none', 'provide path to model to retrain with') diff --git a/training/test.lua b/training/test.lua new file mode 100644 index 0000000..a4dc6ad --- /dev/null +++ b/training/test.lua @@ -0,0 +1,64 @@ +-- Copyright 2016 Carnegie Mellon University +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +require 'io' +require 'string' +require 'sys' + +local lfwDir = '../data/lfw/aligned' + +local batchRepresentBin = "../batch-represent/main.lua" +local lfwEvalBin = "../evaluation/lfw.py" + +local testLogger = optim.Logger(paths.concat(opt.save, 'test.log')) + +local function getLfwAcc(fName) + local f = io.open(fName, 'r') + io.input(f) + local lastLine = nil + while true do + local line = io.read("*line") + if line == nil then break end + lastLine = line + end + io.close() + return tonumber(string.sub(lastLine, 6, 11)) +end + +function test() + if opt.cuda then + model = model:float() + end + local latestModelFile = paths.concat(opt.save, 'model_' .. epoch .. '.t7') + local outDir = paths.concat(opt.save, 'lfw-' .. epoch) + print(latestModelFile) + print(outDir) + local cmd = batchRepresentBin + if opt.cuda then + cmd = cmd .. ' -cuda ' + end + cmd = cmd .. ' -batchSize ' .. opt.testBatchSize .. + ' -model ' .. latestModelFile .. + ' -data ' .. lfwDir .. + ' -outDir ' .. outDir + os.execute(cmd) + + cmd = lfwEvalBin .. ' Epoch' .. epoch .. ' ' .. outDir + os.execute(cmd) + + lfwAcc = getLfwAcc(paths.concat(outDir, "accuracies.txt")) + testLogger:add{ + ['lfwAcc'] = lfwAcc + } +end diff --git a/training/train.lua b/training/train.lua index 8317413..6a57577 100644 --- a/training/train.lua +++ b/training/train.lua @@ -102,10 +102,9 @@ function train() cutorch.synchronize() end - -- set the dropouts to training mode model:training() if opt.cuda then - model:cuda() -- get it back on the right GPUs. + model:cuda() end local tm = torch.Timer()