Train: Detect and correct edge case when peoplePerBatch > number of classes.

Thanks @venturamartin90
This commit is contained in:
Brandon Amos 2015-11-06 12:10:25 -05:00
parent 27e2172ea0
commit 9bef1e8a60
2 changed files with 10 additions and 1 deletions

View File

@ -28,6 +28,14 @@ paths.dofile('train.lua')
paths.dofile('test.lua')
paths.dofile('util.lua')
if opt.peoplePerBatch > #trainLoader.classes then
print('\n\nWarning: opt.peoplePerBatch > number of classes.')
print(' + opt.peoplePerBatch: ', opt.peoplePerBatch)
print(' + number of classes: ', #trainLoader.classes)
print('Setting opt.peoplePerBatch to the number of classes.\n\n')
opt.peoplePerBatch = #trainLoader.classes
end
epoch = opt.epochNumber
-- test()

View File

@ -167,7 +167,8 @@ function trainBatch(inputsThread, numPerClassThread)
for k = 1,numImages do
if k < embStartIdx or k > embStartIdx+n-1 then
local negDist = dist(embeddings[aIdx], embeddings[k])
if posDist < negDist and negDist < selNegDist and negDist < alpha then
if posDist < negDist and negDist < selNegDist and
math.abs(posDist - negDist) < alpha then
selNegDist = negDist
selNegIdx = k
end