From c206fa20c5c16a01491332b4b9574076d58b4c17 Mon Sep 17 00:00:00 2001 From: Brandon Amos Date: Sat, 6 Feb 2016 18:56:32 -0500 Subject: [PATCH] Add print-network-table --- util/print-network-table.lua | 38 ++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100755 util/print-network-table.lua diff --git a/util/print-network-table.lua b/util/print-network-table.lua new file mode 100755 index 0000000..6f2c7e2 --- /dev/null +++ b/util/print-network-table.lua @@ -0,0 +1,38 @@ +#!/usr/bin/env th + +require 'torch' +require 'nn' +require 'dpnn' + +torch.setdefaulttensortype('torch.FloatTensor') + +local cmd = torch.CmdLine() +cmd:text() +cmd:text('Print network table.') +cmd:text() +cmd:text('Options:') + +cmd:option('-modelDef', '/home/bamos/repos/openface/models/openface/nn4.small2.def.lua', 'Path to model definition.') +cmd:option('-imgDim', 96, 'Image dimension. nn1=224, nn4=96') +cmd:option('-embSize', 128) +cmd:text() + +opt = cmd:parse(arg or {}) + +paths.dofile(opt.modelDef) +local net = createModel() + +local img = torch.randn(1, 3, opt.imgDim, opt.imgDim) +net:forward(img) + +-- for i,module in ipairs(net:listModules()) do +for i=1,#net.modules do + local module = net.modules[i] + local out = torch.typename(module) .. ": " + for j, sz in ipairs(torch.totable(module.output:size())) do + -- print(sz) + out = out .. sz .. ', ' + end + out = string.sub(out, 1, -3) + print(out) +end