diff --git a/training/train.lua b/training/train.lua index b8939c1..4ea7f02 100644 --- a/training/train.lua +++ b/training/train.lua @@ -149,6 +149,7 @@ function trainBatch(inputsThread, numPerClassThread) local tripIdx = 1 local shuffle = torch.randperm(numTrips) local embStartIdx = 1 + local nRandomNegs = 0 for i = 1,opt.peoplePerBatch do local n = numPerClass[i] for j = 1,n-1 do @@ -166,16 +167,23 @@ function trainBatch(inputsThread, numPerClassThread) selNegIdx = (torch.random() % numImages) + 1 end local selNegDist = dist(embeddings[aIdx], embeddings[selNegIdx]) + local randomNeg = true for k = 1,numImages do if k < embStartIdx or k > embStartIdx+n-1 then local negDist = dist(embeddings[aIdx], embeddings[k]) - if posDist < negDist and negDist < selNegDist and - math.abs(posDist - negDist) < alpha then + if posDist < negDist and negDist < selNegDist then + randomNeg = false selNegDist = negDist selNegIdx = k + if math.abs(posDist-negDist) < alpha then + break + end end end end + if randomNeg then + nRandomNegs = nRandomNegs + 1 + end ns[shuffle[tripIdx]] = inputsCPU[selNegIdx] @@ -185,6 +193,7 @@ function trainBatch(inputsThread, numPerClassThread) end assert(embStartIdx - 1 == numImages) assert(tripIdx - 1 == numTrips) + print((' + (nRandomNegs, nTrips) = (%d, %d)'):format(nRandomNegs, numTrips)) local beginIdx = 1