2015-12-27 06:11:44 +08:00
|
|
|
-- Source: https://github.com/soumith/imagenet-multiGPU.torch/blob/master/util.lua
|
2015-09-24 07:49:45 +08:00
|
|
|
|
|
|
|
local ffi=require 'ffi'
|
|
|
|
------ Some FFI stuff used to pass storages between threads ------------------
|
|
|
|
ffi.cdef[[
|
|
|
|
void THFloatStorage_free(THFloatStorage *self);
|
|
|
|
void THLongStorage_free(THLongStorage *self);
|
|
|
|
]]
|
|
|
|
|
2015-12-27 21:41:49 +08:00
|
|
|
local function setFloatStorage(tensor, storage_p)
|
2015-09-24 07:49:45 +08:00
|
|
|
assert(storage_p and storage_p ~= 0, "FloatStorage is NULL pointer");
|
|
|
|
local cstorage = ffi.cast('THFloatStorage*', torch.pointer(tensor:storage()))
|
|
|
|
if cstorage ~= nil then
|
|
|
|
ffi.C['THFloatStorage_free'](cstorage)
|
|
|
|
end
|
|
|
|
local storage = ffi.cast('THFloatStorage*', storage_p)
|
|
|
|
tensor:cdata().storage = storage
|
|
|
|
end
|
|
|
|
|
2015-12-27 21:41:49 +08:00
|
|
|
local function setLongStorage(tensor, storage_p)
|
2015-09-24 07:49:45 +08:00
|
|
|
assert(storage_p and storage_p ~= 0, "LongStorage is NULL pointer");
|
|
|
|
local cstorage = ffi.cast('THLongStorage*', torch.pointer(tensor:storage()))
|
|
|
|
if cstorage ~= nil then
|
|
|
|
ffi.C['THLongStorage_free'](cstorage)
|
|
|
|
end
|
|
|
|
local storage = ffi.cast('THLongStorage*', storage_p)
|
|
|
|
tensor:cdata().storage = storage
|
|
|
|
end
|
|
|
|
|
|
|
|
function sendTensor(inputs)
|
|
|
|
local size = inputs:size()
|
|
|
|
local ttype = inputs:type()
|
|
|
|
local i_stg = tonumber(ffi.cast('intptr_t', torch.pointer(inputs:storage())))
|
|
|
|
inputs:cdata().storage = nil
|
|
|
|
return {i_stg, size, ttype}
|
|
|
|
end
|
|
|
|
|
|
|
|
function receiveTensor(obj, buffer)
|
|
|
|
local pointer = obj[1]
|
|
|
|
local size = obj[2]
|
|
|
|
local ttype = obj[3]
|
|
|
|
if buffer then
|
|
|
|
buffer:resize(size)
|
|
|
|
assert(buffer:type() == ttype, 'Buffer is wrong type')
|
|
|
|
else
|
|
|
|
buffer = torch[ttype].new():resize(size)
|
|
|
|
end
|
|
|
|
if ttype == 'torch.FloatTensor' then
|
|
|
|
setFloatStorage(buffer, pointer)
|
|
|
|
elseif ttype == 'torch.LongTensor' then
|
|
|
|
setLongStorage(buffer, pointer)
|
|
|
|
else
|
|
|
|
error('Unknown type')
|
|
|
|
end
|
|
|
|
return buffer
|
|
|
|
end
|
2016-03-31 05:21:13 +08:00
|
|
|
|
|
|
|
--Reduce the memory consumption by model by sharing the buffers
|
|
|
|
function optimizeNet( model, inputSize )
|
2016-04-01 19:04:25 +08:00
|
|
|
local optnet_loaded, optnet = pcall(require,'optnet')
|
2016-06-14 01:50:33 +08:00
|
|
|
if optnet_loaded then
|
2016-04-01 19:04:25 +08:00
|
|
|
local opts = {inplace=true, mode='training', removeGradParams=false}
|
2016-06-15 01:51:18 +08:00
|
|
|
local input = torch.rand(2,3,inputSize,inputSize)
|
2016-04-01 19:04:25 +08:00
|
|
|
if opt.cuda then
|
|
|
|
input = input:cuda()
|
|
|
|
end
|
|
|
|
optnet.optimizeMemory(model, input, opts)
|
|
|
|
else
|
2016-06-14 01:50:33 +08:00
|
|
|
print("'optnet' package not found, install it to reduce the memory consumption.")
|
|
|
|
print("Repo: https://github.com/fmassa/optimize-net")
|
2016-03-31 05:21:13 +08:00
|
|
|
end
|
|
|
|
end
|
2016-06-29 17:48:03 +08:00
|
|
|
|
|
|
|
function makeDataParallel(model, nGPU)
|
|
|
|
-- Wrap the model with DataParallelTable, if using more than one GPU
|
|
|
|
if nGPU > 1 then
|
|
|
|
local gpus = torch.range(1, nGPU):totable()
|
|
|
|
local fastest, benchmark = cudnn.fastest, cudnn.benchmark
|
|
|
|
|
|
|
|
local dpt = nn.DataParallelTable(1, true, true)
|
|
|
|
:add(model, gpus)
|
|
|
|
:threads(function()
|
|
|
|
require ("dpnn")
|
|
|
|
local cudnn = require 'cudnn'
|
|
|
|
cudnn.fastest, cudnn.benchmark = fastest, benchmark
|
|
|
|
end)
|
|
|
|
dpt.gradInput = nil
|
|
|
|
|
|
|
|
model = dpt:cuda()
|
|
|
|
end
|
|
|
|
return model
|
|
|
|
end
|