Training: Fix error with tensors being unexpectedly doubles.

This commit is contained in:
Brandon Amos 2015-11-06 17:09:47 -05:00
parent e2b72e3b7c
commit ae348c22f9
2 changed files with 3 additions and 0 deletions

View File

@ -39,6 +39,7 @@ function test()
donkeys:addjob(
function()
local inputs, labels = testLoader:sampleTriplet(opt.batchSize)
inputs = inputs:float()
return sendTensor(inputs)
end,
testBatch

View File

@ -50,6 +50,8 @@ function train()
function()
local inputs, numPerClass = trainLoader:samplePeople(opt.peoplePerBatch,
opt.imagesPerPerson)
inputs = inputs:float()
numPerClass = numPerClass:float()
return sendTensor(inputs), sendTensor(numPerClass)
end,
-- the end callback (runs in the main thread)