diff --git a/training/main.lua b/training/main.lua index dcc324a..00e99d0 100755 --- a/training/main.lua +++ b/training/main.lua @@ -28,6 +28,14 @@ paths.dofile('train.lua') paths.dofile('test.lua') paths.dofile('util.lua') +if opt.peoplePerBatch > #trainLoader.classes then + print('\n\nWarning: opt.peoplePerBatch > number of classes.') + print(' + opt.peoplePerBatch: ', opt.peoplePerBatch) + print(' + number of classes: ', #trainLoader.classes) + print('Setting opt.peoplePerBatch to the number of classes.\n\n') + opt.peoplePerBatch = #trainLoader.classes +end + epoch = opt.epochNumber -- test() diff --git a/training/train.lua b/training/train.lua index 2c2dee4..3638fe5 100644 --- a/training/train.lua +++ b/training/train.lua @@ -167,7 +167,8 @@ function trainBatch(inputsThread, numPerClassThread) for k = 1,numImages do if k < embStartIdx or k > embStartIdx+n-1 then local negDist = dist(embeddings[aIdx], embeddings[k]) - if posDist < negDist and negDist < selNegDist and negDist < alpha then + if posDist < negDist and negDist < selNegDist and + math.abs(posDist - negDist) < alpha then selNegDist = negDist selNegIdx = k end