Training: Use e-lab's torch-toolbox sanitize.
This commit is contained in:
parent
6029baf48e
commit
8155840989
|
@ -0,0 +1,95 @@
|
||||||
|
-- From https://github.com/e-lab/torch-toolbox/blob/master/Sanitize/sanitize.lua
|
||||||
|
|
||||||
|
require('torch')
|
||||||
|
require('nn')
|
||||||
|
require('cunn')
|
||||||
|
require('cudnn')
|
||||||
|
|
||||||
|
|
||||||
|
-- common obj name to be freed
|
||||||
|
local common = {'output', 'gradInput'}
|
||||||
|
|
||||||
|
-- temporary buffer name other than output/gradInput
|
||||||
|
local t = {
|
||||||
|
-- convolution
|
||||||
|
['nn.SpatialConvolution'] = {'finput', 'fgradInput'},
|
||||||
|
['nn.SpatialConvolutionMM'] = {'finput', 'fgradInput'},
|
||||||
|
|
||||||
|
-- pooling
|
||||||
|
['nn.SpatialMaxPooling'] = {'indices'},
|
||||||
|
['nn.TemporalMaxPooling'] = {'indices'},
|
||||||
|
['nn.VolumetricMaxPooling'] = {'indices'},
|
||||||
|
['nn.SpatialFractionalMaxPooling'] = {'indices'},
|
||||||
|
|
||||||
|
-- regularizer
|
||||||
|
['nn.BatchNormalization'] = {'buffer', 'buffer2', 'centered', 'normalized'},
|
||||||
|
['nn.SpatialBatchNormalization'] = {'buffer', 'buffer2','centered', 'normalized'},
|
||||||
|
['nn.Dropout'] = {'noise'},
|
||||||
|
['nn.SpatialDropout'] = {'noise'},
|
||||||
|
|
||||||
|
-- transfer
|
||||||
|
['nn.PReLU'] = {'gradWeightBuf', 'gradWeightBuf2'},
|
||||||
|
['nn.LogSigmoid'] = {'buffer'},
|
||||||
|
|
||||||
|
-- etc
|
||||||
|
['nn.Mean'] = {'_gradInput'},
|
||||||
|
['nn.Normalize'] = {'_output', 'norm', 'normp'},
|
||||||
|
['nn.PairwiseDistance'] = {'diff'},
|
||||||
|
['nn.Reshape'] = {'_input', '_gradOutput'},
|
||||||
|
|
||||||
|
-- fbcunn
|
||||||
|
['nn.AbstractParallel'] = {'homeGradBuffers', 'input_gpu', 'gradOutput_gpu', 'gradInput_gpu'},
|
||||||
|
['nn.DataParallel'] = {'homeGradBuffers', 'input_gpu', 'gradOutput_gpu', 'gradInput_gpu'},
|
||||||
|
['nn.ModelParallel'] = {'homeGradBuffers', 'input_gpu', 'gradOutput_gpu', 'gradInput_gpu'},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
local function free_table_or_tensor(val, name, field)
|
||||||
|
if type(val[name]) == 'table' then
|
||||||
|
val[name] = {}
|
||||||
|
elseif type(val[name]) == 'userdata' then
|
||||||
|
val[name] = field.new()
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
local function is_member(name, t)
|
||||||
|
if t == nil then
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
|
||||||
|
for _, value in pairs(t) do
|
||||||
|
if name == value then
|
||||||
|
return true
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
-- Taken and modified from Soumith's imagenet-multiGPU.torch code
|
||||||
|
-- https://github.com/soumith/imagenet-multiGPU.torch/blob/master/train.lua
|
||||||
|
local function sanitize(model)
|
||||||
|
local list = model:listModules()
|
||||||
|
for _,val in ipairs(list) do
|
||||||
|
for name,field in pairs(val) do
|
||||||
|
|
||||||
|
-- remove ffi obj
|
||||||
|
if torch.type(field) == 'cdata' then
|
||||||
|
val[name] = nil
|
||||||
|
|
||||||
|
-- remove common obj
|
||||||
|
elseif is_member(name, common) then
|
||||||
|
free_table_or_tensor(val, name, field)
|
||||||
|
|
||||||
|
-- remove specific obj
|
||||||
|
elseif is_member(name, t[val.__typename]) then
|
||||||
|
free_table_or_tensor(val, name, field)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return model
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
return sanitize
|
|
@ -23,6 +23,7 @@ require 'torchx' --for concetration the table of tensors
|
||||||
|
|
||||||
paths.dofile("OpenFaceOptim.lua")
|
paths.dofile("OpenFaceOptim.lua")
|
||||||
|
|
||||||
|
local sanitize = paths.dofile('sanitize.lua')
|
||||||
|
|
||||||
local optimMethod = optim.adadelta
|
local optimMethod = optim.adadelta
|
||||||
local optimState = {} -- Use for other algorithms like SGD
|
local optimState = {} -- Use for other algorithms like SGD
|
||||||
|
@ -34,22 +35,6 @@ local batchNumber
|
||||||
local triplet_loss
|
local triplet_loss
|
||||||
|
|
||||||
|
|
||||||
local function sanitize(net)
|
|
||||||
net:apply(function (val)
|
|
||||||
for name,field in pairs(val) do
|
|
||||||
if torch.type(field) == 'cdata' then val[name] = nil end
|
|
||||||
if name == 'homeGradBuffers' then val[name] = nil end
|
|
||||||
if name == 'input_gpu' then val['input_gpu'] = {} end
|
|
||||||
if name == 'gradOutput_gpu' then val['gradOutput_gpu'] = {} end
|
|
||||||
if name == 'gradInput_gpu' then val['gradInput_gpu'] = {} end
|
|
||||||
if (name == 'output' or name == 'gradInput')
|
|
||||||
and torch.type(field) == 'torch.CudaTensor' then
|
|
||||||
cutorch.withDevice(field:getDevice(), function() val[name] = field.new() end)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end)
|
|
||||||
end
|
|
||||||
|
|
||||||
-- From https://groups.google.com/d/msg/torch7/i8sJYlgQPeA/wiHlPSa5-HYJ
|
-- From https://groups.google.com/d/msg/torch7/i8sJYlgQPeA/wiHlPSa5-HYJ
|
||||||
local function replaceModules(net, orig_class_name, replacer)
|
local function replaceModules(net, orig_class_name, replacer)
|
||||||
local nodes, container_nodes = net:findModules(orig_class_name)
|
local nodes, container_nodes = net:findModules(orig_class_name)
|
||||||
|
|
Loading…
Reference in New Issue