mirror of https://github.com/davisking/dlib.git
Improved example
This commit is contained in:
parent
f28d2f7329
commit
f4b3c7ee0f
|
@ -14,7 +14,6 @@
|
||||||
space it's very easy to do face recognition with some kind of k-nearest
|
space it's very easy to do face recognition with some kind of k-nearest
|
||||||
neighbor classifier.
|
neighbor classifier.
|
||||||
|
|
||||||
|
|
||||||
To keep this example as simple as possible we won't do face recognition.
|
To keep this example as simple as possible we won't do face recognition.
|
||||||
Instead, we will create a very simple network and use it to learn a mapping
|
Instead, we will create a very simple network and use it to learn a mapping
|
||||||
from 8D vectors to 2D vectors such that vectors with the same class labels
|
from 8D vectors to 2D vectors such that vectors with the same class labels
|
||||||
|
@ -65,15 +64,20 @@ int main() try
|
||||||
// vectors.
|
// vectors.
|
||||||
using net_type = loss_metric<fc<2,input<matrix<double,0,1>>>>;
|
using net_type = loss_metric<fc<2,input<matrix<double,0,1>>>>;
|
||||||
net_type net;
|
net_type net;
|
||||||
// Now setup the trainer and train the network using our data.
|
|
||||||
dnn_trainer<net_type> trainer(net);
|
dnn_trainer<net_type> trainer(net);
|
||||||
trainer.set_learning_rate(0.1);
|
trainer.set_learning_rate(0.1);
|
||||||
trainer.set_min_learning_rate(0.001);
|
|
||||||
trainer.set_mini_batch_size(128);
|
|
||||||
trainer.be_verbose();
|
|
||||||
trainer.set_iterations_without_progress_threshold(100);
|
|
||||||
trainer.train(samples, labels);
|
|
||||||
|
|
||||||
|
// It should be emphasized out that it's really important that each mini-batch contain
|
||||||
|
// multiple instances of each class of object. This is because the metric learning
|
||||||
|
// algorithm needs to consider pairs of objects that should be close as well as pairs
|
||||||
|
// of objects that should be far apart during each training step. Here we just keep
|
||||||
|
// training on the same small batch so this constraint is trivially satisfied.
|
||||||
|
while(trainer.get_learning_rate() >= 1e-4)
|
||||||
|
trainer.train_one_step(samples, labels);
|
||||||
|
|
||||||
|
// Wait for training threads to stop
|
||||||
|
trainer.get_net();
|
||||||
|
cout << "done training" << endl;
|
||||||
|
|
||||||
|
|
||||||
// Run all the samples through the network to get their 2D vector embeddings.
|
// Run all the samples through the network to get their 2D vector embeddings.
|
||||||
|
|
Loading…
Reference in New Issue