Replace nn_to_cudnn by cudnn.convert #106

This commit is contained in:
melgor 2016-03-11 11:09:23 +01:00
parent 3a66bc2345
commit f63acb637e
2 changed files with 9 additions and 126 deletions

View File

@ -16,65 +16,6 @@ end
paths.dofile('torch-TripletEmbedding/TripletEmbedding.lua')
-- From https://groups.google.com/d/msg/torch7/i8sJYlgQPeA/wiHlPSa5-HYJ
function replaceModules(net, orig_class_name, replacer)
local nodes, container_nodes = net:findModules(orig_class_name)
for i = 1, #nodes do
for j = 1, #(container_nodes[i].modules) do
if container_nodes[i].modules[j] == nodes[i] then
local orig_mod = container_nodes[i].modules[j]
container_nodes[i].modules[j] = replacer(orig_mod)
end
end
end
end
function nn_to_cudnn(net)
local net_cudnn = net:clone():float()
replaceModules(net_cudnn, 'nn.SpatialConvolutionMM',
function(nn_mod)
local cudnn_mod = cudnn.SpatialConvolution(
nn_mod.nInputPlane, nn_mod.nOutputPlane,
nn_mod.kW, nn_mod.kH,
nn_mod.dW, nn_mod.dH,
nn_mod.padW, nn_mod.padH
)
cudnn_mod.weight:copy(nn_mod.weight)
cudnn_mod.bias:copy(nn_mod.bias)
return cudnn_mod
end
)
replaceModules(net_cudnn, 'nn.SpatialAveragePooling',
function(nn_mod)
return cudnn.SpatialAveragePooling(
nn_mod.kW, nn_mod.kH,
nn_mod.dW, nn_mod.dH,
nn_mod.padW, nn_mod.padH
)
end
)
replaceModules(net_cudnn, 'nn.SpatialMaxPooling',
function(nn_mod)
return cudnn.SpatialMaxPooling(
nn_mod.kW, nn_mod.kH,
nn_mod.dW, nn_mod.dH,
nn_mod.padW, nn_mod.padH
)
end
)
replaceModules(net_cudnn, 'nn.ReLU', function() return cudnn.ReLU() end)
replaceModules(net_cudnn, 'nn.SpatialCrossMapLRN',
function(nn_mod)
return cudnn.SpatialCrossMapLRN(nn_mod.size, nn_mod.alpha,
nn_mod.beta, nn_mod.k)
end
)
return net_cudnn
end
if opt.retrain ~= 'none' then
assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain)
print('Loading model from file: ' .. opt.retrain);
@ -86,13 +27,14 @@ else
model = createModel()
end
if opt.cudnn then
model = nn_to_cudnn(model)
end
criterion = nn.TripletEmbeddingCriterion(opt.alpha)
if opt.cuda then
model = model:cuda()
if opt.cudnn then
cudnn.convert(model,cudnn)
end
criterion:cuda()
end

View File

@ -33,70 +33,9 @@ trainLogger = optim.Logger(paths.concat(opt.save, 'train.log'))
local batchNumber
local triplet_loss
-- From https://groups.google.com/d/msg/torch7/i8sJYlgQPeA/wiHlPSa5-HYJ
local function replaceModules(net, orig_class_name, replacer)
local nodes, container_nodes = net:findModules(orig_class_name)
for i = 1, #nodes do
for j = 1, #(container_nodes[i].modules) do
if container_nodes[i].modules[j] == nodes[i] then
local orig_mod = container_nodes[i].modules[j]
container_nodes[i].modules[j] = replacer(orig_mod)
end
end
end
end
local function cudnn_to_nn(net)
local net_nn = net:clone():float()
replaceModules(net_nn, 'cudnn.SpatialConvolution',
function(cudnn_mod)
local nn_mod = nn.SpatialConvolutionMM(
cudnn_mod.nInputPlane, cudnn_mod.nOutputPlane,
cudnn_mod.kW, cudnn_mod.kH,
cudnn_mod.dW, cudnn_mod.dH,
cudnn_mod.padW, cudnn_mod.padH
)
nn_mod.weight:copy(cudnn_mod.weight)
nn_mod.bias:copy(cudnn_mod.bias)
return nn_mod
end
)
replaceModules(net_nn, 'cudnn.SpatialAveragePooling',
function(cudnn_mod)
return nn.SpatialAveragePooling(
cudnn_mod.kW, cudnn_mod.kH,
cudnn_mod.dW, cudnn_mod.dH,
cudnn_mod.padW, cudnn_mod.padH
)
end
)
replaceModules(net_nn, 'cudnn.SpatialMaxPooling',
function(cudnn_mod)
return nn.SpatialMaxPooling(
cudnn_mod.kW, cudnn_mod.kH,
cudnn_mod.dW, cudnn_mod.dH,
cudnn_mod.padW, cudnn_mod.padH
)
end
)
replaceModules(net_nn, 'cudnn.ReLU', function() return nn.ReLU() end)
replaceModules(net_nn, 'cudnn.SpatialCrossMapLRN',
function(cudnn_mod)
return nn.SpatialCrossMapLRN(cudnn_mod.size, cudnn_mod.alpha,
cudnn_mod.beta, cudnn_mod.K)
end
)
return net_nn
end
function train()
print('==> doing epoch on training data:')
print("==> online epoch # " .. epoch)
batchNumber = 0
if opt.cuda then
cutorch.synchronize()
@ -106,7 +45,6 @@ function train()
if opt.cuda then
model:cuda()
end
local tm = torch.Timer()
triplet_loss = 0
@ -147,8 +85,11 @@ function train()
print('\n')
collectgarbage()
local nnModel = cudnn_to_nn(sanitize(model:float()))
local nnModel = sanitize(model:float():clone())
if opt.cudnn then
cudnn.convert(nnModel,nn)
end
torch.save(paths.concat(opt.save, 'model_' .. epoch .. '.t7'), nnModel)
torch.save(paths.concat(opt.save, 'optimState_' .. epoch .. '.t7'), optimState)
collectgarbage()