Added a narrative to this example.

This commit is contained in:
Davis King 2016-04-10 17:30:45 -04:00
parent 6dbc78df03
commit 7d7c932f29
1 changed files with 81 additions and 14 deletions

View File

@ -1,10 +1,20 @@
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/* /*
This is an example illustrating the use of the deep learning tools from the
dlib C++ Library. In it, we will train the venerable LeNet convolutional
neural network to recognize hand written digits. The network will take as
input a small image and classify it as one of the 10 numeric digits between
0 and 9.
Train the venerable LeNet from The specific network we will run is from the paper
LeCun, Yann, et al. "Gradient-based learning applied to document recognition." LeCun, Yann, et al. "Gradient-based learning applied to document recognition."
Proceedings of the IEEE 86.11 (1998): 2278-2324. Proceedings of the IEEE 86.11 (1998): 2278-2324.
on MNIST except that we replace the sigmoid non-linearities with rectified linear units.
These tools will use CUDA and cuDNN to drastically accelerate network
training and testing. CMake should automatically find them if they are
installed and configure things appropriately. If not, the program will
still run but will be much slower to execute.
*/ */
@ -15,17 +25,22 @@
using namespace std; using namespace std;
using namespace dlib; using namespace dlib;
int main(int argc, char** argv) try int main(int argc, char** argv) try
{ {
// This example is going to run on the MNIST dataset.
if (argc != 2) if (argc != 2)
{ {
cout << "give MNIST data folder!" << endl; cout << "This example needs the MNIST dataset to run!" << endl;
cout << "You can get MNIST from http://yann.lecun.com/exdb/mnist/" << endl;
cout << "Download the 4 files that comprise the dataset, decompress them, and" << endl;
cout << "put them in a folder. Then give that folder as input to this program." << endl;
return 1; return 1;
} }
// MNIST is broken into two parts, a training set of 60000 images and a test set of
// 10000 images. Each image is labeled so we know what hand written digit is depicted.
// These next statements load the dataset into memory.
std::vector<matrix<unsigned char>> training_images; std::vector<matrix<unsigned char>> training_images;
std::vector<unsigned long> training_labels; std::vector<unsigned long> training_labels;
std::vector<matrix<unsigned char>> testing_images; std::vector<matrix<unsigned char>> testing_images;
@ -33,30 +48,80 @@ int main(int argc, char** argv) try
load_mnist_dataset(argv[1], training_images, training_labels, testing_images, testing_labels); load_mnist_dataset(argv[1], training_images, training_labels, testing_images, testing_labels);
// Now let's define the LeNet. Broadly speaking, there are 3 parts to a network
// definition. The loss layer, a bunch of computational layers, and then an input
// layer. You can see these components in the network definition below.
//
// The input layer here says the network expects to be given matrix<unsigned char>
// objects as input. In general, you can use any dlib image or matrix type here, or
// even define your own types by creating custom input layers.
//
// Then the middle layers define the computation the network will do to transform the
// input into whatever we want. Here we run the image through multiple convolutions, ReLU
// units, max pooling operations, and then finally a fully connected layer that converts
// the whole thing into just 10 numbers.
//
// Finally, the loss layer defines the relationship between the network outputs, our 10
// numbers, and the labels in our dataset. Since we selected loss_multiclass_log it
// means we want to do multiclass classification with our network. Moreover, the
// number of network outputs (i.e. 10) is the number of possible labels and whichever
// network output is biggest is the predicted label. So for example, if the first
// network output is largest then the predicted digit is 0, if the last network output
// is largest then the predicted digit is 9.
using net_type = loss_multiclass_log< using net_type = loss_multiclass_log<
fc<10, fc<10,
relu<fc<84, relu<fc<84,
relu<fc<120, relu<fc<120,
max_pool<2,2,2,2,relu<con<16,5,5,1,1, max_pool<2,2,2,2,relu<con<16,5,5,1,1,
max_pool<2,2,2,2,relu<con<6,5,5,1,1, max_pool<2,2,2,2,relu<con<6,5,5,1,1,
input<matrix<unsigned char>>>>>>>>>>>>>>; input<matrix<unsigned char>>
>>>>>>>>>>>>;
// This net_type defines the entire network architecture. For example, the block
// relu<fc<84,SUBNET>> means we take the output from the subnetwork, pass it through a
// fully connected layer with 84 outputs, then apply ReLU. Similarly, a block of
// max_pool<2,2,2,2,relu<con<16,5,5,1,1,SUBNET>>> means we apply 16 convolutions with a
// 5x5 filter size and 1x1 stride to the output of a subnetwork, then apply ReLU, then
// perform max pooling with a 2x2 window and 2x2 stride.
// So with that out of the way, we can make a network instance.
net_type net; net_type net;
// And then train it using the MNIST data. The code below uses mini-batch stochastic
// gradient descent with an initial learning rate of 0.01 to accomplish this.
dnn_trainer<net_type> trainer(net,sgd(0.01)); dnn_trainer<net_type> trainer(net,sgd(0.01));
trainer.set_mini_batch_size(128); trainer.set_mini_batch_size(128);
trainer.be_verbose(); trainer.be_verbose();
// Since DNN training can take a long time, we can ask the trainer to save its state to
// a file named "mnist_sync" every 20 seconds. This way, if we kill this program and
// start it again it will begin where it left off rather than restarting the training
// from scratch.
trainer.set_synchronization_file("mnist_sync", std::chrono::seconds(20)); trainer.set_synchronization_file("mnist_sync", std::chrono::seconds(20));
// Finally, this line begins training. By default, it runs SGD with our specified step
// size until the loss stops decreasing. Then it reduces the step size by a factor of
// 10 and continues running until loss stops decreasing again. It will reduce the step
// size 3 times and then terminate. For a longer discussion see the documentation for
// the dnn_trainer object.
trainer.train(training_images, training_labels); trainer.train(training_images, training_labels);
// At this point our net object should have learned how to classify MNIST images. But
// before we try it out let's save it to disk. Note that, since the trainer has been
// running images through the network, net will have a bunch of state in it related to
// the last image it processed (e.g. outputs from each layer). Since we don't care
// about saving that kind of stuff to disk we can tell the network to forget about that
// kind of transient data so that our file will be smaller. We do this by "cleaning"
// the network before saving it.
net.clean(); net.clean();
serialize("mnist_network.dat") << net; serialize("mnist_network.dat") << net;
// Run the net on all the data to get predictions
// Now let's run the training images through the network. This statement runs all the
// images through it and asks the loss layer to convert the network's raw output into
// labels. In our case, these labels are the numbers between 0 and 9.
std::vector<unsigned long> predicted_labels = net(training_images); std::vector<unsigned long> predicted_labels = net(training_images);
int num_right = 0; int num_right = 0;
int num_wrong = 0; int num_wrong = 0;
// And then let's see if it classified them correctly.
for (size_t i = 0; i < training_images.size(); ++i) for (size_t i = 0; i < training_images.size(); ++i)
{ {
if (predicted_labels[i] == training_labels[i]) if (predicted_labels[i] == training_labels[i])
@ -69,6 +134,8 @@ int main(int argc, char** argv) try
cout << "training num_wrong: " << num_wrong << endl; cout << "training num_wrong: " << num_wrong << endl;
cout << "training accuracy: " << num_right/(double)(num_right+num_wrong) << endl; cout << "training accuracy: " << num_right/(double)(num_right+num_wrong) << endl;
// Let's also see if the network can correctly classify the testing images. Since
// MNIST is an easy dataset, we should see at least 99% accuracy.
predicted_labels = net(testing_images); predicted_labels = net(testing_images);
num_right = 0; num_right = 0;
num_wrong = 0; num_wrong = 0;