diff --git a/examples/model_selection_ex.cpp b/examples/model_selection_ex.cpp index 5bd360c7c..81a975c18 100644 --- a/examples/model_selection_ex.cpp +++ b/examples/model_selection_ex.cpp @@ -78,36 +78,37 @@ int main() try // Now that we have some data we want to train on it. We are going to train a - // binary SVM with the RBF kernel to classify the data. However, there are two - // parameters to the training. These are the nu and gamma parameters. Our choice - // for these parameters will influence how good the resulting decision function is. - // To test how good a particular choice of these parameters is we can use the + // binary SVM with the RBF kernel to classify the data. However, there are + // three parameters to the training. These are the SVM C parameters for each + // class and the RBF kernel's gamma parameter. Our choice for these + // parameters will influence how good the resulting decision function is. To + // test how good a particular choice of these parameters is we can use the // cross_validate_trainer() function to perform n-fold cross validation on our - // training data. However, there is a problem with the way we have sampled our - // distribution above. The problem is that there is a definite ordering to the - // samples. That is, the first half of the samples look like they are from a - // different distribution than the second half. This would screw up the cross - // validation process, but we can fix it by randomizing the order of the samples - // with the following function call. + // training data. However, there is a problem with the way we have sampled + // our distribution above. The problem is that there is a definite ordering + // to the samples. That is, the first half of the samples look like they are + // from a different distribution than the second half. This would screw up + // the cross validation process, but we can fix it by randomizing the order of + // the samples with the following function call. randomize_samples(samples, labels); - // And now we get to the important bit. Here we define a function, // cross_validation_score(), that will do the cross-validation we // mentioned and return a number indicating how good a particular setting - // of gamma and nu is. - auto cross_validation_score = [&](const double gamma, const double nu) + // of gamma, c1, and c2 is. + auto cross_validation_score = [&](const double gamma, const double c1, const double c2) { // Make a RBF SVM trainer and tell it what the parameters are supposed to be. typedef radial_basis_kernel kernel_type; - svm_nu_trainer trainer; + svm_c_trainer trainer; trainer.set_kernel(kernel_type(gamma)); - trainer.set_nu(nu); + trainer.set_c_class1(c1); + trainer.set_c_class2(c2); // Finally, perform 10-fold cross validation and then print and return the results. matrix result = cross_validate_trainer(trainer, samples, labels, 10); - cout << "gamma: " << setw(11) << gamma << " nu: " << setw(11) << nu << " cross validation accuracy: " << result; + cout << "gamma: " << setw(11) << gamma << " c1: " << setw(11) << c1 << " c2: " << setw(11) << c2 << " cross validation accuracy: " << result; // Now return a number indicating how good the parameters are. Bigger is // better in this example. Here I'm returning the harmonic mean between the @@ -119,33 +120,26 @@ int main() try return 2*prod(result)/sum(result); }; - // The nu parameter has a maximum value that is dependent on the ratio of the +1 to -1 - // labels in the training data. This function finds that value. The 0.999 is here - // because the maximum allowable nu is strictly less than the value returned by - // maximum_nu(). So shrinking the limit a little will prevent us from hitting it. - const double max_nu = 0.999*maximum_nu(labels); - // And finally, we call this global optimizer that will search for the best parameters. - // It will call cross_validation_score() 50 times with different settings and return + // It will call cross_validation_score() 30 times with different settings and return // the best parameter setting it finds. find_max_global() uses a global optimization // method based on a combination of non-parametric global function modeling and // quadratic trust region modeling to efficiently find a global maximizer. It usually // does a good job with a relatively small number of calls to cross_validation_score(). // In this example, you should observe that it finds settings that give perfect binary - // classification on the data. + // classification of the data. auto result = find_max_global(cross_validation_score, - {1e-5, 1e-5}, // lower bound constraints on gamma and nu, respectively - {100, max_nu}, // upper bound constraints on gamma and nu, respectively - max_function_calls(50)); + {1e-5, 1e-5, 1e-5}, // lower bound constraints on gamma, c1, and c2, respectively + {100, 1e6, 1e6}, // upper bound constraints on gamma, c1, and c2, respectively + max_function_calls(30)); double best_gamma = result.x(0); - double best_nu = result.x(1); + double best_c1 = result.x(1); + double best_c2 = result.x(2); cout << " best cross-validation score: " << result.y << endl; - cout << " best gamma: " << best_gamma << " best nu: " << best_nu << endl; - - + cout << " best gamma: " << best_gamma << " best c1: " << best_c1 << " best c2: "<< best_c2 << endl; } catch (exception& e) {