added more stuff to example

This commit is contained in:
Davis King 2016-03-27 10:29:30 -04:00
parent bd79b8778a
commit 47bdf95fbe
1 changed files with 22 additions and 5 deletions

View File

@ -60,8 +60,7 @@ int main(int argc, char** argv) try
);
//dnn_trainer<net_type,adam> trainer(net,adam(0.001));
dnn_trainer<net_type> trainer(net,sgd(0.1));
dnn_trainer<net_type,adam> trainer(net,adam(0.001));
trainer.be_verbose();
trainer.set_synchronization_file("mnist_resnet_sync", std::chrono::seconds(100));
std::vector<matrix<unsigned char>> mini_batch_samples;
@ -86,11 +85,29 @@ int main(int argc, char** argv) try
// wait for threaded processing to stop.
trainer.get_net();
// You can access sub layers of the network like this:
net.subnet().subnet().get_output();
layer<avg_pool>(net).get_output();
net.clean();
serialize("mnist_network.dat") << net;
serialize("mnist_res_network.dat") << net;
typedef loss_multiclass_log<fc<avg_pool<
ares<ares<ares<ares<
repeat<10,ares,
ares<
ares<
input<matrix<unsigned char>
>>>>>>>>>>> test_net_type;
test_net_type tnet = net;
// or you could deserialize the saved network
deserialize("mnist_res_network.dat") >> tnet;
// Run the net on all the data to get predictions
std::vector<unsigned long> predicted_labels = net(training_images);
std::vector<unsigned long> predicted_labels = tnet(training_images);
int num_right = 0;
int num_wrong = 0;
for (size_t i = 0; i < training_images.size(); ++i)
@ -105,7 +122,7 @@ int main(int argc, char** argv) try
cout << "training num_wrong: " << num_wrong << endl;
cout << "training accuracy: " << num_right/(double)(num_right+num_wrong) << endl;
predicted_labels = net(testing_images);
predicted_labels = tnet(testing_images);
num_right = 0;
num_wrong = 0;
for (size_t i = 0; i < testing_images.size(); ++i)