Actually fix nans from #127. Error if they appear.
This commit is contained in:
parent
2bad4e503e
commit
a1b3251f96
|
@ -87,19 +87,20 @@ function train()
|
|||
|
||||
collectgarbage()
|
||||
|
||||
-- Fix nans from https://github.com/cmusatyalab/openface/issues/127
|
||||
local function fixNans(x, tag)
|
||||
-- Check for nans from https://github.com/cmusatyalab/openface/issues/127
|
||||
local function checkNans(x, tag)
|
||||
local I = torch.ne(x,x)
|
||||
if torch.any(I) then
|
||||
print("Correcting NaNs in: ", tag)
|
||||
x[I] = 0.0
|
||||
print("train.lua: Error: NaNs found in: ", tag)
|
||||
os.exit(-1)
|
||||
-- x[I] = 0.0
|
||||
end
|
||||
end
|
||||
|
||||
for j, mod in ipairs(model:listModules()) do
|
||||
if torch.typename(mod) == 'nn.SpatialBatchNormalization' then
|
||||
fixNans(mod.running_mean, string.format("%d-%s-%s", j, mod, 'running_mean'))
|
||||
fixNans(mod.running_var, string.format("%d-%s-%s", j, mod, 'running_var'))
|
||||
checkNans(mod.running_mean, string.format("%d-%s-%s", j, mod, 'running_mean'))
|
||||
checkNans(mod.running_var, string.format("%d-%s-%s", j, mod, 'running_var'))
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -60,7 +60,7 @@ function optimizeNet( model, inputSize )
|
|||
local optnet_loaded, optnet = pcall(require,'optnet')
|
||||
if optnet_loaded then
|
||||
local opts = {inplace=true, mode='training', removeGradParams=false}
|
||||
local input = torch.Tensor(1,3,inputSize,inputSize)
|
||||
local input = torch.rand(2,3,inputSize,inputSize)
|
||||
if opt.cuda then
|
||||
input = input:cuda()
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue