2015-09-24 07:49:45 +08:00
-- Copyright 2015 Carnegie Mellon University
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
2016-01-07 05:04:10 +08:00
-- 2015-08-09: [Brandon Amos] Initial implementation.
-- 2016-01-04: [Bartosz Ludwiczuk] Substantial improvements at
-- https://github.com/melgor/Triplet-Learning
2015-11-24 05:12:59 +08:00
2015-09-24 07:49:45 +08:00
require ' optim '
require ' fbnn '
require ' image '
2016-01-07 05:04:10 +08:00
require ' torchx ' --for concetration the table of tensors
2015-09-24 07:49:45 +08:00
2015-10-09 22:48:06 +08:00
paths.dofile ( " OpenFaceOptim.lua " )
2015-09-24 07:49:45 +08:00
2016-01-10 06:17:20 +08:00
local sanitize = paths.dofile ( ' sanitize.lua ' )
2015-09-24 07:49:45 +08:00
local optimMethod = optim.adadelta
local optimState = { } -- Use for other algorithms like SGD
2015-10-09 22:48:06 +08:00
local optimator = OpenFaceOptim ( model , optimState )
2015-09-24 07:49:45 +08:00
trainLogger = optim.Logger ( paths.concat ( opt.save , ' train.log ' ) )
local batchNumber
local triplet_loss
2016-01-08 07:28:05 +08:00
-- From https://groups.google.com/d/msg/torch7/i8sJYlgQPeA/wiHlPSa5-HYJ
local function replaceModules ( net , orig_class_name , replacer )
local nodes , container_nodes = net : findModules ( orig_class_name )
for i = 1 , # nodes do
for j = 1 , # ( container_nodes [ i ] . modules ) do
if container_nodes [ i ] . modules [ j ] == nodes [ i ] then
local orig_mod = container_nodes [ i ] . modules [ j ]
container_nodes [ i ] . modules [ j ] = replacer ( orig_mod )
end
end
end
end
local function cudnn_to_nn ( net )
local net_nn = net : clone ( ) : float ( )
replaceModules ( net_nn , ' cudnn.SpatialConvolution ' ,
function ( cudnn_mod )
local nn_mod = nn.SpatialConvolutionMM (
cudnn_mod.nInputPlane , cudnn_mod.nOutputPlane ,
cudnn_mod.kW , cudnn_mod.kH ,
cudnn_mod.dW , cudnn_mod.dH ,
cudnn_mod.padW , cudnn_mod.padH
)
nn_mod.weight : copy ( cudnn_mod.weight )
nn_mod.bias : copy ( cudnn_mod.bias )
return nn_mod
end
)
replaceModules ( net_nn , ' cudnn.SpatialAveragePooling ' ,
function ( cudnn_mod )
return nn.SpatialAveragePooling (
cudnn_mod.kW , cudnn_mod.kH ,
cudnn_mod.dW , cudnn_mod.dH ,
cudnn_mod.padW , cudnn_mod.padH
)
end
)
replaceModules ( net_nn , ' cudnn.SpatialMaxPooling ' ,
function ( cudnn_mod )
return nn.SpatialMaxPooling (
cudnn_mod.kW , cudnn_mod.kH ,
cudnn_mod.dW , cudnn_mod.dH ,
cudnn_mod.padW , cudnn_mod.padH
)
end
)
replaceModules ( net_nn , ' cudnn.ReLU ' , function ( ) return nn.ReLU ( ) end )
replaceModules ( net_nn , ' cudnn.SpatialCrossMapLRN ' ,
function ( cudnn_mod )
return nn.SpatialCrossMapLRN ( cudnn_mod.size , cudnn_mod.alpha ,
cudnn_mod.beta , cudnn_mod.K )
end
)
return net_nn
end
2015-09-24 07:49:45 +08:00
function train ( )
print ( ' ==> doing epoch on training data: ' )
print ( " ==> online epoch # " .. epoch )
batchNumber = 0
cutorch.synchronize ( )
-- set the dropouts to training mode
model : training ( )
model : cuda ( ) -- get it back on the right GPUs.
local tm = torch.Timer ( )
triplet_loss = 0
local i = 1
while batchNumber < opt.epochSize do
-- queue jobs to data-workers
donkeys : addjob (
2016-01-07 05:04:10 +08:00
-- the job callback (runs in data-worker thread)
2015-09-24 07:49:45 +08:00
function ( )
local inputs , numPerClass = trainLoader : samplePeople ( opt.peoplePerBatch ,
opt.imagesPerPerson )
2015-11-07 06:09:47 +08:00
inputs = inputs : float ( )
numPerClass = numPerClass : float ( )
2015-09-24 07:49:45 +08:00
return sendTensor ( inputs ) , sendTensor ( numPerClass )
end ,
2016-01-07 05:04:10 +08:00
-- the end callback (runs in the main thread)
2015-09-24 07:49:45 +08:00
trainBatch
)
if i % 5 == 0 then
donkeys : synchronize ( )
end
i = i + 1
end
donkeys : synchronize ( )
cutorch.synchronize ( )
triplet_loss = triplet_loss / batchNumber
trainLogger : add {
[ ' avg triplet loss (train set) ' ] = triplet_loss ,
}
print ( string.format ( ' Epoch: [%d][TRAINING SUMMARY] Total Time(s): %.2f \t '
.. ' average triplet loss (per batch): %.2f ' ,
epoch , tm : time ( ) . real , triplet_loss ) )
print ( ' \n ' )
collectgarbage ( )
2016-01-07 05:04:10 +08:00
2015-09-24 07:49:45 +08:00
sanitize ( model )
2016-01-07 05:04:10 +08:00
local nnModel = cudnn_to_nn ( model ) : float ( )
torch.save ( paths.concat ( opt.save , ' model_ ' .. epoch .. ' .t7 ' ) , nnModel )
2015-09-24 07:49:45 +08:00
torch.save ( paths.concat ( opt.save , ' optimState_ ' .. epoch .. ' .t7 ' ) , optimState )
collectgarbage ( )
end -- of train()
local inputsCPU = torch.FloatTensor ( )
local numPerClass = torch.FloatTensor ( )
local timer = torch.Timer ( )
function trainBatch ( inputsThread , numPerClassThread )
if batchNumber >= opt.epochSize then
2016-01-07 05:04:10 +08:00
return
end
cutorch.synchronize ( )
timer : reset ( )
receiveTensor ( inputsThread , inputsCPU )
receiveTensor ( numPerClassThread , numPerClass )
local numImages = inputsCPU : size ( 1 )
local embeddings = model : forward ( inputsCPU : cuda ( ) ) : float ( )
local as_table = { }
local ps_table = { }
local ns_table = { }
local triplet_idx = { }
local num_example_per_idx = torch.Tensor ( embeddings : size ( 1 ) )
num_example_per_idx : zero ( )
local tripIdx = 1
local embStartIdx = 1
local numTrips = 0
for i = 1 , opt.peoplePerBatch do
local n = numPerClass [ i ]
for j = 1 , n - 1 do --for every image in batch
local aIdx = embStartIdx + j - 1
local diff = embeddings - embeddings [ { { aIdx } } ] : expandAs ( embeddings )
local norms = diff : norm ( 2 , 2 ) : pow ( 2 ) : squeeze ( ) --L2 norm have be squared
for pair = j , n - 1 do --create all posible positive pairs
local pIdx = embStartIdx + pair
-- Select a semi-hard negative that has a distance
-- further away from the positive exemplar. Oxford-Face Idea
--choose random example which is in margin
local fff = ( embeddings [ aIdx ] - embeddings [ pIdx ] ) : norm ( 2 )
local normsP = norms - torch.Tensor ( embeddings : size ( 1 ) ) : fill ( fff * fff ) --L2 norm have be squared
--clean the idx of same class by setting to them max value
normsP [ { { embStartIdx , embStartIdx + n - 1 } } ] = normsP : max ( )
-- get indexes of example which are inside margin
local in_margin = normsP : lt ( opt.alpha )
local allNeg = torch.find ( in_margin , 1 )
if table.getn ( allNeg ) ~= 0 then --use only non-random triplets. Random triples (which are beyond margin) will just produce gradient = 0, so average gradient will decrease
selNegIdx = allNeg [ math.random ( table.getn ( allNeg ) ) ]
--get embeding of each example
table.insert ( as_table , embeddings [ aIdx ] )
table.insert ( ps_table , embeddings [ pIdx ] )
table.insert ( ns_table , embeddings [ selNegIdx ] )
-- get original idx of triplets
table.insert ( triplet_idx , { aIdx , pIdx , selNegIdx } )
-- increase number of times of using each example, need for averaging then
num_example_per_idx [ aIdx ] = num_example_per_idx [ aIdx ] + 1
num_example_per_idx [ pIdx ] = num_example_per_idx [ pIdx ] + 1
num_example_per_idx [ selNegIdx ] = num_example_per_idx [ selNegIdx ] + 1
tripIdx = tripIdx + 1
end
numTrips = numTrips + 1
2015-09-24 07:49:45 +08:00
end
2016-01-07 05:04:10 +08:00
end
embStartIdx = embStartIdx + n
end
assert ( embStartIdx - 1 == numImages )
print ( ( ' + (nTrips, nTripsRight) = (%d, %d) ' ) : format ( numTrips , table.getn ( as_table ) ) )
local as = torch.concat ( as_table ) : view ( table.getn ( as_table ) , opt.embSize )
local ps = torch.concat ( ps_table ) : view ( table.getn ( ps_table ) , opt.embSize )
local ns = torch.concat ( ns_table ) : view ( table.getn ( ns_table ) , opt.embSize )
local asCuda = torch.CudaTensor ( )
local psCuda = torch.CudaTensor ( )
local nsCuda = torch.CudaTensor ( )
local sz = as : size ( )
2016-01-08 07:28:05 +08:00
local inCuda = inputsCPU : cuda ( )
2016-01-07 05:04:10 +08:00
asCuda : resize ( sz ) : copy ( as )
psCuda : resize ( sz ) : copy ( ps )
nsCuda : resize ( sz ) : copy ( ns )
2016-01-08 07:28:05 +08:00
local err , _ = optimator : optimizeTriplet (
2016-01-07 05:04:10 +08:00
optimMethod , inCuda , { asCuda , psCuda , nsCuda } , criterion ,
2016-01-08 07:28:05 +08:00
triplet_idx -- , num_example_per_idx
2016-01-07 05:04:10 +08:00
)
-- DataParallelTable's syncParameters
model : apply ( function ( m ) if m.syncParameters then m : syncParameters ( ) end end )
cutorch.synchronize ( )
batchNumber = batchNumber + 1
print ( ( ' Epoch: [%d][%d/%d] \t Time %.3f \t tripErr %.2e ' ) : format (
epoch , batchNumber , opt.epochSize , timer : time ( ) . real , err ) )
timer : reset ( )
triplet_loss = triplet_loss + err
2015-09-24 07:49:45 +08:00
end