openface/training/main.lua

50 lines
949 B
Lua
Raw Normal View History

#!/usr/bin/env th
require 'torch'
require 'optim'
require 'paths'
require 'xlua'
local opts = paths.dofile('opts.lua')
opt = opts.parse(arg)
print(opt)
2016-01-12 04:36:58 +08:00
if opt.cuda then
require 'cutorch'
cutorch.setDevice(1)
end
os.execute('mkdir -p ' .. opt.save)
torch.save(paths.concat(opt.save, 'opts.t7'), opt, 'ascii')
print('Saving everything to: ' .. opt.save)
torch.setdefaulttensortype('torch.FloatTensor')
torch.manualSeed(opt.manualSeed)
paths.dofile('data.lua')
paths.dofile('model.lua')
paths.dofile('train.lua')
paths.dofile('test.lua')
paths.dofile('util.lua')
if opt.peoplePerBatch > nClasses then
print('\n\nError: opt.peoplePerBatch > number of classes. Please decrease this value.')
print(' + opt.peoplePerBatch: ', opt.peoplePerBatch)
print(' + number of classes: ', nClasses)
os.exit(-1)
end
epoch = opt.epochNumber
2015-12-27 21:41:49 +08:00
for _=1,opt.nEpochs do
train()
2016-03-07 08:45:37 +08:00
if opt.test then
test()
end
epoch = epoch + 1
end