Training: Skip battch if nTripsFound == 0.

This commit is contained in:
Brandon Amos 2016-01-13 14:32:43 -05:00
parent 6b6ed234f9
commit 8bbfd4da05
1 changed files with 8 additions and 4 deletions

View File

@ -149,9 +149,7 @@ function train()
collectgarbage()
sanitize(model)
local nnModel = cudnn_to_nn(model):float()
local nnModel = sanitize(cudnn_to_nn(model)):float()
torch.save(paths.concat(opt.save, 'model_' .. epoch .. '.t7'), nnModel)
torch.save(paths.concat(opt.save, 'optimState_' .. epoch .. '.t7'), optimState)
collectgarbage()
@ -235,7 +233,13 @@ function trainBatch(inputsThread, numPerClassThread)
embStartIdx = embStartIdx + n
end
assert(embStartIdx - 1 == numImages)
print((' + (nTrips, nTripsRight) = (%d, %d)'):format(numTrips,table.getn(as_table)))
local nTripsFound = table.getn(as_table)
print((' + (nTrips, nTripsFound) = (%d, %d)'):format(numTrips, nTripsFound))
if nTripsFound == 0 then
print("Warning: nTripsFound == 0. Skipping batch.")
return
end
local as = torch.concat(as_table):view(table.getn(as_table), opt.embSize)
local ps = torch.concat(ps_table):view(table.getn(ps_table), opt.embSize)