2015-09-24 07:49:45 +08:00
|
|
|
#!/usr/bin/env th
|
|
|
|
|
|
|
|
require 'torch'
|
|
|
|
require 'optim'
|
|
|
|
|
|
|
|
require 'paths'
|
|
|
|
|
|
|
|
require 'xlua'
|
|
|
|
|
|
|
|
local opts = paths.dofile('opts.lua')
|
|
|
|
|
|
|
|
opt = opts.parse(arg)
|
|
|
|
print(opt)
|
|
|
|
|
2016-01-12 04:36:58 +08:00
|
|
|
if opt.cuda then
|
|
|
|
require 'cutorch'
|
2016-06-15 03:55:46 +08:00
|
|
|
cutorch.setDevice(opt.device)
|
2016-01-12 04:36:58 +08:00
|
|
|
end
|
|
|
|
|
2015-09-24 07:49:45 +08:00
|
|
|
torch.save(paths.concat(opt.save, 'opts.t7'), opt, 'ascii')
|
|
|
|
print('Saving everything to: ' .. opt.save)
|
|
|
|
|
|
|
|
torch.setdefaulttensortype('torch.FloatTensor')
|
|
|
|
|
|
|
|
torch.manualSeed(opt.manualSeed)
|
|
|
|
|
|
|
|
paths.dofile('data.lua')
|
2016-04-03 14:48:50 +08:00
|
|
|
paths.dofile('util.lua')
|
2016-06-29 17:48:03 +08:00
|
|
|
model = nil
|
|
|
|
criterion = nil
|
2015-09-24 07:49:45 +08:00
|
|
|
paths.dofile('train.lua')
|
2016-03-05 07:30:39 +08:00
|
|
|
paths.dofile('test.lua')
|
2015-09-24 07:49:45 +08:00
|
|
|
|
2015-11-08 03:35:49 +08:00
|
|
|
if opt.peoplePerBatch > nClasses then
|
|
|
|
print('\n\nError: opt.peoplePerBatch > number of classes. Please decrease this value.')
|
2015-11-07 01:10:25 +08:00
|
|
|
print(' + opt.peoplePerBatch: ', opt.peoplePerBatch)
|
2015-11-08 03:35:49 +08:00
|
|
|
print(' + number of classes: ', nClasses)
|
|
|
|
os.exit(-1)
|
2015-11-07 01:10:25 +08:00
|
|
|
end
|
|
|
|
|
2015-09-24 07:49:45 +08:00
|
|
|
epoch = opt.epochNumber
|
|
|
|
|
2015-12-27 21:41:49 +08:00
|
|
|
for _=1,opt.nEpochs do
|
2015-09-24 07:49:45 +08:00
|
|
|
train()
|
2016-03-07 10:15:10 +08:00
|
|
|
if opt.testing then
|
2016-03-07 08:45:37 +08:00
|
|
|
test()
|
|
|
|
end
|
2016-06-29 17:48:03 +08:00
|
|
|
model = saveModel(model)
|
2015-09-24 07:49:45 +08:00
|
|
|
epoch = epoch + 1
|
|
|
|
end
|