Clarify training procedure.
Summary of content from @Strateus and @melgor at https://groups.google.com/forum/#!topic/cmu-openface/XgxfN8Xy9nA
This commit is contained in:
parent
bcf69ec2bb
commit
df89d749d1
|
@ -12,6 +12,75 @@
|
|||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
|
||||
-- This code samples images and trains a triplet network with the
|
||||
-- following steps, which are referenced inline.
|
||||
--
|
||||
-- [Step 1]
|
||||
-- Sample at most opt.peoplePerBatch * opt.imagesPerPerson
|
||||
-- images by choosing random people and images from the
|
||||
-- training set.
|
||||
--
|
||||
-- [Step 2]
|
||||
-- Compute the embeddings of all of these images by doing forward
|
||||
-- passs with the current state of a network.
|
||||
-- This is done offline and the network is not modified.
|
||||
-- Since not all of the images will fit in GPU memory, this is
|
||||
-- split into minibatches.
|
||||
--
|
||||
-- [Step 3]
|
||||
-- Select the semi-hard triplets as described in the FaceNet paper.
|
||||
--
|
||||
-- [Step 4]
|
||||
-- Google is able to do a single forward and backward pass to process
|
||||
-- all the triplets and update the network's parameters at once since
|
||||
-- they use a distributed system.
|
||||
-- With a memory-limited GPU, OpenFace uses smaller mini-batches and
|
||||
-- does many forward and backward passes to iteratively update the
|
||||
-- network's parameters.
|
||||
--
|
||||
--
|
||||
--
|
||||
-- Some other useful references for models with shared weights are:
|
||||
--
|
||||
-- 1. Weinberger, K. Q., & Saul, L. K. (2009).
|
||||
-- Distance metric learning for large margin
|
||||
-- nearest neighbor classification.
|
||||
-- The Journal of Machine Learning Research, 10, 207-244.
|
||||
--
|
||||
-- http://machinelearning.wustl.edu/mlpapers/paper_files/jmlr10_weinberger09a.pdf
|
||||
--
|
||||
--
|
||||
-- Citation from the FaceNet paper on their motivation for
|
||||
-- using the triplet loss.
|
||||
--
|
||||
--
|
||||
-- 2. Chopra, S., Hadsell, R., & LeCun, Y. (2005, June).
|
||||
-- Learning a similarity metric discriminatively, with application
|
||||
-- to face verification.
|
||||
-- In Computer Vision and Pattern Recognition, 2005. CVPR 2005.
|
||||
-- IEEE Computer Society Conference on (Vol. 1, pp. 539-546). IEEE.
|
||||
--
|
||||
-- http://yann.lecun.com/exdb/publis/pdf/chopra-05.pdf
|
||||
--
|
||||
--
|
||||
-- The idea is to just look at pairs of images at a time
|
||||
-- rather than triplets, which they train with two networks
|
||||
-- in parallel with shared weights.
|
||||
--
|
||||
-- 3. Hoffer, E., & Ailon, N. (2014).
|
||||
-- Deep metric learning using Triplet network.
|
||||
-- arXiv preprint arXiv:1412.6622.
|
||||
--
|
||||
-- http://arxiv.org/abs/1412.6622
|
||||
--
|
||||
--
|
||||
-- Not used in OpenFace or FaceNet, but another view of triplet
|
||||
-- networks that provides slightly more details about training using
|
||||
-- three networks with shared weights.
|
||||
-- The code uses Torch and is available on GitHub at
|
||||
-- https://github.com/eladhoffer/TripletNet
|
||||
|
||||
require 'optim'
|
||||
require 'fbnn'
|
||||
require 'image'
|
||||
|
@ -46,15 +115,14 @@ function train()
|
|||
while batchNumber < opt.epochSize do
|
||||
-- queue jobs to data-workers
|
||||
donkeys:addjob(
|
||||
-- the job callback (runs in data-worker thread)
|
||||
function()
|
||||
-- [Step 1]: Sample people/images from the dataset.
|
||||
local inputs, numPerClass = trainLoader:samplePeople(opt.peoplePerBatch,
|
||||
opt.imagesPerPerson)
|
||||
inputs = inputs:float()
|
||||
numPerClass = numPerClass:float()
|
||||
return sendTensor(inputs), sendTensor(numPerClass)
|
||||
end,
|
||||
-- the end callback (runs in the main thread)
|
||||
trainBatch
|
||||
)
|
||||
if i % 5 == 0 then
|
||||
|
@ -114,8 +182,7 @@ function trainBatch(inputsThread, numPerClassThread)
|
|||
receiveTensor(inputsThread, inputsCPU)
|
||||
receiveTensor(numPerClassThread, numPerClass)
|
||||
|
||||
-- inputs:resize(inputsCPU:size()):copy(inputsCPU)
|
||||
|
||||
-- [Step 2]: Compute embeddings.
|
||||
local numImages = inputsCPU:size(1)
|
||||
local embeddings = torch.Tensor(numImages, 128)
|
||||
local singleNet = model.modules[1]
|
||||
|
@ -133,6 +200,7 @@ function trainBatch(inputsThread, numPerClassThread)
|
|||
end
|
||||
assert(beginIdx - 1 == numImages)
|
||||
|
||||
-- [Step 3]: Select semi-hard triplets.
|
||||
local numTrips = numImages - opt.peoplePerBatch
|
||||
local as = torch.Tensor(numTrips, inputs:size(2),
|
||||
inputs:size(3), inputs:size(4))
|
||||
|
@ -194,6 +262,7 @@ function trainBatch(inputsThread, numPerClassThread)
|
|||
print((' + (nRandomNegs, nTrips) = (%d, %d)'):format(nRandomNegs, numTrips))
|
||||
|
||||
|
||||
-- [Step 4]: Upate network parameters.
|
||||
local beginIdx = 1
|
||||
local asCuda = torch.CudaTensor()
|
||||
local psCuda = torch.CudaTensor()
|
||||
|
|
Loading…
Reference in New Issue