Clarify training procedure.
Summary of content from @Strateus and @melgor at https://groups.google.com/forum/#!topic/cmu-openface/XgxfN8Xy9nA
This commit is contained in:
parent
bcf69ec2bb
commit
df89d749d1
|
@ -12,6 +12,75 @@
|
||||||
-- See the License for the specific language governing permissions and
|
-- See the License for the specific language governing permissions and
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
-- This code samples images and trains a triplet network with the
|
||||||
|
-- following steps, which are referenced inline.
|
||||||
|
--
|
||||||
|
-- [Step 1]
|
||||||
|
-- Sample at most opt.peoplePerBatch * opt.imagesPerPerson
|
||||||
|
-- images by choosing random people and images from the
|
||||||
|
-- training set.
|
||||||
|
--
|
||||||
|
-- [Step 2]
|
||||||
|
-- Compute the embeddings of all of these images by doing forward
|
||||||
|
-- passs with the current state of a network.
|
||||||
|
-- This is done offline and the network is not modified.
|
||||||
|
-- Since not all of the images will fit in GPU memory, this is
|
||||||
|
-- split into minibatches.
|
||||||
|
--
|
||||||
|
-- [Step 3]
|
||||||
|
-- Select the semi-hard triplets as described in the FaceNet paper.
|
||||||
|
--
|
||||||
|
-- [Step 4]
|
||||||
|
-- Google is able to do a single forward and backward pass to process
|
||||||
|
-- all the triplets and update the network's parameters at once since
|
||||||
|
-- they use a distributed system.
|
||||||
|
-- With a memory-limited GPU, OpenFace uses smaller mini-batches and
|
||||||
|
-- does many forward and backward passes to iteratively update the
|
||||||
|
-- network's parameters.
|
||||||
|
--
|
||||||
|
--
|
||||||
|
--
|
||||||
|
-- Some other useful references for models with shared weights are:
|
||||||
|
--
|
||||||
|
-- 1. Weinberger, K. Q., & Saul, L. K. (2009).
|
||||||
|
-- Distance metric learning for large margin
|
||||||
|
-- nearest neighbor classification.
|
||||||
|
-- The Journal of Machine Learning Research, 10, 207-244.
|
||||||
|
--
|
||||||
|
-- http://machinelearning.wustl.edu/mlpapers/paper_files/jmlr10_weinberger09a.pdf
|
||||||
|
--
|
||||||
|
--
|
||||||
|
-- Citation from the FaceNet paper on their motivation for
|
||||||
|
-- using the triplet loss.
|
||||||
|
--
|
||||||
|
--
|
||||||
|
-- 2. Chopra, S., Hadsell, R., & LeCun, Y. (2005, June).
|
||||||
|
-- Learning a similarity metric discriminatively, with application
|
||||||
|
-- to face verification.
|
||||||
|
-- In Computer Vision and Pattern Recognition, 2005. CVPR 2005.
|
||||||
|
-- IEEE Computer Society Conference on (Vol. 1, pp. 539-546). IEEE.
|
||||||
|
--
|
||||||
|
-- http://yann.lecun.com/exdb/publis/pdf/chopra-05.pdf
|
||||||
|
--
|
||||||
|
--
|
||||||
|
-- The idea is to just look at pairs of images at a time
|
||||||
|
-- rather than triplets, which they train with two networks
|
||||||
|
-- in parallel with shared weights.
|
||||||
|
--
|
||||||
|
-- 3. Hoffer, E., & Ailon, N. (2014).
|
||||||
|
-- Deep metric learning using Triplet network.
|
||||||
|
-- arXiv preprint arXiv:1412.6622.
|
||||||
|
--
|
||||||
|
-- http://arxiv.org/abs/1412.6622
|
||||||
|
--
|
||||||
|
--
|
||||||
|
-- Not used in OpenFace or FaceNet, but another view of triplet
|
||||||
|
-- networks that provides slightly more details about training using
|
||||||
|
-- three networks with shared weights.
|
||||||
|
-- The code uses Torch and is available on GitHub at
|
||||||
|
-- https://github.com/eladhoffer/TripletNet
|
||||||
|
|
||||||
require 'optim'
|
require 'optim'
|
||||||
require 'fbnn'
|
require 'fbnn'
|
||||||
require 'image'
|
require 'image'
|
||||||
|
@ -46,15 +115,14 @@ function train()
|
||||||
while batchNumber < opt.epochSize do
|
while batchNumber < opt.epochSize do
|
||||||
-- queue jobs to data-workers
|
-- queue jobs to data-workers
|
||||||
donkeys:addjob(
|
donkeys:addjob(
|
||||||
-- the job callback (runs in data-worker thread)
|
|
||||||
function()
|
function()
|
||||||
|
-- [Step 1]: Sample people/images from the dataset.
|
||||||
local inputs, numPerClass = trainLoader:samplePeople(opt.peoplePerBatch,
|
local inputs, numPerClass = trainLoader:samplePeople(opt.peoplePerBatch,
|
||||||
opt.imagesPerPerson)
|
opt.imagesPerPerson)
|
||||||
inputs = inputs:float()
|
inputs = inputs:float()
|
||||||
numPerClass = numPerClass:float()
|
numPerClass = numPerClass:float()
|
||||||
return sendTensor(inputs), sendTensor(numPerClass)
|
return sendTensor(inputs), sendTensor(numPerClass)
|
||||||
end,
|
end,
|
||||||
-- the end callback (runs in the main thread)
|
|
||||||
trainBatch
|
trainBatch
|
||||||
)
|
)
|
||||||
if i % 5 == 0 then
|
if i % 5 == 0 then
|
||||||
|
@ -114,8 +182,7 @@ function trainBatch(inputsThread, numPerClassThread)
|
||||||
receiveTensor(inputsThread, inputsCPU)
|
receiveTensor(inputsThread, inputsCPU)
|
||||||
receiveTensor(numPerClassThread, numPerClass)
|
receiveTensor(numPerClassThread, numPerClass)
|
||||||
|
|
||||||
-- inputs:resize(inputsCPU:size()):copy(inputsCPU)
|
-- [Step 2]: Compute embeddings.
|
||||||
|
|
||||||
local numImages = inputsCPU:size(1)
|
local numImages = inputsCPU:size(1)
|
||||||
local embeddings = torch.Tensor(numImages, 128)
|
local embeddings = torch.Tensor(numImages, 128)
|
||||||
local singleNet = model.modules[1]
|
local singleNet = model.modules[1]
|
||||||
|
@ -133,6 +200,7 @@ function trainBatch(inputsThread, numPerClassThread)
|
||||||
end
|
end
|
||||||
assert(beginIdx - 1 == numImages)
|
assert(beginIdx - 1 == numImages)
|
||||||
|
|
||||||
|
-- [Step 3]: Select semi-hard triplets.
|
||||||
local numTrips = numImages - opt.peoplePerBatch
|
local numTrips = numImages - opt.peoplePerBatch
|
||||||
local as = torch.Tensor(numTrips, inputs:size(2),
|
local as = torch.Tensor(numTrips, inputs:size(2),
|
||||||
inputs:size(3), inputs:size(4))
|
inputs:size(3), inputs:size(4))
|
||||||
|
@ -194,6 +262,7 @@ function trainBatch(inputsThread, numPerClassThread)
|
||||||
print((' + (nRandomNegs, nTrips) = (%d, %d)'):format(nRandomNegs, numTrips))
|
print((' + (nRandomNegs, nTrips) = (%d, %d)'):format(nRandomNegs, numTrips))
|
||||||
|
|
||||||
|
|
||||||
|
-- [Step 4]: Upate network parameters.
|
||||||
local beginIdx = 1
|
local beginIdx = 1
|
||||||
local asCuda = torch.CudaTensor()
|
local asCuda = torch.CudaTensor()
|
||||||
local psCuda = torch.CudaTensor()
|
local psCuda = torch.CudaTensor()
|
||||||
|
|
Loading…
Reference in New Issue