2015-09-24 07:49:45 +08:00
|
|
|
require 'nn'
|
2015-12-21 11:23:15 +08:00
|
|
|
|
2015-09-24 07:49:45 +08:00
|
|
|
require 'dpnn'
|
2015-12-21 11:23:15 +08:00
|
|
|
|
2015-09-24 07:49:45 +08:00
|
|
|
require 'optim'
|
|
|
|
|
2016-01-12 04:36:58 +08:00
|
|
|
if opt.cuda then
|
|
|
|
require 'cunn'
|
|
|
|
if opt.cudnn then
|
|
|
|
require 'cudnn'
|
2016-06-29 17:48:03 +08:00
|
|
|
cudnn.benchmark = opt.cudnn_bench
|
2016-01-12 04:36:58 +08:00
|
|
|
cudnn.fastest = true
|
|
|
|
cudnn.verbose = false
|
|
|
|
end
|
|
|
|
end
|
2016-01-07 05:04:10 +08:00
|
|
|
|
2015-09-24 07:49:45 +08:00
|
|
|
paths.dofile('torch-TripletEmbedding/TripletEmbedding.lua')
|
|
|
|
|
|
|
|
|
2016-06-29 17:48:03 +08:00
|
|
|
local M = {}
|
2015-09-24 07:49:45 +08:00
|
|
|
|
2016-06-29 17:48:03 +08:00
|
|
|
function M.modelSetup(continue)
|
|
|
|
if continue then
|
|
|
|
model = continue
|
|
|
|
elseif opt.retrain ~= 'none' then
|
|
|
|
assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain)
|
|
|
|
print('Loading model from file: ' .. opt.retrain);
|
|
|
|
model = torch.load(opt.retrain)
|
|
|
|
print("Using imgDim = ", opt.imgDim)
|
|
|
|
else
|
|
|
|
paths.dofile(opt.modelDef)
|
|
|
|
assert(imgDim, "Model definition must set global variable 'imgDim'")
|
|
|
|
assert(imgDim == opt.imgDim, "Model definiton's imgDim must match imgDim option.")
|
|
|
|
model = createModel()
|
|
|
|
end
|
|
|
|
|
|
|
|
-- First remove any DataParallelTable
|
|
|
|
if torch.type(model) == 'nn.DataParallelTable' then
|
|
|
|
model = model:get(1)
|
|
|
|
end
|
|
|
|
|
|
|
|
criterion = nn.TripletEmbeddingCriterion(opt.alpha)
|
|
|
|
|
|
|
|
if opt.cuda then
|
|
|
|
model = model:cuda()
|
|
|
|
if opt.cudnn then
|
|
|
|
cudnn.convert(model,cudnn)
|
|
|
|
end
|
|
|
|
criterion:cuda()
|
|
|
|
else
|
|
|
|
model:float()
|
|
|
|
criterion:float()
|
|
|
|
end
|
|
|
|
|
|
|
|
optimizeNet(model, opt.imgDim)
|
|
|
|
|
|
|
|
if opt.cuda and opt.nGPU > 1 then
|
|
|
|
model = makeDataParallel(model, opt.nGPU)
|
|
|
|
end
|
|
|
|
|
|
|
|
collectgarbage()
|
|
|
|
return model, criterion
|
2016-01-12 04:36:58 +08:00
|
|
|
end
|
2015-09-24 07:49:45 +08:00
|
|
|
|
2016-06-29 17:48:03 +08:00
|
|
|
return M
|
2015-09-24 07:49:45 +08:00
|
|
|
|
|
|
|
|