Clarify training procedure.

Summary of content from @Strateus and @melgor at
https://groups.google.com/forum/#!topic/cmu-openface/XgxfN8Xy9nA
This commit is contained in:
Brandon Amos 2015-11-23 16:12:59 -05:00
parent bcf69ec2bb
commit df89d749d1
1 changed files with 73 additions and 4 deletions

View File

@ -12,6 +12,75 @@
-- See the License for the specific language governing permissions and
-- limitations under the License.
-- This code samples images and trains a triplet network with the
-- following steps, which are referenced inline.
--
-- [Step 1]
-- Sample at most opt.peoplePerBatch * opt.imagesPerPerson
-- images by choosing random people and images from the
-- training set.
--
-- [Step 2]
-- Compute the embeddings of all of these images by doing forward
-- passs with the current state of a network.
-- This is done offline and the network is not modified.
-- Since not all of the images will fit in GPU memory, this is
-- split into minibatches.
--
-- [Step 3]
-- Select the semi-hard triplets as described in the FaceNet paper.
--
-- [Step 4]
-- Google is able to do a single forward and backward pass to process
-- all the triplets and update the network's parameters at once since
-- they use a distributed system.
-- With a memory-limited GPU, OpenFace uses smaller mini-batches and
-- does many forward and backward passes to iteratively update the
-- network's parameters.
--
--
--
-- Some other useful references for models with shared weights are:
--
-- 1. Weinberger, K. Q., & Saul, L. K. (2009).
-- Distance metric learning for large margin
-- nearest neighbor classification.
-- The Journal of Machine Learning Research, 10, 207-244.
--
-- http://machinelearning.wustl.edu/mlpapers/paper_files/jmlr10_weinberger09a.pdf
--
--
-- Citation from the FaceNet paper on their motivation for
-- using the triplet loss.
--
--
-- 2. Chopra, S., Hadsell, R., & LeCun, Y. (2005, June).
-- Learning a similarity metric discriminatively, with application
-- to face verification.
-- In Computer Vision and Pattern Recognition, 2005. CVPR 2005.
-- IEEE Computer Society Conference on (Vol. 1, pp. 539-546). IEEE.
--
-- http://yann.lecun.com/exdb/publis/pdf/chopra-05.pdf
--
--
-- The idea is to just look at pairs of images at a time
-- rather than triplets, which they train with two networks
-- in parallel with shared weights.
--
-- 3. Hoffer, E., & Ailon, N. (2014).
-- Deep metric learning using Triplet network.
-- arXiv preprint arXiv:1412.6622.
--
-- http://arxiv.org/abs/1412.6622
--
--
-- Not used in OpenFace or FaceNet, but another view of triplet
-- networks that provides slightly more details about training using
-- three networks with shared weights.
-- The code uses Torch and is available on GitHub at
-- https://github.com/eladhoffer/TripletNet
require 'optim'
require 'fbnn'
require 'image'
@ -46,15 +115,14 @@ function train()
while batchNumber < opt.epochSize do
-- queue jobs to data-workers
donkeys:addjob(
-- the job callback (runs in data-worker thread)
function()
-- [Step 1]: Sample people/images from the dataset.
local inputs, numPerClass = trainLoader:samplePeople(opt.peoplePerBatch,
opt.imagesPerPerson)
inputs = inputs:float()
numPerClass = numPerClass:float()
return sendTensor(inputs), sendTensor(numPerClass)
end,
-- the end callback (runs in the main thread)
trainBatch
)
if i % 5 == 0 then
@ -114,8 +182,7 @@ function trainBatch(inputsThread, numPerClassThread)
receiveTensor(inputsThread, inputsCPU)
receiveTensor(numPerClassThread, numPerClass)
-- inputs:resize(inputsCPU:size()):copy(inputsCPU)
-- [Step 2]: Compute embeddings.
local numImages = inputsCPU:size(1)
local embeddings = torch.Tensor(numImages, 128)
local singleNet = model.modules[1]
@ -133,6 +200,7 @@ function trainBatch(inputsThread, numPerClassThread)
end
assert(beginIdx - 1 == numImages)
-- [Step 3]: Select semi-hard triplets.
local numTrips = numImages - opt.peoplePerBatch
local as = torch.Tensor(numTrips, inputs:size(2),
inputs:size(3), inputs:size(4))
@ -194,6 +262,7 @@ function trainBatch(inputsThread, numPerClassThread)
print((' + (nRandomNegs, nTrips) = (%d, %d)'):format(nRandomNegs, numTrips))
-- [Step 4]: Upate network parameters.
local beginIdx = 1
local asCuda = torch.CudaTensor()
local psCuda = torch.CudaTensor()