Actually fix nans from #127. Error if they appear.

This commit is contained in:
Brandon Amos 2016-06-14 13:51:18 -04:00
parent 2bad4e503e
commit a1b3251f96
2 changed files with 8 additions and 7 deletions

View File

@ -87,19 +87,20 @@ function train()
collectgarbage()
-- Fix nans from https://github.com/cmusatyalab/openface/issues/127
local function fixNans(x, tag)
-- Check for nans from https://github.com/cmusatyalab/openface/issues/127
local function checkNans(x, tag)
local I = torch.ne(x,x)
if torch.any(I) then
print("Correcting NaNs in: ", tag)
x[I] = 0.0
print("train.lua: Error: NaNs found in: ", tag)
os.exit(-1)
-- x[I] = 0.0
end
end
for j, mod in ipairs(model:listModules()) do
if torch.typename(mod) == 'nn.SpatialBatchNormalization' then
fixNans(mod.running_mean, string.format("%d-%s-%s", j, mod, 'running_mean'))
fixNans(mod.running_var, string.format("%d-%s-%s", j, mod, 'running_var'))
checkNans(mod.running_mean, string.format("%d-%s-%s", j, mod, 'running_mean'))
checkNans(mod.running_var, string.format("%d-%s-%s", j, mod, 'running_var'))
end
end

View File

@ -60,7 +60,7 @@ function optimizeNet( model, inputSize )
local optnet_loaded, optnet = pcall(require,'optnet')
if optnet_loaded then
local opts = {inplace=true, mode='training', removeGradParams=false}
local input = torch.Tensor(1,3,inputSize,inputSize)
local input = torch.rand(2,3,inputSize,inputSize)
if opt.cuda then
input = input:cuda()
end