diff --git a/training/train.lua b/training/train.lua index d2a6d58..4df52dc 100644 --- a/training/train.lua +++ b/training/train.lua @@ -12,6 +12,75 @@ -- See the License for the specific language governing permissions and -- limitations under the License. + +-- This code samples images and trains a triplet network with the +-- following steps, which are referenced inline. +-- +-- [Step 1] +-- Sample at most opt.peoplePerBatch * opt.imagesPerPerson +-- images by choosing random people and images from the +-- training set. +-- +-- [Step 2] +-- Compute the embeddings of all of these images by doing forward +-- passs with the current state of a network. +-- This is done offline and the network is not modified. +-- Since not all of the images will fit in GPU memory, this is +-- split into minibatches. +-- +-- [Step 3] +-- Select the semi-hard triplets as described in the FaceNet paper. +-- +-- [Step 4] +-- Google is able to do a single forward and backward pass to process +-- all the triplets and update the network's parameters at once since +-- they use a distributed system. +-- With a memory-limited GPU, OpenFace uses smaller mini-batches and +-- does many forward and backward passes to iteratively update the +-- network's parameters. +-- +-- +-- +-- Some other useful references for models with shared weights are: +-- +-- 1. Weinberger, K. Q., & Saul, L. K. (2009). +-- Distance metric learning for large margin +-- nearest neighbor classification. +-- The Journal of Machine Learning Research, 10, 207-244. +-- +-- http://machinelearning.wustl.edu/mlpapers/paper_files/jmlr10_weinberger09a.pdf +-- +-- +-- Citation from the FaceNet paper on their motivation for +-- using the triplet loss. +-- +-- +-- 2. Chopra, S., Hadsell, R., & LeCun, Y. (2005, June). +-- Learning a similarity metric discriminatively, with application +-- to face verification. +-- In Computer Vision and Pattern Recognition, 2005. CVPR 2005. +-- IEEE Computer Society Conference on (Vol. 1, pp. 539-546). IEEE. +-- +-- http://yann.lecun.com/exdb/publis/pdf/chopra-05.pdf +-- +-- +-- The idea is to just look at pairs of images at a time +-- rather than triplets, which they train with two networks +-- in parallel with shared weights. +-- +-- 3. Hoffer, E., & Ailon, N. (2014). +-- Deep metric learning using Triplet network. +-- arXiv preprint arXiv:1412.6622. +-- +-- http://arxiv.org/abs/1412.6622 +-- +-- +-- Not used in OpenFace or FaceNet, but another view of triplet +-- networks that provides slightly more details about training using +-- three networks with shared weights. +-- The code uses Torch and is available on GitHub at +-- https://github.com/eladhoffer/TripletNet + require 'optim' require 'fbnn' require 'image' @@ -46,15 +115,14 @@ function train() while batchNumber < opt.epochSize do -- queue jobs to data-workers donkeys:addjob( - -- the job callback (runs in data-worker thread) function() + -- [Step 1]: Sample people/images from the dataset. local inputs, numPerClass = trainLoader:samplePeople(opt.peoplePerBatch, opt.imagesPerPerson) inputs = inputs:float() numPerClass = numPerClass:float() return sendTensor(inputs), sendTensor(numPerClass) end, - -- the end callback (runs in the main thread) trainBatch ) if i % 5 == 0 then @@ -114,8 +182,7 @@ function trainBatch(inputsThread, numPerClassThread) receiveTensor(inputsThread, inputsCPU) receiveTensor(numPerClassThread, numPerClass) - -- inputs:resize(inputsCPU:size()):copy(inputsCPU) - + -- [Step 2]: Compute embeddings. local numImages = inputsCPU:size(1) local embeddings = torch.Tensor(numImages, 128) local singleNet = model.modules[1] @@ -133,6 +200,7 @@ function trainBatch(inputsThread, numPerClassThread) end assert(beginIdx - 1 == numImages) + -- [Step 3]: Select semi-hard triplets. local numTrips = numImages - opt.peoplePerBatch local as = torch.Tensor(numTrips, inputs:size(2), inputs:size(3), inputs:size(4)) @@ -194,6 +262,7 @@ function trainBatch(inputsThread, numPerClassThread) print((' + (nRandomNegs, nTrips) = (%d, %d)'):format(nRandomNegs, numTrips)) + -- [Step 4]: Upate network parameters. local beginIdx = 1 local asCuda = torch.CudaTensor() local psCuda = torch.CudaTensor()