diff --git a/training/main.lua b/training/main.lua index 5df5c9d..104bd5a 100755 --- a/training/main.lua +++ b/training/main.lua @@ -14,7 +14,7 @@ print(opt) if opt.cuda then require 'cutorch' - cutorch.setDevice(1) + cutorch.setDevice(opt.device) end torch.save(paths.concat(opt.save, 'opts.t7'), opt, 'ascii') diff --git a/training/opts.lua b/training/opts.lua index 9089f85..20f67a0 100644 --- a/training/opts.lua +++ b/training/opts.lua @@ -27,6 +27,7 @@ function M.parse(arg) 'Home of dataset. Split into "train" and "val" directories that separate images by class.') cmd:option('-manualSeed', 2, 'Manually set RNG seed') cmd:option('-cuda', true, 'Use cuda.') + cmd:option('-device', 1, 'Cuda device to use.') cmd:option('-cudnn', true, 'Convert the model to cudnn.') ------------- Data options ------------------------