mirror of https://github.com/davisking/dlib.git
added more stuff to example
This commit is contained in:
parent
bd79b8778a
commit
47bdf95fbe
|
@ -60,8 +60,7 @@ int main(int argc, char** argv) try
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
||||||
//dnn_trainer<net_type,adam> trainer(net,adam(0.001));
|
dnn_trainer<net_type,adam> trainer(net,adam(0.001));
|
||||||
dnn_trainer<net_type> trainer(net,sgd(0.1));
|
|
||||||
trainer.be_verbose();
|
trainer.be_verbose();
|
||||||
trainer.set_synchronization_file("mnist_resnet_sync", std::chrono::seconds(100));
|
trainer.set_synchronization_file("mnist_resnet_sync", std::chrono::seconds(100));
|
||||||
std::vector<matrix<unsigned char>> mini_batch_samples;
|
std::vector<matrix<unsigned char>> mini_batch_samples;
|
||||||
|
@ -86,11 +85,29 @@ int main(int argc, char** argv) try
|
||||||
// wait for threaded processing to stop.
|
// wait for threaded processing to stop.
|
||||||
trainer.get_net();
|
trainer.get_net();
|
||||||
|
|
||||||
|
// You can access sub layers of the network like this:
|
||||||
|
net.subnet().subnet().get_output();
|
||||||
|
layer<avg_pool>(net).get_output();
|
||||||
|
|
||||||
net.clean();
|
net.clean();
|
||||||
serialize("mnist_network.dat") << net;
|
serialize("mnist_res_network.dat") << net;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
typedef loss_multiclass_log<fc<avg_pool<
|
||||||
|
ares<ares<ares<ares<
|
||||||
|
repeat<10,ares,
|
||||||
|
ares<
|
||||||
|
ares<
|
||||||
|
input<matrix<unsigned char>
|
||||||
|
>>>>>>>>>>> test_net_type;
|
||||||
|
test_net_type tnet = net;
|
||||||
|
// or you could deserialize the saved network
|
||||||
|
deserialize("mnist_res_network.dat") >> tnet;
|
||||||
|
|
||||||
|
|
||||||
// Run the net on all the data to get predictions
|
// Run the net on all the data to get predictions
|
||||||
std::vector<unsigned long> predicted_labels = net(training_images);
|
std::vector<unsigned long> predicted_labels = tnet(training_images);
|
||||||
int num_right = 0;
|
int num_right = 0;
|
||||||
int num_wrong = 0;
|
int num_wrong = 0;
|
||||||
for (size_t i = 0; i < training_images.size(); ++i)
|
for (size_t i = 0; i < training_images.size(); ++i)
|
||||||
|
@ -105,7 +122,7 @@ int main(int argc, char** argv) try
|
||||||
cout << "training num_wrong: " << num_wrong << endl;
|
cout << "training num_wrong: " << num_wrong << endl;
|
||||||
cout << "training accuracy: " << num_right/(double)(num_right+num_wrong) << endl;
|
cout << "training accuracy: " << num_right/(double)(num_right+num_wrong) << endl;
|
||||||
|
|
||||||
predicted_labels = net(testing_images);
|
predicted_labels = tnet(testing_images);
|
||||||
num_right = 0;
|
num_right = 0;
|
||||||
num_wrong = 0;
|
num_wrong = 0;
|
||||||
for (size_t i = 0; i < testing_images.size(); ++i)
|
for (size_t i = 0; i < testing_images.size(); ++i)
|
||||||
|
|
Loading…
Reference in New Issue