Training: Suggestion 2 from @melgor in #48.

This commit is contained in:
Brandon Amos 2015-11-08 17:08:06 -05:00
parent 30f754176c
commit aed184d860
1 changed files with 11 additions and 2 deletions

View File

@ -149,6 +149,7 @@ function trainBatch(inputsThread, numPerClassThread)
local tripIdx = 1
local shuffle = torch.randperm(numTrips)
local embStartIdx = 1
local nRandomNegs = 0
for i = 1,opt.peoplePerBatch do
local n = numPerClass[i]
for j = 1,n-1 do
@ -166,16 +167,23 @@ function trainBatch(inputsThread, numPerClassThread)
selNegIdx = (torch.random() % numImages) + 1
end
local selNegDist = dist(embeddings[aIdx], embeddings[selNegIdx])
local randomNeg = true
for k = 1,numImages do
if k < embStartIdx or k > embStartIdx+n-1 then
local negDist = dist(embeddings[aIdx], embeddings[k])
if posDist < negDist and negDist < selNegDist and
math.abs(posDist - negDist) < alpha then
if posDist < negDist and negDist < selNegDist then
randomNeg = false
selNegDist = negDist
selNegIdx = k
if math.abs(posDist-negDist) < alpha then
break
end
end
end
end
if randomNeg then
nRandomNegs = nRandomNegs + 1
end
ns[shuffle[tripIdx]] = inputsCPU[selNegIdx]
@ -185,6 +193,7 @@ function trainBatch(inputsThread, numPerClassThread)
end
assert(embStartIdx - 1 == numImages)
assert(tripIdx - 1 == numTrips)
print((' + (nRandomNegs, nTrips) = (%d, %d)'):format(nRandomNegs, numTrips))
local beginIdx = 1