mirror of https://github.com/davisking/dlib.git
Switched this example to use the svm C instead of nu trainer.
This commit is contained in:
parent
0e7e433096
commit
1aa6667481
|
@ -78,36 +78,37 @@ int main() try
|
||||||
|
|
||||||
|
|
||||||
// Now that we have some data we want to train on it. We are going to train a
|
// Now that we have some data we want to train on it. We are going to train a
|
||||||
// binary SVM with the RBF kernel to classify the data. However, there are two
|
// binary SVM with the RBF kernel to classify the data. However, there are
|
||||||
// parameters to the training. These are the nu and gamma parameters. Our choice
|
// three parameters to the training. These are the SVM C parameters for each
|
||||||
// for these parameters will influence how good the resulting decision function is.
|
// class and the RBF kernel's gamma parameter. Our choice for these
|
||||||
// To test how good a particular choice of these parameters is we can use the
|
// parameters will influence how good the resulting decision function is. To
|
||||||
|
// test how good a particular choice of these parameters is we can use the
|
||||||
// cross_validate_trainer() function to perform n-fold cross validation on our
|
// cross_validate_trainer() function to perform n-fold cross validation on our
|
||||||
// training data. However, there is a problem with the way we have sampled our
|
// training data. However, there is a problem with the way we have sampled
|
||||||
// distribution above. The problem is that there is a definite ordering to the
|
// our distribution above. The problem is that there is a definite ordering
|
||||||
// samples. That is, the first half of the samples look like they are from a
|
// to the samples. That is, the first half of the samples look like they are
|
||||||
// different distribution than the second half. This would screw up the cross
|
// from a different distribution than the second half. This would screw up
|
||||||
// validation process, but we can fix it by randomizing the order of the samples
|
// the cross validation process, but we can fix it by randomizing the order of
|
||||||
// with the following function call.
|
// the samples with the following function call.
|
||||||
randomize_samples(samples, labels);
|
randomize_samples(samples, labels);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// And now we get to the important bit. Here we define a function,
|
// And now we get to the important bit. Here we define a function,
|
||||||
// cross_validation_score(), that will do the cross-validation we
|
// cross_validation_score(), that will do the cross-validation we
|
||||||
// mentioned and return a number indicating how good a particular setting
|
// mentioned and return a number indicating how good a particular setting
|
||||||
// of gamma and nu is.
|
// of gamma, c1, and c2 is.
|
||||||
auto cross_validation_score = [&](const double gamma, const double nu)
|
auto cross_validation_score = [&](const double gamma, const double c1, const double c2)
|
||||||
{
|
{
|
||||||
// Make a RBF SVM trainer and tell it what the parameters are supposed to be.
|
// Make a RBF SVM trainer and tell it what the parameters are supposed to be.
|
||||||
typedef radial_basis_kernel<sample_type> kernel_type;
|
typedef radial_basis_kernel<sample_type> kernel_type;
|
||||||
svm_nu_trainer<kernel_type> trainer;
|
svm_c_trainer<kernel_type> trainer;
|
||||||
trainer.set_kernel(kernel_type(gamma));
|
trainer.set_kernel(kernel_type(gamma));
|
||||||
trainer.set_nu(nu);
|
trainer.set_c_class1(c1);
|
||||||
|
trainer.set_c_class2(c2);
|
||||||
|
|
||||||
// Finally, perform 10-fold cross validation and then print and return the results.
|
// Finally, perform 10-fold cross validation and then print and return the results.
|
||||||
matrix<double> result = cross_validate_trainer(trainer, samples, labels, 10);
|
matrix<double> result = cross_validate_trainer(trainer, samples, labels, 10);
|
||||||
cout << "gamma: " << setw(11) << gamma << " nu: " << setw(11) << nu << " cross validation accuracy: " << result;
|
cout << "gamma: " << setw(11) << gamma << " c1: " << setw(11) << c1 << " c2: " << setw(11) << c2 << " cross validation accuracy: " << result;
|
||||||
|
|
||||||
// Now return a number indicating how good the parameters are. Bigger is
|
// Now return a number indicating how good the parameters are. Bigger is
|
||||||
// better in this example. Here I'm returning the harmonic mean between the
|
// better in this example. Here I'm returning the harmonic mean between the
|
||||||
|
@ -119,33 +120,26 @@ int main() try
|
||||||
return 2*prod(result)/sum(result);
|
return 2*prod(result)/sum(result);
|
||||||
};
|
};
|
||||||
|
|
||||||
// The nu parameter has a maximum value that is dependent on the ratio of the +1 to -1
|
|
||||||
// labels in the training data. This function finds that value. The 0.999 is here
|
|
||||||
// because the maximum allowable nu is strictly less than the value returned by
|
|
||||||
// maximum_nu(). So shrinking the limit a little will prevent us from hitting it.
|
|
||||||
const double max_nu = 0.999*maximum_nu(labels);
|
|
||||||
|
|
||||||
|
|
||||||
// And finally, we call this global optimizer that will search for the best parameters.
|
// And finally, we call this global optimizer that will search for the best parameters.
|
||||||
// It will call cross_validation_score() 50 times with different settings and return
|
// It will call cross_validation_score() 30 times with different settings and return
|
||||||
// the best parameter setting it finds. find_max_global() uses a global optimization
|
// the best parameter setting it finds. find_max_global() uses a global optimization
|
||||||
// method based on a combination of non-parametric global function modeling and
|
// method based on a combination of non-parametric global function modeling and
|
||||||
// quadratic trust region modeling to efficiently find a global maximizer. It usually
|
// quadratic trust region modeling to efficiently find a global maximizer. It usually
|
||||||
// does a good job with a relatively small number of calls to cross_validation_score().
|
// does a good job with a relatively small number of calls to cross_validation_score().
|
||||||
// In this example, you should observe that it finds settings that give perfect binary
|
// In this example, you should observe that it finds settings that give perfect binary
|
||||||
// classification on the data.
|
// classification of the data.
|
||||||
auto result = find_max_global(cross_validation_score,
|
auto result = find_max_global(cross_validation_score,
|
||||||
{1e-5, 1e-5}, // lower bound constraints on gamma and nu, respectively
|
{1e-5, 1e-5, 1e-5}, // lower bound constraints on gamma, c1, and c2, respectively
|
||||||
{100, max_nu}, // upper bound constraints on gamma and nu, respectively
|
{100, 1e6, 1e6}, // upper bound constraints on gamma, c1, and c2, respectively
|
||||||
max_function_calls(50));
|
max_function_calls(30));
|
||||||
|
|
||||||
double best_gamma = result.x(0);
|
double best_gamma = result.x(0);
|
||||||
double best_nu = result.x(1);
|
double best_c1 = result.x(1);
|
||||||
|
double best_c2 = result.x(2);
|
||||||
|
|
||||||
cout << " best cross-validation score: " << result.y << endl;
|
cout << " best cross-validation score: " << result.y << endl;
|
||||||
cout << " best gamma: " << best_gamma << " best nu: " << best_nu << endl;
|
cout << " best gamma: " << best_gamma << " best c1: " << best_c1 << " best c2: "<< best_c2 << endl;
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
catch (exception& e)
|
catch (exception& e)
|
||||||
{
|
{
|
||||||
|
|
Loading…
Reference in New Issue