Replace nn_to_cudnn by cudnn.convert #106
This commit is contained in:
parent
3a66bc2345
commit
f63acb637e
|
@ -16,65 +16,6 @@ end
|
|||
|
||||
paths.dofile('torch-TripletEmbedding/TripletEmbedding.lua')
|
||||
|
||||
-- From https://groups.google.com/d/msg/torch7/i8sJYlgQPeA/wiHlPSa5-HYJ
|
||||
function replaceModules(net, orig_class_name, replacer)
|
||||
local nodes, container_nodes = net:findModules(orig_class_name)
|
||||
for i = 1, #nodes do
|
||||
for j = 1, #(container_nodes[i].modules) do
|
||||
if container_nodes[i].modules[j] == nodes[i] then
|
||||
local orig_mod = container_nodes[i].modules[j]
|
||||
container_nodes[i].modules[j] = replacer(orig_mod)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function nn_to_cudnn(net)
|
||||
local net_cudnn = net:clone():float()
|
||||
|
||||
replaceModules(net_cudnn, 'nn.SpatialConvolutionMM',
|
||||
function(nn_mod)
|
||||
local cudnn_mod = cudnn.SpatialConvolution(
|
||||
nn_mod.nInputPlane, nn_mod.nOutputPlane,
|
||||
nn_mod.kW, nn_mod.kH,
|
||||
nn_mod.dW, nn_mod.dH,
|
||||
nn_mod.padW, nn_mod.padH
|
||||
)
|
||||
cudnn_mod.weight:copy(nn_mod.weight)
|
||||
cudnn_mod.bias:copy(nn_mod.bias)
|
||||
return cudnn_mod
|
||||
end
|
||||
)
|
||||
replaceModules(net_cudnn, 'nn.SpatialAveragePooling',
|
||||
function(nn_mod)
|
||||
return cudnn.SpatialAveragePooling(
|
||||
nn_mod.kW, nn_mod.kH,
|
||||
nn_mod.dW, nn_mod.dH,
|
||||
nn_mod.padW, nn_mod.padH
|
||||
)
|
||||
end
|
||||
)
|
||||
replaceModules(net_cudnn, 'nn.SpatialMaxPooling',
|
||||
function(nn_mod)
|
||||
return cudnn.SpatialMaxPooling(
|
||||
nn_mod.kW, nn_mod.kH,
|
||||
nn_mod.dW, nn_mod.dH,
|
||||
nn_mod.padW, nn_mod.padH
|
||||
)
|
||||
end
|
||||
)
|
||||
|
||||
replaceModules(net_cudnn, 'nn.ReLU', function() return cudnn.ReLU() end)
|
||||
replaceModules(net_cudnn, 'nn.SpatialCrossMapLRN',
|
||||
function(nn_mod)
|
||||
return cudnn.SpatialCrossMapLRN(nn_mod.size, nn_mod.alpha,
|
||||
nn_mod.beta, nn_mod.k)
|
||||
end
|
||||
)
|
||||
|
||||
return net_cudnn
|
||||
end
|
||||
|
||||
if opt.retrain ~= 'none' then
|
||||
assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain)
|
||||
print('Loading model from file: ' .. opt.retrain);
|
||||
|
@ -86,13 +27,14 @@ else
|
|||
model = createModel()
|
||||
end
|
||||
|
||||
if opt.cudnn then
|
||||
model = nn_to_cudnn(model)
|
||||
end
|
||||
|
||||
criterion = nn.TripletEmbeddingCriterion(opt.alpha)
|
||||
|
||||
if opt.cuda then
|
||||
model = model:cuda()
|
||||
if opt.cudnn then
|
||||
cudnn.convert(model,cudnn)
|
||||
end
|
||||
criterion:cuda()
|
||||
end
|
||||
|
||||
|
|
|
@ -33,70 +33,9 @@ trainLogger = optim.Logger(paths.concat(opt.save, 'train.log'))
|
|||
local batchNumber
|
||||
local triplet_loss
|
||||
|
||||
|
||||
-- From https://groups.google.com/d/msg/torch7/i8sJYlgQPeA/wiHlPSa5-HYJ
|
||||
local function replaceModules(net, orig_class_name, replacer)
|
||||
local nodes, container_nodes = net:findModules(orig_class_name)
|
||||
for i = 1, #nodes do
|
||||
for j = 1, #(container_nodes[i].modules) do
|
||||
if container_nodes[i].modules[j] == nodes[i] then
|
||||
local orig_mod = container_nodes[i].modules[j]
|
||||
container_nodes[i].modules[j] = replacer(orig_mod)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function cudnn_to_nn(net)
|
||||
local net_nn = net:clone():float()
|
||||
|
||||
replaceModules(net_nn, 'cudnn.SpatialConvolution',
|
||||
function(cudnn_mod)
|
||||
local nn_mod = nn.SpatialConvolutionMM(
|
||||
cudnn_mod.nInputPlane, cudnn_mod.nOutputPlane,
|
||||
cudnn_mod.kW, cudnn_mod.kH,
|
||||
cudnn_mod.dW, cudnn_mod.dH,
|
||||
cudnn_mod.padW, cudnn_mod.padH
|
||||
)
|
||||
nn_mod.weight:copy(cudnn_mod.weight)
|
||||
nn_mod.bias:copy(cudnn_mod.bias)
|
||||
return nn_mod
|
||||
end
|
||||
)
|
||||
replaceModules(net_nn, 'cudnn.SpatialAveragePooling',
|
||||
function(cudnn_mod)
|
||||
return nn.SpatialAveragePooling(
|
||||
cudnn_mod.kW, cudnn_mod.kH,
|
||||
cudnn_mod.dW, cudnn_mod.dH,
|
||||
cudnn_mod.padW, cudnn_mod.padH
|
||||
)
|
||||
end
|
||||
)
|
||||
replaceModules(net_nn, 'cudnn.SpatialMaxPooling',
|
||||
function(cudnn_mod)
|
||||
return nn.SpatialMaxPooling(
|
||||
cudnn_mod.kW, cudnn_mod.kH,
|
||||
cudnn_mod.dW, cudnn_mod.dH,
|
||||
cudnn_mod.padW, cudnn_mod.padH
|
||||
)
|
||||
end
|
||||
)
|
||||
|
||||
replaceModules(net_nn, 'cudnn.ReLU', function() return nn.ReLU() end)
|
||||
replaceModules(net_nn, 'cudnn.SpatialCrossMapLRN',
|
||||
function(cudnn_mod)
|
||||
return nn.SpatialCrossMapLRN(cudnn_mod.size, cudnn_mod.alpha,
|
||||
cudnn_mod.beta, cudnn_mod.K)
|
||||
end
|
||||
)
|
||||
|
||||
return net_nn
|
||||
end
|
||||
|
||||
function train()
|
||||
print('==> doing epoch on training data:')
|
||||
print("==> online epoch # " .. epoch)
|
||||
|
||||
batchNumber = 0
|
||||
if opt.cuda then
|
||||
cutorch.synchronize()
|
||||
|
@ -106,7 +45,6 @@ function train()
|
|||
if opt.cuda then
|
||||
model:cuda()
|
||||
end
|
||||
|
||||
local tm = torch.Timer()
|
||||
triplet_loss = 0
|
||||
|
||||
|
@ -147,8 +85,11 @@ function train()
|
|||
print('\n')
|
||||
|
||||
collectgarbage()
|
||||
|
||||
local nnModel = cudnn_to_nn(sanitize(model:float()))
|
||||
|
||||
local nnModel = sanitize(model:float():clone())
|
||||
if opt.cudnn then
|
||||
cudnn.convert(nnModel,nn)
|
||||
end
|
||||
torch.save(paths.concat(opt.save, 'model_' .. epoch .. '.t7'), nnModel)
|
||||
torch.save(paths.concat(opt.save, 'optimState_' .. epoch .. '.t7'), optimState)
|
||||
collectgarbage()
|
||||
|
|
Loading…
Reference in New Issue