Train: Detect and correct edge case when peoplePerBatch > number of classes.
Thanks @venturamartin90
This commit is contained in:
parent
27e2172ea0
commit
9bef1e8a60
|
@ -28,6 +28,14 @@ paths.dofile('train.lua')
|
|||
paths.dofile('test.lua')
|
||||
paths.dofile('util.lua')
|
||||
|
||||
if opt.peoplePerBatch > #trainLoader.classes then
|
||||
print('\n\nWarning: opt.peoplePerBatch > number of classes.')
|
||||
print(' + opt.peoplePerBatch: ', opt.peoplePerBatch)
|
||||
print(' + number of classes: ', #trainLoader.classes)
|
||||
print('Setting opt.peoplePerBatch to the number of classes.\n\n')
|
||||
opt.peoplePerBatch = #trainLoader.classes
|
||||
end
|
||||
|
||||
epoch = opt.epochNumber
|
||||
|
||||
-- test()
|
||||
|
|
|
@ -167,7 +167,8 @@ function trainBatch(inputsThread, numPerClassThread)
|
|||
for k = 1,numImages do
|
||||
if k < embStartIdx or k > embStartIdx+n-1 then
|
||||
local negDist = dist(embeddings[aIdx], embeddings[k])
|
||||
if posDist < negDist and negDist < selNegDist and negDist < alpha then
|
||||
if posDist < negDist and negDist < selNegDist and
|
||||
math.abs(posDist - negDist) < alpha then
|
||||
selNegDist = negDist
|
||||
selNegIdx = k
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue