2016-01-14 03:52:52 +08:00
|
|
|
-- Copyright 2015-2016 Carnegie Mellon University
|
2015-09-24 07:49:45 +08:00
|
|
|
--
|
|
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
-- you may not use this file except in compliance with the License.
|
|
|
|
-- You may obtain a copy of the License at
|
|
|
|
--
|
|
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
--
|
|
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
-- See the License for the specific language governing permissions and
|
|
|
|
-- limitations under the License.
|
|
|
|
|
2016-01-07 05:04:10 +08:00
|
|
|
-- 2015-08-09: [Brandon Amos] Initial implementation.
|
|
|
|
-- 2016-01-04: [Bartosz Ludwiczuk] Substantial improvements at
|
|
|
|
-- https://github.com/melgor/Triplet-Learning
|
2015-11-24 05:12:59 +08:00
|
|
|
|
2015-09-24 07:49:45 +08:00
|
|
|
require 'optim'
|
|
|
|
require 'image'
|
2016-01-07 05:04:10 +08:00
|
|
|
require 'torchx' --for concetration the table of tensors
|
2016-06-29 17:48:03 +08:00
|
|
|
local optnet_loaded, optnet = pcall(require,'optnet')
|
|
|
|
local models = require 'model'
|
|
|
|
local openFaceOptim = require 'OpenFaceOptim'
|
2015-09-24 07:49:45 +08:00
|
|
|
|
|
|
|
|
2016-06-04 03:29:26 +08:00
|
|
|
local optimMethod = optim.adam
|
2015-09-24 07:49:45 +08:00
|
|
|
local optimState = {} -- Use for other algorithms like SGD
|
2016-06-29 17:48:03 +08:00
|
|
|
local optimator = nil
|
2015-09-24 07:49:45 +08:00
|
|
|
|
|
|
|
trainLogger = optim.Logger(paths.concat(opt.save, 'train.log'))
|
|
|
|
|
|
|
|
local batchNumber
|
|
|
|
local triplet_loss
|
|
|
|
|
|
|
|
function train()
|
|
|
|
print('==> doing epoch on training data:')
|
|
|
|
print("==> online epoch # " .. epoch)
|
|
|
|
batchNumber = 0
|
2016-06-29 17:48:03 +08:00
|
|
|
model,criterion = models.modelSetup(model)
|
|
|
|
optimator = openFaceOptim:__init(model, optimState)
|
2016-01-12 04:36:58 +08:00
|
|
|
if opt.cuda then
|
2016-06-29 17:48:03 +08:00
|
|
|
cutorch.synchronize()
|
2016-01-12 04:36:58 +08:00
|
|
|
end
|
2015-09-24 07:49:45 +08:00
|
|
|
model:training()
|
2016-06-29 17:48:03 +08:00
|
|
|
|
2015-09-24 07:49:45 +08:00
|
|
|
local tm = torch.Timer()
|
|
|
|
triplet_loss = 0
|
|
|
|
|
|
|
|
local i = 1
|
|
|
|
while batchNumber < opt.epochSize do
|
|
|
|
-- queue jobs to data-workers
|
|
|
|
donkeys:addjob(
|
2016-01-07 05:04:10 +08:00
|
|
|
-- the job callback (runs in data-worker thread)
|
2015-09-24 07:49:45 +08:00
|
|
|
function()
|
|
|
|
local inputs, numPerClass = trainLoader:samplePeople(opt.peoplePerBatch,
|
|
|
|
opt.imagesPerPerson)
|
2015-11-07 06:09:47 +08:00
|
|
|
inputs = inputs:float()
|
|
|
|
numPerClass = numPerClass:float()
|
2015-09-24 07:49:45 +08:00
|
|
|
return sendTensor(inputs), sendTensor(numPerClass)
|
|
|
|
end,
|
2016-01-07 05:04:10 +08:00
|
|
|
-- the end callback (runs in the main thread)
|
2015-09-24 07:49:45 +08:00
|
|
|
trainBatch
|
|
|
|
)
|
|
|
|
if i % 5 == 0 then
|
|
|
|
donkeys:synchronize()
|
|
|
|
end
|
|
|
|
i = i + 1
|
|
|
|
end
|
|
|
|
|
|
|
|
donkeys:synchronize()
|
2016-01-12 05:17:03 +08:00
|
|
|
if opt.cuda then
|
2016-06-29 17:48:03 +08:00
|
|
|
cutorch.synchronize()
|
2016-01-12 05:17:03 +08:00
|
|
|
end
|
2015-09-24 07:49:45 +08:00
|
|
|
|
|
|
|
triplet_loss = triplet_loss / batchNumber
|
|
|
|
|
|
|
|
trainLogger:add{
|
|
|
|
['avg triplet loss (train set)'] = triplet_loss,
|
|
|
|
}
|
|
|
|
print(string.format('Epoch: [%d][TRAINING SUMMARY] Total Time(s): %.2f\t'
|
|
|
|
.. 'average triplet loss (per batch): %.2f',
|
|
|
|
epoch, tm:time().real, triplet_loss))
|
|
|
|
print('\n')
|
|
|
|
|
|
|
|
collectgarbage()
|
2016-06-29 17:48:03 +08:00
|
|
|
end -- of train()
|
|
|
|
|
2016-06-04 03:29:26 +08:00
|
|
|
|
2016-06-29 17:48:03 +08:00
|
|
|
function saveModel(model)
|
2016-06-15 01:51:18 +08:00
|
|
|
-- Check for nans from https://github.com/cmusatyalab/openface/issues/127
|
|
|
|
local function checkNans(x, tag)
|
2016-06-14 01:50:33 +08:00
|
|
|
local I = torch.ne(x,x)
|
|
|
|
if torch.any(I) then
|
2016-06-15 01:51:18 +08:00
|
|
|
print("train.lua: Error: NaNs found in: ", tag)
|
|
|
|
os.exit(-1)
|
|
|
|
-- x[I] = 0.0
|
2016-06-14 01:50:33 +08:00
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2016-06-14 02:06:12 +08:00
|
|
|
for j, mod in ipairs(model:listModules()) do
|
2016-06-14 01:50:33 +08:00
|
|
|
if torch.typename(mod) == 'nn.SpatialBatchNormalization' then
|
2016-06-15 01:51:18 +08:00
|
|
|
checkNans(mod.running_mean, string.format("%d-%s-%s", j, mod, 'running_mean'))
|
|
|
|
checkNans(mod.running_var, string.format("%d-%s-%s", j, mod, 'running_var'))
|
2016-06-14 01:50:33 +08:00
|
|
|
end
|
|
|
|
end
|
2016-06-29 17:48:03 +08:00
|
|
|
if opt.cuda then
|
|
|
|
if opt.cudnn then
|
|
|
|
cudnn.convert(model, nn)
|
|
|
|
end
|
2016-03-11 18:09:23 +08:00
|
|
|
end
|
2016-06-29 17:48:03 +08:00
|
|
|
|
|
|
|
local dpt
|
|
|
|
if torch.type(model) == 'nn.DataParallelTable' then
|
|
|
|
dpt = model
|
|
|
|
model = model:get(1)
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
if optnet_loaded then
|
|
|
|
optnet.removeOptimization(model)
|
|
|
|
end
|
|
|
|
|
|
|
|
torch.save(paths.concat(opt.save, 'model_' .. epoch .. '.t7'), model:float():clearState())
|
2015-09-24 07:49:45 +08:00
|
|
|
torch.save(paths.concat(opt.save, 'optimState_' .. epoch .. '.t7'), optimState)
|
2016-06-29 17:48:03 +08:00
|
|
|
|
|
|
|
if dpt then -- OOM without this
|
|
|
|
dpt:clearState()
|
2016-06-14 01:50:33 +08:00
|
|
|
end
|
2016-06-29 17:48:03 +08:00
|
|
|
|
2015-09-24 07:49:45 +08:00
|
|
|
collectgarbage()
|
2016-06-29 17:48:03 +08:00
|
|
|
|
|
|
|
return model
|
|
|
|
end
|
2015-09-24 07:49:45 +08:00
|
|
|
|
|
|
|
local inputsCPU = torch.FloatTensor()
|
|
|
|
local numPerClass = torch.FloatTensor()
|
|
|
|
|
|
|
|
local timer = torch.Timer()
|
|
|
|
function trainBatch(inputsThread, numPerClassThread)
|
2016-06-29 17:48:03 +08:00
|
|
|
collectgarbage()
|
2016-01-12 05:17:03 +08:00
|
|
|
if batchNumber >= opt.epochSize then
|
2016-01-07 05:04:10 +08:00
|
|
|
return
|
|
|
|
end
|
|
|
|
|
2016-01-12 05:17:03 +08:00
|
|
|
if opt.cuda then
|
|
|
|
cutorch.synchronize()
|
|
|
|
end
|
2016-01-07 05:04:10 +08:00
|
|
|
timer:reset()
|
2016-06-29 17:48:03 +08:00
|
|
|
|
2016-01-07 05:04:10 +08:00
|
|
|
receiveTensor(inputsThread, inputsCPU)
|
|
|
|
receiveTensor(numPerClassThread, numPerClass)
|
|
|
|
|
2016-01-12 04:36:58 +08:00
|
|
|
local inputs
|
|
|
|
if opt.cuda then
|
|
|
|
inputs = inputsCPU:cuda()
|
|
|
|
else
|
|
|
|
inputs = inputsCPU
|
|
|
|
end
|
|
|
|
|
|
|
|
local numImages = inputs:size(1)
|
|
|
|
local embeddings = model:forward(inputs):float()
|
2016-01-07 05:04:10 +08:00
|
|
|
|
|
|
|
local as_table = {}
|
|
|
|
local ps_table = {}
|
|
|
|
local ns_table = {}
|
|
|
|
|
|
|
|
local triplet_idx = {}
|
|
|
|
local num_example_per_idx = torch.Tensor(embeddings:size(1))
|
|
|
|
num_example_per_idx:zero()
|
|
|
|
|
|
|
|
local tripIdx = 1
|
|
|
|
local embStartIdx = 1
|
|
|
|
local numTrips = 0
|
|
|
|
for i = 1,opt.peoplePerBatch do
|
|
|
|
local n = numPerClass[i]
|
2016-01-17 12:50:54 +08:00
|
|
|
for j = 1,n-1 do -- For every image in the batch.
|
2016-01-07 05:04:10 +08:00
|
|
|
local aIdx = embStartIdx + j - 1
|
|
|
|
local diff = embeddings - embeddings[{ {aIdx} }]:expandAs(embeddings)
|
2016-01-17 12:50:54 +08:00
|
|
|
local norms = diff:norm(2, 2):pow(2):squeeze()
|
|
|
|
for pair = j, n-1 do -- For every possible positive pair.
|
2016-01-07 05:04:10 +08:00
|
|
|
local pIdx = embStartIdx + pair
|
|
|
|
|
|
|
|
local fff = (embeddings[aIdx]-embeddings[pIdx]):norm(2)
|
2016-01-17 12:50:54 +08:00
|
|
|
local normsP = norms - torch.Tensor(embeddings:size(1)):fill(fff*fff)
|
|
|
|
|
|
|
|
-- Set the indices of the same class to the max so they are ignored.
|
2016-01-07 05:04:10 +08:00
|
|
|
normsP[{{embStartIdx,embStartIdx +n-1}}] = normsP:max()
|
2016-01-17 12:50:54 +08:00
|
|
|
|
|
|
|
-- Get indices of images within the margin.
|
2016-01-07 05:04:10 +08:00
|
|
|
local in_margin = normsP:lt(opt.alpha)
|
|
|
|
local allNeg = torch.find(in_margin, 1)
|
|
|
|
|
2016-01-17 12:50:54 +08:00
|
|
|
-- Use only non-random triplets.
|
|
|
|
-- Random triples (which are beyond the margin) will just produce gradient = 0,
|
|
|
|
-- so the average gradient will decrease.
|
|
|
|
if table.getn(allNeg) ~= 0 then
|
2016-01-07 05:04:10 +08:00
|
|
|
selNegIdx = allNeg[math.random (table.getn(allNeg))]
|
2016-01-17 12:50:54 +08:00
|
|
|
-- Add the embeding of each example.
|
2016-01-07 05:04:10 +08:00
|
|
|
table.insert(as_table,embeddings[aIdx])
|
|
|
|
table.insert(ps_table,embeddings[pIdx])
|
|
|
|
table.insert(ns_table,embeddings[selNegIdx])
|
2016-01-17 12:50:54 +08:00
|
|
|
-- Add the original index of triplets.
|
|
|
|
table.insert(triplet_idx, {aIdx,pIdx,selNegIdx})
|
|
|
|
-- Increase the number of times of using each example.
|
2016-01-07 05:04:10 +08:00
|
|
|
num_example_per_idx[aIdx] = num_example_per_idx[aIdx] + 1
|
|
|
|
num_example_per_idx[pIdx] = num_example_per_idx[pIdx] + 1
|
|
|
|
num_example_per_idx[selNegIdx] = num_example_per_idx[selNegIdx] + 1
|
|
|
|
tripIdx = tripIdx + 1
|
|
|
|
end
|
|
|
|
|
|
|
|
numTrips = numTrips + 1
|
2015-09-24 07:49:45 +08:00
|
|
|
end
|
2016-01-07 05:04:10 +08:00
|
|
|
end
|
|
|
|
embStartIdx = embStartIdx + n
|
|
|
|
end
|
|
|
|
assert(embStartIdx - 1 == numImages)
|
2016-01-14 03:32:43 +08:00
|
|
|
local nTripsFound = table.getn(as_table)
|
|
|
|
print((' + (nTrips, nTripsFound) = (%d, %d)'):format(numTrips, nTripsFound))
|
|
|
|
|
|
|
|
if nTripsFound == 0 then
|
|
|
|
print("Warning: nTripsFound == 0. Skipping batch.")
|
|
|
|
return
|
|
|
|
end
|
2016-01-07 05:04:10 +08:00
|
|
|
|
|
|
|
local as = torch.concat(as_table):view(table.getn(as_table), opt.embSize)
|
|
|
|
local ps = torch.concat(ps_table):view(table.getn(ps_table), opt.embSize)
|
|
|
|
local ns = torch.concat(ns_table):view(table.getn(ns_table), opt.embSize)
|
2016-06-29 17:48:03 +08:00
|
|
|
|
2016-01-12 04:36:58 +08:00
|
|
|
local apn
|
|
|
|
if opt.cuda then
|
2016-06-29 17:48:03 +08:00
|
|
|
local asCuda = torch.CudaTensor()
|
|
|
|
local psCuda = torch.CudaTensor()
|
|
|
|
local nsCuda = torch.CudaTensor()
|
2016-01-12 04:36:58 +08:00
|
|
|
|
2016-06-29 17:48:03 +08:00
|
|
|
local sz = as:size()
|
|
|
|
asCuda:resize(sz):copy(as)
|
|
|
|
psCuda:resize(sz):copy(ps)
|
|
|
|
nsCuda:resize(sz):copy(ns)
|
2016-01-12 04:36:58 +08:00
|
|
|
|
2016-06-29 17:48:03 +08:00
|
|
|
apn = {asCuda, psCuda, nsCuda}
|
2016-01-12 04:36:58 +08:00
|
|
|
else
|
2016-06-29 17:48:03 +08:00
|
|
|
apn = {as, ps, ns}
|
2016-01-12 04:36:58 +08:00
|
|
|
end
|
2016-01-07 05:04:10 +08:00
|
|
|
|
2016-01-08 07:28:05 +08:00
|
|
|
local err, _ = optimator:optimizeTriplet(
|
2016-01-12 04:36:58 +08:00
|
|
|
optimMethod, inputs, apn, criterion,
|
2016-01-08 07:28:05 +08:00
|
|
|
triplet_idx -- , num_example_per_idx
|
2016-01-07 05:04:10 +08:00
|
|
|
)
|
2016-01-12 04:36:58 +08:00
|
|
|
if opt.cuda then
|
2016-06-29 17:48:03 +08:00
|
|
|
cutorch.synchronize()
|
2016-01-12 04:36:58 +08:00
|
|
|
end
|
2016-06-29 17:48:03 +08:00
|
|
|
|
2016-01-07 05:04:10 +08:00
|
|
|
batchNumber = batchNumber + 1
|
|
|
|
print(('Epoch: [%d][%d/%d]\tTime %.3f\ttripErr %.2e'):format(
|
|
|
|
epoch, batchNumber, opt.epochSize, timer:time().real, err))
|
|
|
|
timer:reset()
|
|
|
|
triplet_loss = triplet_loss + err
|
2015-09-24 07:49:45 +08:00
|
|
|
end
|