mirror of https://github.com/davisking/dlib.git
Added try/catch block to main().
This commit is contained in:
parent
8c8c5bf3ce
commit
d1b579f09e
|
@ -170,79 +170,87 @@ void make_training_examples(
|
|||
|
||||
int main()
|
||||
{
|
||||
// Get the training samples we defined above.
|
||||
dlib::array<graph_type> samples;
|
||||
std::vector<std::vector<bool> > labels;
|
||||
make_training_examples(samples, labels);
|
||||
try
|
||||
{
|
||||
// Get the training samples we defined above.
|
||||
dlib::array<graph_type> samples;
|
||||
std::vector<std::vector<bool> > labels;
|
||||
make_training_examples(samples, labels);
|
||||
|
||||
|
||||
// Create a structural SVM trainer for graph labeling problems. The vector_type
|
||||
// needs to be set to a type capable of holding node or edge vectors.
|
||||
typedef matrix<double,0,1> vector_type;
|
||||
structural_graph_labeling_trainer<vector_type> trainer;
|
||||
// This is the usual SVM C parameter. Larger values make the trainer try
|
||||
// harder to fit the training data but might result in overfitting. You
|
||||
// should set this value to whatever gives the best cross-validation results.
|
||||
trainer.set_c(10);
|
||||
// Create a structural SVM trainer for graph labeling problems. The vector_type
|
||||
// needs to be set to a type capable of holding node or edge vectors.
|
||||
typedef matrix<double,0,1> vector_type;
|
||||
structural_graph_labeling_trainer<vector_type> trainer;
|
||||
// This is the usual SVM C parameter. Larger values make the trainer try
|
||||
// harder to fit the training data but might result in overfitting. You
|
||||
// should set this value to whatever gives the best cross-validation results.
|
||||
trainer.set_c(10);
|
||||
|
||||
// Do 3-fold cross-validation and print the results. In this case it will
|
||||
// indicate that all nodes were correctly classified.
|
||||
cout << "3-fold cross-validation: " << cross_validate_graph_labeling_trainer(trainer, samples, labels, 3) << endl;
|
||||
// Do 3-fold cross-validation and print the results. In this case it will
|
||||
// indicate that all nodes were correctly classified.
|
||||
cout << "3-fold cross-validation: " << cross_validate_graph_labeling_trainer(trainer, samples, labels, 3) << endl;
|
||||
|
||||
// Since the trainer is working well. Lets have it make a graph_labeler
|
||||
// based on the training data.
|
||||
graph_labeler<vector_type> labeler = trainer.train(samples, labels);
|
||||
// Since the trainer is working well. Lets have it make a graph_labeler
|
||||
// based on the training data.
|
||||
graph_labeler<vector_type> labeler = trainer.train(samples, labels);
|
||||
|
||||
|
||||
/*
|
||||
Lets try the graph_labeler on a new test graph. In particular, lets
|
||||
use one with 5 nodes as shown below:
|
||||
/*
|
||||
Lets try the graph_labeler on a new test graph. In particular, lets
|
||||
use one with 5 nodes as shown below:
|
||||
|
||||
(0 F)-----(1 T)
|
||||
| |
|
||||
| |
|
||||
| |
|
||||
(3 T)-----(2 T)------(4 T)
|
||||
(0 F)-----(1 T)
|
||||
| |
|
||||
| |
|
||||
| |
|
||||
(3 T)-----(2 T)------(4 T)
|
||||
|
||||
I have annotated each node with either T or F to indicate the correct
|
||||
output (true or false).
|
||||
*/
|
||||
graph_type g;
|
||||
g.set_number_of_nodes(5);
|
||||
g.node(0).data = 1, 0; // Node data indicates a false node.
|
||||
g.node(1).data = 0, 1; // Node data indicates a true node.
|
||||
g.node(2).data = 0, 0; // Node data is ambiguous.
|
||||
g.node(3).data = 0, 0; // Node data is ambiguous.
|
||||
g.node(4).data = 0.1, 0; // Node data slightly indicates a false node.
|
||||
I have annotated each node with either T or F to indicate the correct
|
||||
output (true or false).
|
||||
*/
|
||||
graph_type g;
|
||||
g.set_number_of_nodes(5);
|
||||
g.node(0).data = 1, 0; // Node data indicates a false node.
|
||||
g.node(1).data = 0, 1; // Node data indicates a true node.
|
||||
g.node(2).data = 0, 0; // Node data is ambiguous.
|
||||
g.node(3).data = 0, 0; // Node data is ambiguous.
|
||||
g.node(4).data = 0.1, 0; // Node data slightly indicates a false node.
|
||||
|
||||
g.add_edge(0,1);
|
||||
g.add_edge(1,2);
|
||||
g.add_edge(2,3);
|
||||
g.add_edge(3,0);
|
||||
g.add_edge(2,4);
|
||||
g.add_edge(0,1);
|
||||
g.add_edge(1,2);
|
||||
g.add_edge(2,3);
|
||||
g.add_edge(3,0);
|
||||
g.add_edge(2,4);
|
||||
|
||||
// Set the edges up so nodes 1, 2, 3, and 4 are all strongly connected.
|
||||
edge(g,0,1) = 0;
|
||||
edge(g,1,2) = 1;
|
||||
edge(g,2,3) = 1;
|
||||
edge(g,3,0) = 0;
|
||||
edge(g,2,4) = 1;
|
||||
// Set the edges up so nodes 1, 2, 3, and 4 are all strongly connected.
|
||||
edge(g,0,1) = 0;
|
||||
edge(g,1,2) = 1;
|
||||
edge(g,2,3) = 1;
|
||||
edge(g,3,0) = 0;
|
||||
edge(g,2,4) = 1;
|
||||
|
||||
// The output of this shows all the nodes are correctly labeled.
|
||||
cout << "Predicted labels: " << endl;
|
||||
std::vector<bool> temp = labeler(g);
|
||||
for (unsigned long i = 0; i < temp.size(); ++i)
|
||||
cout << " " << i << ": " << temp[i] << endl;
|
||||
// The output of this shows all the nodes are correctly labeled.
|
||||
cout << "Predicted labels: " << endl;
|
||||
std::vector<bool> temp = labeler(g);
|
||||
for (unsigned long i = 0; i < temp.size(); ++i)
|
||||
cout << " " << i << ": " << temp[i] << endl;
|
||||
|
||||
|
||||
|
||||
// Breaking the strong labeling consistency link between node 1 and 2 causes
|
||||
// nodes 2, 3, and 4 to flip to false. This is because of their connection
|
||||
// to node 4 which has a small preference for false.
|
||||
edge(g,1,2) = 0;
|
||||
cout << "Predicted labels: " << endl;
|
||||
temp = labeler(g);
|
||||
for (unsigned long i = 0; i < temp.size(); ++i)
|
||||
cout << " " << i << ": " << temp[i] << endl;
|
||||
// Breaking the strong labeling consistency link between node 1 and 2 causes
|
||||
// nodes 2, 3, and 4 to flip to false. This is because of their connection
|
||||
// to node 4 which has a small preference for false.
|
||||
edge(g,1,2) = 0;
|
||||
cout << "Predicted labels: " << endl;
|
||||
temp = labeler(g);
|
||||
for (unsigned long i = 0; i < temp.size(); ++i)
|
||||
cout << " " << i << ": " << temp[i] << endl;
|
||||
}
|
||||
catch (std::exception& e)
|
||||
{
|
||||
cout << "Error, an exception was thrown!" << endl;
|
||||
cout << e.what() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue