2015-09-24 07:49:45 +08:00
|
|
|
-- Copyright 2015 Carnegie Mellon University
|
|
|
|
--
|
|
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
-- you may not use this file except in compliance with the License.
|
|
|
|
-- You may obtain a copy of the License at
|
|
|
|
--
|
|
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
--
|
|
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
-- See the License for the specific language governing permissions and
|
|
|
|
-- limitations under the License.
|
|
|
|
|
|
|
|
require 'optim'
|
|
|
|
require 'fbnn'
|
|
|
|
require 'image'
|
|
|
|
|
2015-10-09 22:48:06 +08:00
|
|
|
paths.dofile("OpenFaceOptim.lua")
|
2015-09-24 07:49:45 +08:00
|
|
|
|
|
|
|
|
|
|
|
local optimMethod = optim.adadelta
|
|
|
|
local optimState = {} -- Use for other algorithms like SGD
|
2015-10-09 22:48:06 +08:00
|
|
|
local optimator = OpenFaceOptim(model, optimState)
|
2015-09-24 07:49:45 +08:00
|
|
|
|
|
|
|
trainLogger = optim.Logger(paths.concat(opt.save, 'train.log'))
|
|
|
|
|
|
|
|
local batchNumber
|
|
|
|
local triplet_loss
|
|
|
|
|
|
|
|
function train()
|
|
|
|
print('==> doing epoch on training data:')
|
|
|
|
print("==> online epoch # " .. epoch)
|
|
|
|
|
|
|
|
batchNumber = 0
|
|
|
|
cutorch.synchronize()
|
|
|
|
|
|
|
|
-- set the dropouts to training mode
|
|
|
|
model:training()
|
|
|
|
model:cuda() -- get it back on the right GPUs.
|
|
|
|
|
|
|
|
local tm = torch.Timer()
|
|
|
|
triplet_loss = 0
|
|
|
|
|
|
|
|
local i = 1
|
|
|
|
while batchNumber < opt.epochSize do
|
|
|
|
-- queue jobs to data-workers
|
|
|
|
donkeys:addjob(
|
|
|
|
-- the job callback (runs in data-worker thread)
|
|
|
|
function()
|
|
|
|
local inputs, numPerClass = trainLoader:samplePeople(opt.peoplePerBatch,
|
|
|
|
opt.imagesPerPerson)
|
2015-11-07 06:09:47 +08:00
|
|
|
inputs = inputs:float()
|
|
|
|
numPerClass = numPerClass:float()
|
2015-09-24 07:49:45 +08:00
|
|
|
return sendTensor(inputs), sendTensor(numPerClass)
|
|
|
|
end,
|
|
|
|
-- the end callback (runs in the main thread)
|
|
|
|
trainBatch
|
|
|
|
)
|
|
|
|
if i % 5 == 0 then
|
|
|
|
donkeys:synchronize()
|
|
|
|
end
|
|
|
|
i = i + 1
|
|
|
|
end
|
|
|
|
|
|
|
|
donkeys:synchronize()
|
|
|
|
cutorch.synchronize()
|
|
|
|
|
|
|
|
triplet_loss = triplet_loss / batchNumber
|
|
|
|
|
|
|
|
trainLogger:add{
|
|
|
|
['avg triplet loss (train set)'] = triplet_loss,
|
|
|
|
}
|
|
|
|
print(string.format('Epoch: [%d][TRAINING SUMMARY] Total Time(s): %.2f\t'
|
|
|
|
.. 'average triplet loss (per batch): %.2f',
|
|
|
|
epoch, tm:time().real, triplet_loss))
|
|
|
|
print('\n')
|
|
|
|
|
|
|
|
collectgarbage()
|
|
|
|
|
|
|
|
local function sanitize(net)
|
2015-10-27 05:41:36 +08:00
|
|
|
net:apply(function (val)
|
2015-09-24 07:49:45 +08:00
|
|
|
for name,field in pairs(val) do
|
|
|
|
if torch.type(field) == 'cdata' then val[name] = nil end
|
|
|
|
if name == 'homeGradBuffers' then val[name] = nil end
|
|
|
|
if name == 'input_gpu' then val['input_gpu'] = {} end
|
|
|
|
if name == 'gradOutput_gpu' then val['gradOutput_gpu'] = {} end
|
|
|
|
if name == 'gradInput_gpu' then val['gradInput_gpu'] = {} end
|
|
|
|
if (name == 'output' or name == 'gradInput')
|
|
|
|
and torch.type(field) == 'torch.CudaTensor' then
|
|
|
|
cutorch.withDevice(field:getDevice(), function() val[name] = field.new() end)
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end)
|
|
|
|
end
|
|
|
|
sanitize(model)
|
|
|
|
torch.save(paths.concat(opt.save, 'model_' .. epoch .. '.t7'),
|
|
|
|
model.modules[1]:float())
|
|
|
|
torch.save(paths.concat(opt.save, 'optimState_' .. epoch .. '.t7'), optimState)
|
|
|
|
collectgarbage()
|
|
|
|
end -- of train()
|
|
|
|
|
|
|
|
local inputsCPU = torch.FloatTensor()
|
|
|
|
local numPerClass = torch.FloatTensor()
|
|
|
|
|
|
|
|
local timer = torch.Timer()
|
|
|
|
function trainBatch(inputsThread, numPerClassThread)
|
|
|
|
if batchNumber >= opt.epochSize then
|
|
|
|
return
|
|
|
|
end
|
|
|
|
|
|
|
|
cutorch.synchronize()
|
|
|
|
timer:reset()
|
|
|
|
receiveTensor(inputsThread, inputsCPU)
|
|
|
|
receiveTensor(numPerClassThread, numPerClass)
|
|
|
|
|
|
|
|
-- inputs:resize(inputsCPU:size()):copy(inputsCPU)
|
|
|
|
|
|
|
|
local numImages = inputsCPU:size(1)
|
|
|
|
local embeddings = torch.Tensor(numImages, 128)
|
|
|
|
local singleNet = model.modules[1]
|
|
|
|
local beginIdx = 1
|
|
|
|
local inputs = torch.CudaTensor()
|
|
|
|
while beginIdx <= numImages do
|
|
|
|
local endIdx = math.min(beginIdx+opt.batchSize-1, numImages)
|
|
|
|
local range = {{beginIdx,endIdx}}
|
|
|
|
local sz = inputsCPU[range]:size()
|
|
|
|
inputs:resize(sz):copy(inputsCPU[range])
|
|
|
|
local reps = singleNet:forward(inputs):float()
|
|
|
|
embeddings[range] = reps
|
|
|
|
|
|
|
|
beginIdx = endIdx + 1
|
|
|
|
end
|
|
|
|
assert(beginIdx - 1 == numImages)
|
|
|
|
|
|
|
|
local numTrips = numImages - opt.peoplePerBatch
|
|
|
|
local as = torch.Tensor(numTrips, inputs:size(2),
|
|
|
|
inputs:size(3), inputs:size(4))
|
|
|
|
local ps = torch.Tensor(numTrips, inputs:size(2),
|
|
|
|
inputs:size(3), inputs:size(4))
|
|
|
|
local ns = torch.Tensor(numTrips, inputs:size(2),
|
|
|
|
inputs:size(3), inputs:size(4))
|
|
|
|
|
|
|
|
function dist(emb1, emb2)
|
|
|
|
local d = emb1 - emb2
|
|
|
|
return d:cmul(d):sum()
|
|
|
|
end
|
|
|
|
|
|
|
|
local tripIdx = 1
|
|
|
|
local shuffle = torch.randperm(numTrips)
|
|
|
|
local embStartIdx = 1
|
|
|
|
for i = 1,opt.peoplePerBatch do
|
|
|
|
local n = numPerClass[i]
|
|
|
|
for j = 1,n-1 do
|
|
|
|
local aIdx = embStartIdx
|
|
|
|
local pIdx = embStartIdx+j
|
|
|
|
as[shuffle[tripIdx]] = inputsCPU[aIdx]
|
|
|
|
ps[shuffle[tripIdx]] = inputsCPU[pIdx]
|
|
|
|
|
|
|
|
-- Select a semi-hard negative that has a distance
|
|
|
|
-- further away from the positive exemplar.
|
|
|
|
local posDist = dist(embeddings[aIdx], embeddings[pIdx])
|
|
|
|
|
|
|
|
local selNegIdx = embStartIdx
|
|
|
|
while selNegIdx >= embStartIdx and selNegIdx <= embStartIdx+n-1 do
|
|
|
|
selNegIdx = (torch.random() % numImages) + 1
|
|
|
|
end
|
|
|
|
local selNegDist = dist(embeddings[aIdx], embeddings[selNegIdx])
|
|
|
|
for k = 1,numImages do
|
|
|
|
if k < embStartIdx or k > embStartIdx+n-1 then
|
|
|
|
local negDist = dist(embeddings[aIdx], embeddings[k])
|
2015-11-07 01:10:25 +08:00
|
|
|
if posDist < negDist and negDist < selNegDist and
|
|
|
|
math.abs(posDist - negDist) < alpha then
|
2015-09-24 07:49:45 +08:00
|
|
|
selNegDist = negDist
|
|
|
|
selNegIdx = k
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
ns[shuffle[tripIdx]] = inputsCPU[selNegIdx]
|
|
|
|
|
|
|
|
tripIdx = tripIdx + 1
|
|
|
|
end
|
|
|
|
embStartIdx = embStartIdx + n
|
|
|
|
end
|
|
|
|
assert(embStartIdx - 1 == numImages)
|
|
|
|
assert(tripIdx - 1 == numTrips)
|
|
|
|
|
|
|
|
|
|
|
|
local beginIdx = 1
|
|
|
|
local asCuda = torch.CudaTensor()
|
|
|
|
local psCuda = torch.CudaTensor()
|
|
|
|
local nsCuda = torch.CudaTensor()
|
|
|
|
|
|
|
|
-- Return early if the loss is 0 for `numZeros` iterations.
|
|
|
|
local numZeros = 4
|
|
|
|
local zeroCounts = torch.IntTensor(numZeros):zero()
|
|
|
|
local zeroIdx = 1
|
|
|
|
|
|
|
|
-- Return early if the loss shrinks too much.
|
|
|
|
-- local firstLoss = nil
|
|
|
|
|
|
|
|
-- TODO: Should be <=, but batches with just one image cause errors.
|
|
|
|
while beginIdx < numTrips do
|
|
|
|
local endIdx = math.min(beginIdx+opt.batchSize, numTrips)
|
|
|
|
|
|
|
|
local range = {{beginIdx,endIdx}}
|
|
|
|
local sz = as[range]:size()
|
|
|
|
asCuda:resize(sz):copy(as[range])
|
|
|
|
psCuda:resize(sz):copy(ps[range])
|
|
|
|
nsCuda:resize(sz):copy(ns[range])
|
|
|
|
local err, outputs = optimator:optimizeTriplet(optimMethod,
|
|
|
|
{asCuda, psCuda, nsCuda},
|
|
|
|
criterion)
|
|
|
|
|
|
|
|
cutorch.synchronize()
|
|
|
|
batchNumber = batchNumber + 1
|
|
|
|
print(('Epoch: [%d][%d/%d]\tTime %.3f\ttripErr %.2e'):format(
|
|
|
|
epoch, batchNumber, opt.epochSize, timer:time().real, err))
|
|
|
|
timer:reset()
|
|
|
|
triplet_loss = triplet_loss + err
|
|
|
|
|
|
|
|
-- Return early if the epoch is over.
|
|
|
|
if batchNumber >= opt.epochSize then
|
|
|
|
return
|
|
|
|
end
|
|
|
|
|
|
|
|
-- Return early if the loss is 0 for `numZeros` iterations.
|
|
|
|
zeroCounts[zeroIdx] = (err == 0.0) and 1 or 0 -- Boolean to int.
|
|
|
|
zeroIdx = (zeroIdx % numZeros) + 1
|
|
|
|
if zeroCounts:sum() == numZeros then
|
|
|
|
return
|
|
|
|
end
|
|
|
|
|
|
|
|
-- Return early if the loss shrinks too much.
|
|
|
|
-- if firstLoss == nil then
|
|
|
|
-- firstLoss = err
|
|
|
|
-- else
|
|
|
|
-- -- Triplets trivially satisfied if err=0
|
|
|
|
-- if err ~= 0 and firstLoss/err > 4 then
|
|
|
|
-- return
|
|
|
|
-- end
|
|
|
|
-- end
|
|
|
|
|
|
|
|
beginIdx = endIdx + 1
|
|
|
|
end
|
|
|
|
assert(beginIdx - 1 == numTrips or beginIdx == numTrips)
|
|
|
|
end
|