diff --git a/batch-represent/main.lua b/batch-represent/main.lua index 28c8faf..a2fdb13 100755 --- a/batch-represent/main.lua +++ b/batch-represent/main.lua @@ -21,7 +21,7 @@ torch.setdefaulttensortype('torch.FloatTensor') if opt.cuda then require 'cutorch' require 'cunn' - cutorch.setDevice(1) + cutorch.setDevice(opt.device) end opt.manualSeed = 2 diff --git a/batch-represent/opts.lua b/batch-represent/opts.lua index 2813d70..7841c28 100644 --- a/batch-represent/opts.lua +++ b/batch-represent/opts.lua @@ -24,6 +24,7 @@ function M.parse(arg) cmd:option('-imgDim', 96, 'Image dimension. nn1=224, nn4=96') cmd:option('-batchSize', 50, 'mini-batch size') cmd:option('-cuda', false, 'Use cuda') + cmd:option('-device', 1, 'Cuda device to use') cmd:option('-cache', false, 'Cache loaded data.') cmd:text()