train.lua: Pass on comments.

This commit is contained in:
Brandon Amos 2016-01-16 23:50:54 -05:00
parent ca15ee5176
commit 9b04bca65d
1 changed files with 16 additions and 14 deletions

View File

@ -194,33 +194,35 @@ function trainBatch(inputsThread, numPerClassThread)
local numTrips = 0
for i = 1,opt.peoplePerBatch do
local n = numPerClass[i]
for j = 1,n-1 do --for every image in batch
for j = 1,n-1 do -- For every image in the batch.
local aIdx = embStartIdx + j - 1
local diff = embeddings - embeddings[{ {aIdx} }]:expandAs(embeddings)
local norms = diff:norm(2, 2):pow(2):squeeze() --L2 norm have be squared
for pair = j,n-1 do --create all posible positive pairs
local norms = diff:norm(2, 2):pow(2):squeeze()
for pair = j, n-1 do -- For every possible positive pair.
local pIdx = embStartIdx + pair
-- Select a semi-hard negative that has a distance
-- further away from the positive exemplar. Oxford-Face Idea
--choose random example which is in margin
local fff = (embeddings[aIdx]-embeddings[pIdx]):norm(2)
local normsP = norms - torch.Tensor(embeddings:size(1)):fill(fff*fff) --L2 norm have be squared
--clean the idx of same class by setting to them max value
local normsP = norms - torch.Tensor(embeddings:size(1)):fill(fff*fff)
-- Set the indices of the same class to the max so they are ignored.
normsP[{{embStartIdx,embStartIdx +n-1}}] = normsP:max()
-- get indexes of example which are inside margin
-- Get indices of images within the margin.
local in_margin = normsP:lt(opt.alpha)
local allNeg = torch.find(in_margin, 1)
if table.getn(allNeg) ~= 0 then --use only non-random triplets. Random triples (which are beyond margin) will just produce gradient = 0, so average gradient will decrease
-- Use only non-random triplets.
-- Random triples (which are beyond the margin) will just produce gradient = 0,
-- so the average gradient will decrease.
if table.getn(allNeg) ~= 0 then
selNegIdx = allNeg[math.random (table.getn(allNeg))]
--get embeding of each example
-- Add the embeding of each example.
table.insert(as_table,embeddings[aIdx])
table.insert(ps_table,embeddings[pIdx])
table.insert(ns_table,embeddings[selNegIdx])
-- get original idx of triplets
table.insert(triplet_idx,{aIdx,pIdx,selNegIdx})
-- increase number of times of using each example, need for averaging then
-- Add the original index of triplets.
table.insert(triplet_idx, {aIdx,pIdx,selNegIdx})
-- Increase the number of times of using each example.
num_example_per_idx[aIdx] = num_example_per_idx[aIdx] + 1
num_example_per_idx[pIdx] = num_example_per_idx[pIdx] + 1
num_example_per_idx[selNegIdx] = num_example_per_idx[selNegIdx] + 1