mirror of https://github.com/davisking/dlib.git
Updated examples to work with new ridge regression interface.
This commit is contained in:
parent
068bf89d34
commit
8d6cec1d36
|
@ -85,9 +85,11 @@ int main()
|
|||
krr_trainer<kernel_type> trainer;
|
||||
|
||||
// The krr_trainer has the ability to perform leave-one-out cross-validation.
|
||||
// This function tells it to measure errors in terms of the number of classification
|
||||
// mistakes instead of mean squared error between decision function output values
|
||||
// and labels. Which is what we want to do since we are performing classification.
|
||||
// It does this to automatically determine the regularization parameter. Since
|
||||
// we are performing classification instead of regression we should be sure to
|
||||
// call use_classification_loss_for_loo_cv(). This function tells it to measure
|
||||
// errors in terms of the number of classification mistakes instead of mean squared
|
||||
// error between decision function output values and labels.
|
||||
trainer.use_classification_loss_for_loo_cv();
|
||||
|
||||
|
||||
|
@ -98,11 +100,14 @@ int main()
|
|||
// tell the trainer the parameters we want to use
|
||||
trainer.set_kernel(kernel_type(gamma));
|
||||
|
||||
double loo_error;
|
||||
trainer.train(samples, labels, loo_error);
|
||||
// loo_values will contain the LOO predictions for each sample. In the case
|
||||
// of perfect prediction it will end up being a copy of labels.
|
||||
std::vector<double> loo_values;
|
||||
trainer.train(samples, labels, loo_values);
|
||||
|
||||
// Print gamma and the fraction of samples misclassified during LOO cross-validation.
|
||||
cout << "gamma: " << gamma << " LOO error: " << loo_error << endl;
|
||||
// Print gamma and the fraction of samples correctly classified during LOO cross-validation.
|
||||
const double classification_accuracy = mean_sign_agreement(labels, loo_values);
|
||||
cout << "gamma: " << gamma << " LOO accuracy: " << classification_accuracy << endl;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -78,14 +78,16 @@ int main()
|
|||
// column is the output from the krr estimate.
|
||||
|
||||
|
||||
// Note that the krr_trainer has the ability to tell us the leave-one-out cross-validation
|
||||
// accuracy. The train() function has an optional 3rd argument and if we give it a double
|
||||
// it will give us back the LOO error.
|
||||
double loo_error;
|
||||
trainer.train(samples, labels, loo_error);
|
||||
cout << "mean squared LOO error: " << loo_error << endl;
|
||||
// Note that the krr_trainer has the ability to tell us the leave-one-out predictions
|
||||
// for each sample.
|
||||
std::vector<double> loo_values;
|
||||
trainer.train(samples, labels, loo_values);
|
||||
cout << "mean squared LOO error: " << mean_squared_error(labels,loo_values) << endl;
|
||||
cout << "R^2 LOO value: " << r_squared(labels,loo_values) << endl;
|
||||
// Which outputs the following:
|
||||
// mean squared LOO error: 8.29563e-07
|
||||
// mean squared LOO error: 8.29575e-07
|
||||
// R^2 LOO value: 0.999995
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue