Add print-network-table

This commit is contained in:
Brandon Amos 2016-02-06 18:56:32 -05:00
parent 3fefe5825c
commit c206fa20c5
1 changed files with 38 additions and 0 deletions

38
util/print-network-table.lua Executable file
View File

@ -0,0 +1,38 @@
#!/usr/bin/env th
require 'torch'
require 'nn'
require 'dpnn'
torch.setdefaulttensortype('torch.FloatTensor')
local cmd = torch.CmdLine()
cmd:text()
cmd:text('Print network table.')
cmd:text()
cmd:text('Options:')
cmd:option('-modelDef', '/home/bamos/repos/openface/models/openface/nn4.small2.def.lua', 'Path to model definition.')
cmd:option('-imgDim', 96, 'Image dimension. nn1=224, nn4=96')
cmd:option('-embSize', 128)
cmd:text()
opt = cmd:parse(arg or {})
paths.dofile(opt.modelDef)
local net = createModel()
local img = torch.randn(1, 3, opt.imgDim, opt.imgDim)
net:forward(img)
-- for i,module in ipairs(net:listModules()) do
for i=1,#net.modules do
local module = net.modules[i]
local out = torch.typename(module) .. ": "
for j, sz in ipairs(torch.totable(module.output:size())) do
-- print(sz)
out = out .. sz .. ', '
end
out = string.sub(out, 1, -3)
print(out)
end