Changed test_regression_function() and cross_validate_regression_trainer() to

output 2 more statistics, which are the mean absolute error and the standard
deviation of the absolute error.  This means these functions now return 4D
rather than 2D vectors.

I also made test_regression_function() take a non-const reference to the
regression function so that DNN objects can be tested.
This commit is contained in:
Davis King 2017-11-10 16:56:37 -05:00
parent 5a0c09c775
commit 6137540b27
3 changed files with 21 additions and 13 deletions

View File

@ -18,9 +18,9 @@ namespace dlib
typename sample_type,
typename label_type
>
matrix<double,1,2>
matrix<double,1,4>
test_regression_function (
const reg_funct_type& reg_funct,
reg_funct_type& reg_funct,
const std::vector<sample_type>& x_test,
const std::vector<label_type>& y_test
)
@ -33,7 +33,7 @@ namespace dlib
<< "\n\t is_learning_problem(x_test,y_test): "
<< is_learning_problem(x_test,y_test));
running_stats<double> rs;
running_stats<double> rs, rs_mae;
running_scalar_covariance<double> rc;
for (unsigned long i = 0; i < x_test.size(); ++i)
@ -42,12 +42,13 @@ namespace dlib
const double output = reg_funct(x_test[i]);
const double temp = output - y_test[i];
rs_mae.add(std::abs(temp));
rs.add(temp*temp);
rc.add(output, y_test[i]);
}
matrix<double,1,2> result;
result = rs.mean(), std::pow(rc.correlation(),2);
matrix<double,1,4> result;
result = rs.mean(), std::pow(rc.correlation(),2), rs_mae.mean(), rs_mae.stddev();
return result;
}
@ -58,7 +59,7 @@ namespace dlib
typename sample_type,
typename label_type
>
matrix<double,1,2>
matrix<double,1,4>
cross_validate_regression_trainer (
const trainer_type& trainer,
const std::vector<sample_type>& x,
@ -82,7 +83,7 @@ namespace dlib
const long num_in_test = x.size()/folds;
const long num_in_train = x.size() - num_in_test;
running_stats<double> rs;
running_stats<double> rs, rs_mae;
running_scalar_covariance<double> rc;
std::vector<sample_type> x_test, x_train;
@ -128,6 +129,7 @@ namespace dlib
const double output = df(x_test[j]);
const double temp = output - y_test[j];
rs_mae.add(std::abs(temp));
rs.add(temp*temp);
rc.add(output, y_test[j]);
}
@ -139,8 +141,8 @@ namespace dlib
} // for (long i = 0; i < folds; ++i)
matrix<double,1,2> result;
result = rs.mean(), std::pow(rc.correlation(),2);
matrix<double,1,4> result;
result = rs.mean(), std::pow(rc.correlation(),2), rs_mae.mean(), rs_mae.stddev();
return result;
}

View File

@ -16,9 +16,9 @@ namespace dlib
typename sample_type,
typename label_type
>
matrix<double,1,2>
matrix<double,1,4>
test_regression_function (
const reg_funct_type& reg_funct,
reg_funct_type& reg_funct,
const std::vector<sample_type>& x_test,
const std::vector<label_type>& y_test
);
@ -35,6 +35,9 @@ namespace dlib
- M(1) == the R-squared value (i.e. the squared correlation between
reg_funct(x_test[i]) and y_test[i]). This is a number between 0
and 1.
- M(2) == the mean absolute error.
This is given by: sum over i: abs(reg_funct(x_test[i]) - y_test[i])
- M(3) == the standard deviation of the absolute error.
!*/
// ----------------------------------------------------------------------------------------
@ -44,7 +47,7 @@ namespace dlib
typename sample_type,
typename label_type
>
matrix<double,1,2>
matrix<double,1,4>
cross_validate_regression_trainer (
const trainer_type& trainer,
const std::vector<sample_type>& x,
@ -66,6 +69,9 @@ namespace dlib
- M(1) == the R-squared value (i.e. the squared correlation between
a predicted y value and its true value). This is a number between
0 and 1.
- M(2) == the mean absolute error.
This is given by: sum over i: abs(reg_funct(x_test[i]) - y_test[i])
- M(3) == the standard deviation of the absolute error.
!*/
}

View File

@ -247,7 +247,7 @@ namespace
randomize_samples(samples, labels);
dlog << LINFO << "KRR MSE and R-squared: "<< cross_validate_regression_trainer(krr_test, samples, labels, 6);
dlog << LINFO << "SVR MSE and R-squared: "<< cross_validate_regression_trainer(svr_test, samples, labels, 6);
matrix<double,1,2> cv = cross_validate_regression_trainer(krr_test, samples, labels, 6);
matrix<double,1,4> cv = cross_validate_regression_trainer(krr_test, samples, labels, 6);
DLIB_TEST(cv(0) < 1e-4);
DLIB_TEST(cv(1) > 0.99);
cv = cross_validate_regression_trainer(svr_test, samples, labels, 6);