Updated examples to work with new ridge regression interface.

This commit is contained in:
Davis King 2011-08-24 21:36:50 -04:00
parent 068bf89d34
commit 8d6cec1d36
2 changed files with 21 additions and 14 deletions

View File

@ -85,9 +85,11 @@ int main()
krr_trainer<kernel_type> trainer; krr_trainer<kernel_type> trainer;
// The krr_trainer has the ability to perform leave-one-out cross-validation. // The krr_trainer has the ability to perform leave-one-out cross-validation.
// This function tells it to measure errors in terms of the number of classification // It does this to automatically determine the regularization parameter. Since
// mistakes instead of mean squared error between decision function output values // we are performing classification instead of regression we should be sure to
// and labels. Which is what we want to do since we are performing classification. // call use_classification_loss_for_loo_cv(). This function tells it to measure
// errors in terms of the number of classification mistakes instead of mean squared
// error between decision function output values and labels.
trainer.use_classification_loss_for_loo_cv(); trainer.use_classification_loss_for_loo_cv();
@ -98,11 +100,14 @@ int main()
// tell the trainer the parameters we want to use // tell the trainer the parameters we want to use
trainer.set_kernel(kernel_type(gamma)); trainer.set_kernel(kernel_type(gamma));
double loo_error; // loo_values will contain the LOO predictions for each sample. In the case
trainer.train(samples, labels, loo_error); // of perfect prediction it will end up being a copy of labels.
std::vector<double> loo_values;
trainer.train(samples, labels, loo_values);
// Print gamma and the fraction of samples misclassified during LOO cross-validation. // Print gamma and the fraction of samples correctly classified during LOO cross-validation.
cout << "gamma: " << gamma << " LOO error: " << loo_error << endl; const double classification_accuracy = mean_sign_agreement(labels, loo_values);
cout << "gamma: " << gamma << " LOO accuracy: " << classification_accuracy << endl;
} }

View File

@ -78,14 +78,16 @@ int main()
// column is the output from the krr estimate. // column is the output from the krr estimate.
// Note that the krr_trainer has the ability to tell us the leave-one-out cross-validation // Note that the krr_trainer has the ability to tell us the leave-one-out predictions
// accuracy. The train() function has an optional 3rd argument and if we give it a double // for each sample.
// it will give us back the LOO error. std::vector<double> loo_values;
double loo_error; trainer.train(samples, labels, loo_values);
trainer.train(samples, labels, loo_error); cout << "mean squared LOO error: " << mean_squared_error(labels,loo_values) << endl;
cout << "mean squared LOO error: " << loo_error << endl; cout << "R^2 LOO value: " << r_squared(labels,loo_values) << endl;
// Which outputs the following: // Which outputs the following:
// mean squared LOO error: 8.29563e-07 // mean squared LOO error: 8.29575e-07
// R^2 LOO value: 0.999995