Made test_layers() a little more robust.

This commit is contained in:
Davis King 2015-12-12 12:51:29 -05:00
parent 9065f08c35
commit cbd57be677
1 changed files with 10 additions and 2 deletions

View File

@ -2016,7 +2016,11 @@ namespace dlib
// compare it to the one output by the layer and make sure they match. // compare it to the one output by the layer and make sure they match.
double reference_derivative = (dot(out2,input_grad)-dot(out3, input_grad))/(2*eps); double reference_derivative = (dot(out2,input_grad)-dot(out3, input_grad))/(2*eps);
double output_derivative = params_grad.host()[i]; double output_derivative = params_grad.host()[i];
double relative_error = (reference_derivative - output_derivative)/(reference_derivative + 1e-100); double relative_error;
if (reference_derivative != 0)
relative_error = (reference_derivative - output_derivative)/(reference_derivative);
else
relative_error = (reference_derivative - output_derivative);
double absolute_error = (reference_derivative - output_derivative); double absolute_error = (reference_derivative - output_derivative);
rs_params.add(std::abs(relative_error)); rs_params.add(std::abs(relative_error));
if (std::abs(relative_error) > 0.05 && std::abs(absolute_error) > 0.006) if (std::abs(relative_error) > 0.05 && std::abs(absolute_error) > 0.006)
@ -2049,7 +2053,11 @@ namespace dlib
double output_derivative = subnetwork.get_gradient_input_element(i); double output_derivative = subnetwork.get_gradient_input_element(i);
if (!impl::is_inplace_layer(l,subnetwork)) if (!impl::is_inplace_layer(l,subnetwork))
output_derivative -= initial_gradient_input[i]; output_derivative -= initial_gradient_input[i];
double relative_error = (reference_derivative - output_derivative)/(reference_derivative + 1e-100); double relative_error;
if (reference_derivative != 0)
relative_error = (reference_derivative - output_derivative)/(reference_derivative);
else
relative_error = (reference_derivative - output_derivative);
double absolute_error = (reference_derivative - output_derivative); double absolute_error = (reference_derivative - output_derivative);
rs_data.add(std::abs(relative_error)); rs_data.add(std::abs(relative_error));
if (std::abs(relative_error) > 0.05 && std::abs(absolute_error) > 0.006) if (std::abs(relative_error) > 0.05 && std::abs(absolute_error) > 0.006)