Training: Suggestion 2 from @melgor in #48.
This commit is contained in:
parent
30f754176c
commit
aed184d860
|
@ -149,6 +149,7 @@ function trainBatch(inputsThread, numPerClassThread)
|
|||
local tripIdx = 1
|
||||
local shuffle = torch.randperm(numTrips)
|
||||
local embStartIdx = 1
|
||||
local nRandomNegs = 0
|
||||
for i = 1,opt.peoplePerBatch do
|
||||
local n = numPerClass[i]
|
||||
for j = 1,n-1 do
|
||||
|
@ -166,16 +167,23 @@ function trainBatch(inputsThread, numPerClassThread)
|
|||
selNegIdx = (torch.random() % numImages) + 1
|
||||
end
|
||||
local selNegDist = dist(embeddings[aIdx], embeddings[selNegIdx])
|
||||
local randomNeg = true
|
||||
for k = 1,numImages do
|
||||
if k < embStartIdx or k > embStartIdx+n-1 then
|
||||
local negDist = dist(embeddings[aIdx], embeddings[k])
|
||||
if posDist < negDist and negDist < selNegDist and
|
||||
math.abs(posDist - negDist) < alpha then
|
||||
if posDist < negDist and negDist < selNegDist then
|
||||
randomNeg = false
|
||||
selNegDist = negDist
|
||||
selNegIdx = k
|
||||
if math.abs(posDist-negDist) < alpha then
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
if randomNeg then
|
||||
nRandomNegs = nRandomNegs + 1
|
||||
end
|
||||
|
||||
ns[shuffle[tripIdx]] = inputsCPU[selNegIdx]
|
||||
|
||||
|
@ -185,6 +193,7 @@ function trainBatch(inputsThread, numPerClassThread)
|
|||
end
|
||||
assert(embStartIdx - 1 == numImages)
|
||||
assert(tripIdx - 1 == numTrips)
|
||||
print((' + (nRandomNegs, nTrips) = (%d, %d)'):format(nRandomNegs, numTrips))
|
||||
|
||||
|
||||
local beginIdx = 1
|
||||
|
|
Loading…
Reference in New Issue