Updated examples to work with new ridge regression interface.

This commit is contained in:
Davis King 2011-08-24 21:36:50 -04:00
parent 068bf89d34
commit 8d6cec1d36
2 changed files with 21 additions and 14 deletions

View File

@ -85,9 +85,11 @@ int main()
krr_trainer<kernel_type> trainer;
// The krr_trainer has the ability to perform leave-one-out cross-validation.
// This function tells it to measure errors in terms of the number of classification
// mistakes instead of mean squared error between decision function output values
// and labels. Which is what we want to do since we are performing classification.
// It does this to automatically determine the regularization parameter. Since
// we are performing classification instead of regression we should be sure to
// call use_classification_loss_for_loo_cv(). This function tells it to measure
// errors in terms of the number of classification mistakes instead of mean squared
// error between decision function output values and labels.
trainer.use_classification_loss_for_loo_cv();
@ -98,11 +100,14 @@ int main()
// tell the trainer the parameters we want to use
trainer.set_kernel(kernel_type(gamma));
double loo_error;
trainer.train(samples, labels, loo_error);
// loo_values will contain the LOO predictions for each sample. In the case
// of perfect prediction it will end up being a copy of labels.
std::vector<double> loo_values;
trainer.train(samples, labels, loo_values);
// Print gamma and the fraction of samples misclassified during LOO cross-validation.
cout << "gamma: " << gamma << " LOO error: " << loo_error << endl;
// Print gamma and the fraction of samples correctly classified during LOO cross-validation.
const double classification_accuracy = mean_sign_agreement(labels, loo_values);
cout << "gamma: " << gamma << " LOO accuracy: " << classification_accuracy << endl;
}

View File

@ -78,14 +78,16 @@ int main()
// column is the output from the krr estimate.
// Note that the krr_trainer has the ability to tell us the leave-one-out cross-validation
// accuracy. The train() function has an optional 3rd argument and if we give it a double
// it will give us back the LOO error.
double loo_error;
trainer.train(samples, labels, loo_error);
cout << "mean squared LOO error: " << loo_error << endl;
// Note that the krr_trainer has the ability to tell us the leave-one-out predictions
// for each sample.
std::vector<double> loo_values;
trainer.train(samples, labels, loo_values);
cout << "mean squared LOO error: " << mean_squared_error(labels,loo_values) << endl;
cout << "R^2 LOO value: " << r_squared(labels,loo_values) << endl;
// Which outputs the following:
// mean squared LOO error: 8.29563e-07
// mean squared LOO error: 8.29575e-07
// R^2 LOO value: 0.999995