From 6137540b276b475a67dfd0e8fec6c9ee56ab2652 Mon Sep 17 00:00:00 2001 From: Davis King Date: Fri, 10 Nov 2017 16:56:37 -0500 Subject: [PATCH] Changed test_regression_function() and cross_validate_regression_trainer() to output 2 more statistics, which are the mean absolute error and the standard deviation of the absolute error. This means these functions now return 4D rather than 2D vectors. I also made test_regression_function() take a non-const reference to the regression function so that DNN objects can be tested. --- dlib/svm/cross_validate_regression_trainer.h | 20 ++++++++++--------- ...oss_validate_regression_trainer_abstract.h | 12 ++++++++--- dlib/test/svm.cpp | 2 +- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/dlib/svm/cross_validate_regression_trainer.h b/dlib/svm/cross_validate_regression_trainer.h index a352781f0..3c9c8937b 100644 --- a/dlib/svm/cross_validate_regression_trainer.h +++ b/dlib/svm/cross_validate_regression_trainer.h @@ -18,9 +18,9 @@ namespace dlib typename sample_type, typename label_type > - matrix + matrix test_regression_function ( - const reg_funct_type& reg_funct, + reg_funct_type& reg_funct, const std::vector& x_test, const std::vector& y_test ) @@ -33,7 +33,7 @@ namespace dlib << "\n\t is_learning_problem(x_test,y_test): " << is_learning_problem(x_test,y_test)); - running_stats rs; + running_stats rs, rs_mae; running_scalar_covariance rc; for (unsigned long i = 0; i < x_test.size(); ++i) @@ -42,12 +42,13 @@ namespace dlib const double output = reg_funct(x_test[i]); const double temp = output - y_test[i]; + rs_mae.add(std::abs(temp)); rs.add(temp*temp); rc.add(output, y_test[i]); } - matrix result; - result = rs.mean(), std::pow(rc.correlation(),2); + matrix result; + result = rs.mean(), std::pow(rc.correlation(),2), rs_mae.mean(), rs_mae.stddev(); return result; } @@ -58,7 +59,7 @@ namespace dlib typename sample_type, typename label_type > - matrix + matrix cross_validate_regression_trainer ( const trainer_type& trainer, const std::vector& x, @@ -82,7 +83,7 @@ namespace dlib const long num_in_test = x.size()/folds; const long num_in_train = x.size() - num_in_test; - running_stats rs; + running_stats rs, rs_mae; running_scalar_covariance rc; std::vector x_test, x_train; @@ -128,6 +129,7 @@ namespace dlib const double output = df(x_test[j]); const double temp = output - y_test[j]; + rs_mae.add(std::abs(temp)); rs.add(temp*temp); rc.add(output, y_test[j]); } @@ -139,8 +141,8 @@ namespace dlib } // for (long i = 0; i < folds; ++i) - matrix result; - result = rs.mean(), std::pow(rc.correlation(),2); + matrix result; + result = rs.mean(), std::pow(rc.correlation(),2), rs_mae.mean(), rs_mae.stddev(); return result; } diff --git a/dlib/svm/cross_validate_regression_trainer_abstract.h b/dlib/svm/cross_validate_regression_trainer_abstract.h index f73fdb42f..19d1d7791 100644 --- a/dlib/svm/cross_validate_regression_trainer_abstract.h +++ b/dlib/svm/cross_validate_regression_trainer_abstract.h @@ -16,9 +16,9 @@ namespace dlib typename sample_type, typename label_type > - matrix + matrix test_regression_function ( - const reg_funct_type& reg_funct, + reg_funct_type& reg_funct, const std::vector& x_test, const std::vector& y_test ); @@ -35,6 +35,9 @@ namespace dlib - M(1) == the R-squared value (i.e. the squared correlation between reg_funct(x_test[i]) and y_test[i]). This is a number between 0 and 1. + - M(2) == the mean absolute error. + This is given by: sum over i: abs(reg_funct(x_test[i]) - y_test[i]) + - M(3) == the standard deviation of the absolute error. !*/ // ---------------------------------------------------------------------------------------- @@ -44,7 +47,7 @@ namespace dlib typename sample_type, typename label_type > - matrix + matrix cross_validate_regression_trainer ( const trainer_type& trainer, const std::vector& x, @@ -66,6 +69,9 @@ namespace dlib - M(1) == the R-squared value (i.e. the squared correlation between a predicted y value and its true value). This is a number between 0 and 1. + - M(2) == the mean absolute error. + This is given by: sum over i: abs(reg_funct(x_test[i]) - y_test[i]) + - M(3) == the standard deviation of the absolute error. !*/ } diff --git a/dlib/test/svm.cpp b/dlib/test/svm.cpp index e2af3eb5c..b46d44331 100644 --- a/dlib/test/svm.cpp +++ b/dlib/test/svm.cpp @@ -247,7 +247,7 @@ namespace randomize_samples(samples, labels); dlog << LINFO << "KRR MSE and R-squared: "<< cross_validate_regression_trainer(krr_test, samples, labels, 6); dlog << LINFO << "SVR MSE and R-squared: "<< cross_validate_regression_trainer(svr_test, samples, labels, 6); - matrix cv = cross_validate_regression_trainer(krr_test, samples, labels, 6); + matrix cv = cross_validate_regression_trainer(krr_test, samples, labels, 6); DLIB_TEST(cv(0) < 1e-4); DLIB_TEST(cv(1) > 0.99); cv = cross_validate_regression_trainer(svr_test, samples, labels, 6);