Training: Suggestion 2 from @melgor in #48.
This commit is contained in:
parent
30f754176c
commit
aed184d860
|
@ -149,6 +149,7 @@ function trainBatch(inputsThread, numPerClassThread)
|
||||||
local tripIdx = 1
|
local tripIdx = 1
|
||||||
local shuffle = torch.randperm(numTrips)
|
local shuffle = torch.randperm(numTrips)
|
||||||
local embStartIdx = 1
|
local embStartIdx = 1
|
||||||
|
local nRandomNegs = 0
|
||||||
for i = 1,opt.peoplePerBatch do
|
for i = 1,opt.peoplePerBatch do
|
||||||
local n = numPerClass[i]
|
local n = numPerClass[i]
|
||||||
for j = 1,n-1 do
|
for j = 1,n-1 do
|
||||||
|
@ -166,16 +167,23 @@ function trainBatch(inputsThread, numPerClassThread)
|
||||||
selNegIdx = (torch.random() % numImages) + 1
|
selNegIdx = (torch.random() % numImages) + 1
|
||||||
end
|
end
|
||||||
local selNegDist = dist(embeddings[aIdx], embeddings[selNegIdx])
|
local selNegDist = dist(embeddings[aIdx], embeddings[selNegIdx])
|
||||||
|
local randomNeg = true
|
||||||
for k = 1,numImages do
|
for k = 1,numImages do
|
||||||
if k < embStartIdx or k > embStartIdx+n-1 then
|
if k < embStartIdx or k > embStartIdx+n-1 then
|
||||||
local negDist = dist(embeddings[aIdx], embeddings[k])
|
local negDist = dist(embeddings[aIdx], embeddings[k])
|
||||||
if posDist < negDist and negDist < selNegDist and
|
if posDist < negDist and negDist < selNegDist then
|
||||||
math.abs(posDist - negDist) < alpha then
|
randomNeg = false
|
||||||
selNegDist = negDist
|
selNegDist = negDist
|
||||||
selNegIdx = k
|
selNegIdx = k
|
||||||
|
if math.abs(posDist-negDist) < alpha then
|
||||||
|
break
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
if randomNeg then
|
||||||
|
nRandomNegs = nRandomNegs + 1
|
||||||
|
end
|
||||||
|
|
||||||
ns[shuffle[tripIdx]] = inputsCPU[selNegIdx]
|
ns[shuffle[tripIdx]] = inputsCPU[selNegIdx]
|
||||||
|
|
||||||
|
@ -185,6 +193,7 @@ function trainBatch(inputsThread, numPerClassThread)
|
||||||
end
|
end
|
||||||
assert(embStartIdx - 1 == numImages)
|
assert(embStartIdx - 1 == numImages)
|
||||||
assert(tripIdx - 1 == numTrips)
|
assert(tripIdx - 1 == numTrips)
|
||||||
|
print((' + (nRandomNegs, nTrips) = (%d, %d)'):format(nRandomNegs, numTrips))
|
||||||
|
|
||||||
|
|
||||||
local beginIdx = 1
|
local beginIdx = 1
|
||||||
|
|
Loading…
Reference in New Issue