mirror of https://github.com/davisking/dlib.git
added more stuff to example
This commit is contained in:
parent
bd79b8778a
commit
47bdf95fbe
|
@ -60,8 +60,7 @@ int main(int argc, char** argv) try
|
|||
);
|
||||
|
||||
|
||||
//dnn_trainer<net_type,adam> trainer(net,adam(0.001));
|
||||
dnn_trainer<net_type> trainer(net,sgd(0.1));
|
||||
dnn_trainer<net_type,adam> trainer(net,adam(0.001));
|
||||
trainer.be_verbose();
|
||||
trainer.set_synchronization_file("mnist_resnet_sync", std::chrono::seconds(100));
|
||||
std::vector<matrix<unsigned char>> mini_batch_samples;
|
||||
|
@ -86,11 +85,29 @@ int main(int argc, char** argv) try
|
|||
// wait for threaded processing to stop.
|
||||
trainer.get_net();
|
||||
|
||||
// You can access sub layers of the network like this:
|
||||
net.subnet().subnet().get_output();
|
||||
layer<avg_pool>(net).get_output();
|
||||
|
||||
net.clean();
|
||||
serialize("mnist_network.dat") << net;
|
||||
serialize("mnist_res_network.dat") << net;
|
||||
|
||||
|
||||
|
||||
typedef loss_multiclass_log<fc<avg_pool<
|
||||
ares<ares<ares<ares<
|
||||
repeat<10,ares,
|
||||
ares<
|
||||
ares<
|
||||
input<matrix<unsigned char>
|
||||
>>>>>>>>>>> test_net_type;
|
||||
test_net_type tnet = net;
|
||||
// or you could deserialize the saved network
|
||||
deserialize("mnist_res_network.dat") >> tnet;
|
||||
|
||||
|
||||
// Run the net on all the data to get predictions
|
||||
std::vector<unsigned long> predicted_labels = net(training_images);
|
||||
std::vector<unsigned long> predicted_labels = tnet(training_images);
|
||||
int num_right = 0;
|
||||
int num_wrong = 0;
|
||||
for (size_t i = 0; i < training_images.size(); ++i)
|
||||
|
@ -105,7 +122,7 @@ int main(int argc, char** argv) try
|
|||
cout << "training num_wrong: " << num_wrong << endl;
|
||||
cout << "training accuracy: " << num_right/(double)(num_right+num_wrong) << endl;
|
||||
|
||||
predicted_labels = net(testing_images);
|
||||
predicted_labels = tnet(testing_images);
|
||||
num_right = 0;
|
||||
num_wrong = 0;
|
||||
for (size_t i = 0; i < testing_images.size(); ++i)
|
||||
|
|
Loading…
Reference in New Issue