Training: Fix error with tensors being unexpectedly doubles.
This commit is contained in:
parent
e2b72e3b7c
commit
ae348c22f9
|
@ -39,6 +39,7 @@ function test()
|
|||
donkeys:addjob(
|
||||
function()
|
||||
local inputs, labels = testLoader:sampleTriplet(opt.batchSize)
|
||||
inputs = inputs:float()
|
||||
return sendTensor(inputs)
|
||||
end,
|
||||
testBatch
|
||||
|
|
|
@ -50,6 +50,8 @@ function train()
|
|||
function()
|
||||
local inputs, numPerClass = trainLoader:samplePeople(opt.peoplePerBatch,
|
||||
opt.imagesPerPerson)
|
||||
inputs = inputs:float()
|
||||
numPerClass = numPerClass:float()
|
||||
return sendTensor(inputs), sendTensor(numPerClass)
|
||||
end,
|
||||
-- the end callback (runs in the main thread)
|
||||
|
|
Loading…
Reference in New Issue