Training: Skip battch if nTripsFound == 0.
This commit is contained in:
parent
6b6ed234f9
commit
8bbfd4da05
|
@ -149,9 +149,7 @@ function train()
|
|||
|
||||
collectgarbage()
|
||||
|
||||
|
||||
sanitize(model)
|
||||
local nnModel = cudnn_to_nn(model):float()
|
||||
local nnModel = sanitize(cudnn_to_nn(model)):float()
|
||||
torch.save(paths.concat(opt.save, 'model_' .. epoch .. '.t7'), nnModel)
|
||||
torch.save(paths.concat(opt.save, 'optimState_' .. epoch .. '.t7'), optimState)
|
||||
collectgarbage()
|
||||
|
@ -235,7 +233,13 @@ function trainBatch(inputsThread, numPerClassThread)
|
|||
embStartIdx = embStartIdx + n
|
||||
end
|
||||
assert(embStartIdx - 1 == numImages)
|
||||
print((' + (nTrips, nTripsRight) = (%d, %d)'):format(numTrips,table.getn(as_table)))
|
||||
local nTripsFound = table.getn(as_table)
|
||||
print((' + (nTrips, nTripsFound) = (%d, %d)'):format(numTrips, nTripsFound))
|
||||
|
||||
if nTripsFound == 0 then
|
||||
print("Warning: nTripsFound == 0. Skipping batch.")
|
||||
return
|
||||
end
|
||||
|
||||
local as = torch.concat(as_table):view(table.getn(as_table), opt.embSize)
|
||||
local ps = torch.concat(ps_table):view(table.getn(ps_table), opt.embSize)
|
||||
|
|
Loading…
Reference in New Issue