openface/training/model.lua

68 lines
1.4 KiB
Lua
Raw Normal View History

require 'nn'
require 'dpnn'
require 'optim'
2016-01-12 04:36:58 +08:00
if opt.cuda then
require 'cunn'
if opt.cudnn then
require 'cudnn'
2016-06-29 17:48:03 +08:00
cudnn.benchmark = opt.cudnn_bench
2016-01-12 04:36:58 +08:00
cudnn.fastest = true
cudnn.verbose = false
end
end
paths.dofile('torch-TripletEmbedding/TripletEmbedding.lua')
2016-06-29 17:48:03 +08:00
local M = {}
2016-06-29 17:48:03 +08:00
function M.modelSetup(continue)
if continue then
model = continue
elseif opt.retrain ~= 'none' then
assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain)
print('Loading model from file: ' .. opt.retrain);
model = torch.load(opt.retrain)
print("Using imgDim = ", opt.imgDim)
else
paths.dofile(opt.modelDef)
assert(imgDim, "Model definition must set global variable 'imgDim'")
assert(imgDim == opt.imgDim, "Model definiton's imgDim must match imgDim option.")
model = createModel()
end
-- First remove any DataParallelTable
if torch.type(model) == 'nn.DataParallelTable' then
model = model:get(1)
end
criterion = nn.TripletEmbeddingCriterion(opt.alpha)
if opt.cuda then
model = model:cuda()
if opt.cudnn then
cudnn.convert(model,cudnn)
end
criterion:cuda()
else
model:float()
criterion:float()
end
optimizeNet(model, opt.imgDim)
if opt.cuda and opt.nGPU > 1 then
model = makeDataParallel(model, opt.nGPU)
end
collectgarbage()
return model, criterion
2016-01-12 04:36:58 +08:00
end
2016-06-29 17:48:03 +08:00
return M