openface/training/model.lua

102 lines
2.8 KiB
Lua
Raw Normal View History

require 'nn'
require 'cunn'
require 'dpnn'
require 'fbnn'
require 'fbcunn'
require 'optim'
require 'cudnn'
cudnn.benchmark = false
cudnn.fastest = true
cudnn.verbose = false
paths.dofile('torch-TripletEmbedding/TripletEmbedding.lua')
-- From https://groups.google.com/d/msg/torch7/i8sJYlgQPeA/wiHlPSa5-HYJ
function replaceModules(net, orig_class_name, replacer)
local nodes, container_nodes = net:findModules(orig_class_name)
for i = 1, #nodes do
for j = 1, #(container_nodes[i].modules) do
if container_nodes[i].modules[j] == nodes[i] then
local orig_mod = container_nodes[i].modules[j]
container_nodes[i].modules[j] = replacer(orig_mod)
end
end
end
end
function nn_to_cudnn(net)
2016-01-08 07:28:05 +08:00
local net_cudnn = net:clone():float()
2016-01-08 07:28:05 +08:00
replaceModules(net_cudnn, 'nn.SpatialConvolutionMM',
function(nn_mod)
local cudnn_mod = cudnn.SpatialConvolution(
nn_mod.nInputPlane, nn_mod.nOutputPlane,
nn_mod.kW, nn_mod.kH,
nn_mod.dW, nn_mod.dH,
nn_mod.padW, nn_mod.padH
)
cudnn_mod.weight:copy(nn_mod.weight)
cudnn_mod.bias:copy(nn_mod.bias)
return cudnn_mod
end
)
2016-01-08 07:28:05 +08:00
replaceModules(net_cudnn, 'nn.SpatialAveragePooling',
function(nn_mod)
return cudnn.SpatialAveragePooling(
nn_mod.kW, nn_mod.kH,
nn_mod.dW, nn_mod.dH,
nn_mod.padW, nn_mod.padH
)
end
)
2016-01-08 07:28:05 +08:00
replaceModules(net_cudnn, 'nn.SpatialMaxPooling',
function(nn_mod)
return cudnn.SpatialMaxPooling(
nn_mod.kW, nn_mod.kH,
nn_mod.dW, nn_mod.dH,
nn_mod.padW, nn_mod.padH
)
end
)
2016-01-08 07:28:05 +08:00
replaceModules(net_cudnn, 'nn.ReLU', function() return cudnn.ReLU() end)
replaceModules(net_cudnn, 'nn.SpatialCrossMapLRN',
function(nn_mod)
2016-01-08 07:28:05 +08:00
return cudnn.SpatialCrossMapLRN(nn_mod.size, nn_mod.alpha,
nn_mod.beta, nn_mod.k)
end
)
2016-01-08 07:28:05 +08:00
return net_cudnn
end
if opt.retrain ~= 'none' then
assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain)
print('Loading model from file: ' .. opt.retrain);
model = torch.load(opt.retrain)
else
paths.dofile(opt.modelDef)
model = createModel()
end
2016-01-08 07:28:05 +08:00
if opt.cudnn then
model = nn_to_cudnn(model)
end
criterion = nn.TripletEmbeddingCriterion(opt.alpha)
model = model:cuda()
criterion:cuda()
print('=> Model')
print(model)
print('=> Criterion')
print(criterion)
collectgarbage()