Initial commit of VGG model definitions.

This commit is contained in:
Brandon Amos 2016-06-08 12:13:27 -04:00
parent 288fd8a3bd
commit e534fa9c7f
2 changed files with 171 additions and 0 deletions

View File

@ -0,0 +1,79 @@
-- Model: vgg-face.def.lua
-- Description: VGG Face's network:
-- http://www.robots.ox.ac.uk/~vgg/publications/2015/Parkhi15/parkhi15.pdf
--
-- Input size: 3x224x224
-- Number of Parameters from net:getParameters() with embSize=128: 118003648
-- Components: Mostly `nn`
-- Devices: CPU and CUDA
--
-- Brandon Amos <http://bamos.github.io>
-- 2016-06-08
--
-- Copyright 2016 Carnegie Mellon University
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
imgDim = 224
local conv = nn.SpatialConvolutionMM
local relu = nn.ReLU
local mp = nn.SpatialMaxPooling
function createModel()
local net = nn.Sequential()
net:add(conv(3, 64, 3,3, 1,1, 1,1))
net:add(relu(true))
net:add(conv(64, 64, 3,3, 1,1, 1,1))
net:add(relu(true))
net:add(mp(2,2, 2,2))
net:add(conv(64, 128, 3,3, 1,1, 1,1))
net:add(relu(true))
net:add(conv(128, 128, 3,3, 1,1, 1,1))
net:add(relu(true))
net:add(mp(2,2, 2,2))
net:add(conv(128, 256, 3,3, 1,1, 1,1))
net:add(relu(true))
net:add(conv(256, 256, 3,3, 1,1, 1,1))
net:add(relu(true))
net:add(conv(256, 256, 3,3, 1,1, 1,1))
net:add(relu(true))
net:add(mp(2,2, 2,2))
net:add(conv(256, 512, 3,3, 1,1, 1,1))
net:add(relu(true))
net:add(conv(512, 512, 3,3, 1,1, 1,1))
net:add(relu(true))
net:add(conv(512, 512, 3,3, 1,1, 1,1))
net:add(relu(true))
net:add(mp(2,2, 2,2))
net:add(conv(512, 512, 3,3, 1,1, 1,1))
net:add(relu(true))
net:add(conv(512, 512, 3,3, 1,1, 1,1))
net:add(relu(true))
net:add(conv(512, 512, 3,3, 1,1, 1,1))
net:add(relu(true))
net:add(mp(2,2, 2,2))
-- Validate shape with:
-- net:add(nn.Reshape(25088))
net:add(nn.View(25088))
net:add(nn.Linear(25088, 4096))
net:add(relu(true))
net:add(nn.Linear(4096, opt.embSize))
net:add(nn.Normalize(2))
return net
end

View File

@ -0,0 +1,92 @@
-- Model: vgg-face.small1.def.lua
-- Description: Modified VGG Face network. Smaller and with batch normalization.
-- !! In progress, may change.
-- Input size: 3x96x96
-- Number of Parameters from net:getParameters() with embSize=128: TODO
-- Components: Mostly `nn`
-- Devices: CPU and CUDA
--
-- Brandon Amos <http://bamos.github.io>
-- 2016-06-08
--
-- Copyright 2016 Carnegie Mellon University
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
imgDim = 96
local conv = nn.SpatialConvolutionMM
local sbn = nn.SpatialBatchNormalization
local relu = nn.ReLU
local mp = nn.SpatialMaxPooling
function createModel()
local net = nn.Sequential()
net:add(conv(3, 64, 3,3, 1,1, 1,1))
net:add(sbn(64))
net:add(relu(true))
net:add(conv(64, 64, 3,3, 1,1, 1,1))
net:add(sbn(64))
net:add(relu(true))
net:add(mp(2,2, 2,2))
net:add(conv(64, 128, 3,3, 1,1, 1,1))
net:add(sbn(128))
net:add(relu(true))
net:add(conv(128, 128, 3,3, 1,1, 1,1))
net:add(sbn(128))
net:add(relu(true))
net:add(mp(2,2, 2,2))
net:add(conv(128, 256, 3,3, 1,1, 1,1))
net:add(sbn(256))
net:add(relu(true))
net:add(conv(256, 256, 3,3, 1,1, 1,1))
net:add(sbn(256))
net:add(relu(true))
net:add(conv(256, 256, 3,3, 1,1, 1,1))
net:add(sbn(256))
net:add(relu(true))
net:add(mp(2,2, 2,2))
net:add(conv(256, 512, 3,3, 1,1, 1,1))
net:add(sbn(512))
net:add(relu(true))
net:add(conv(512, 512, 3,3, 1,1, 1,1))
net:add(sbn(512))
net:add(relu(true))
net:add(conv(512, 512, 3,3, 1,1, 1,1))
net:add(sbn(512))
net:add(relu(true))
net:add(mp(2,2, 2,2))
net:add(conv(512, 512, 3,3, 1,1, 1,1))
net:add(sbn(512))
net:add(relu(true))
net:add(conv(512, 512, 3,3, 1,1, 1,1))
net:add(sbn(512))
net:add(relu(true))
net:add(conv(512, 512, 3,3, 1,1, 1,1))
net:add(sbn(512))
net:add(relu(true))
net:add(mp(2,2, 2,2))
-- Validate shape with:
net:add(nn.Reshape(4608))
net:add(nn.View(4608))
net:add(nn.Linear(4608, 1024))
net:add(relu(true))
net:add(nn.Linear(1024, opt.embSize))
net:add(nn.Normalize(2))
return net
end