Train: Detect and correct edge case when peoplePerBatch > number of classes.
Thanks @venturamartin90
This commit is contained in:
parent
27e2172ea0
commit
9bef1e8a60
|
@ -28,6 +28,14 @@ paths.dofile('train.lua')
|
||||||
paths.dofile('test.lua')
|
paths.dofile('test.lua')
|
||||||
paths.dofile('util.lua')
|
paths.dofile('util.lua')
|
||||||
|
|
||||||
|
if opt.peoplePerBatch > #trainLoader.classes then
|
||||||
|
print('\n\nWarning: opt.peoplePerBatch > number of classes.')
|
||||||
|
print(' + opt.peoplePerBatch: ', opt.peoplePerBatch)
|
||||||
|
print(' + number of classes: ', #trainLoader.classes)
|
||||||
|
print('Setting opt.peoplePerBatch to the number of classes.\n\n')
|
||||||
|
opt.peoplePerBatch = #trainLoader.classes
|
||||||
|
end
|
||||||
|
|
||||||
epoch = opt.epochNumber
|
epoch = opt.epochNumber
|
||||||
|
|
||||||
-- test()
|
-- test()
|
||||||
|
|
|
@ -167,7 +167,8 @@ function trainBatch(inputsThread, numPerClassThread)
|
||||||
for k = 1,numImages do
|
for k = 1,numImages do
|
||||||
if k < embStartIdx or k > embStartIdx+n-1 then
|
if k < embStartIdx or k > embStartIdx+n-1 then
|
||||||
local negDist = dist(embeddings[aIdx], embeddings[k])
|
local negDist = dist(embeddings[aIdx], embeddings[k])
|
||||||
if posDist < negDist and negDist < selNegDist and negDist < alpha then
|
if posDist < negDist and negDist < selNegDist and
|
||||||
|
math.abs(posDist - negDist) < alpha then
|
||||||
selNegDist = negDist
|
selNegDist = negDist
|
||||||
selNegIdx = k
|
selNegIdx = k
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue