mirror of https://github.com/davisking/dlib.git
Made test_layers() a little more robust.
This commit is contained in:
parent
9065f08c35
commit
cbd57be677
|
@ -2016,7 +2016,11 @@ namespace dlib
|
||||||
// compare it to the one output by the layer and make sure they match.
|
// compare it to the one output by the layer and make sure they match.
|
||||||
double reference_derivative = (dot(out2,input_grad)-dot(out3, input_grad))/(2*eps);
|
double reference_derivative = (dot(out2,input_grad)-dot(out3, input_grad))/(2*eps);
|
||||||
double output_derivative = params_grad.host()[i];
|
double output_derivative = params_grad.host()[i];
|
||||||
double relative_error = (reference_derivative - output_derivative)/(reference_derivative + 1e-100);
|
double relative_error;
|
||||||
|
if (reference_derivative != 0)
|
||||||
|
relative_error = (reference_derivative - output_derivative)/(reference_derivative);
|
||||||
|
else
|
||||||
|
relative_error = (reference_derivative - output_derivative);
|
||||||
double absolute_error = (reference_derivative - output_derivative);
|
double absolute_error = (reference_derivative - output_derivative);
|
||||||
rs_params.add(std::abs(relative_error));
|
rs_params.add(std::abs(relative_error));
|
||||||
if (std::abs(relative_error) > 0.05 && std::abs(absolute_error) > 0.006)
|
if (std::abs(relative_error) > 0.05 && std::abs(absolute_error) > 0.006)
|
||||||
|
@ -2049,7 +2053,11 @@ namespace dlib
|
||||||
double output_derivative = subnetwork.get_gradient_input_element(i);
|
double output_derivative = subnetwork.get_gradient_input_element(i);
|
||||||
if (!impl::is_inplace_layer(l,subnetwork))
|
if (!impl::is_inplace_layer(l,subnetwork))
|
||||||
output_derivative -= initial_gradient_input[i];
|
output_derivative -= initial_gradient_input[i];
|
||||||
double relative_error = (reference_derivative - output_derivative)/(reference_derivative + 1e-100);
|
double relative_error;
|
||||||
|
if (reference_derivative != 0)
|
||||||
|
relative_error = (reference_derivative - output_derivative)/(reference_derivative);
|
||||||
|
else
|
||||||
|
relative_error = (reference_derivative - output_derivative);
|
||||||
double absolute_error = (reference_derivative - output_derivative);
|
double absolute_error = (reference_derivative - output_derivative);
|
||||||
rs_data.add(std::abs(relative_error));
|
rs_data.add(std::abs(relative_error));
|
||||||
if (std::abs(relative_error) > 0.05 && std::abs(absolute_error) > 0.006)
|
if (std::abs(relative_error) > 0.05 && std::abs(absolute_error) > 0.006)
|
||||||
|
|
Loading…
Reference in New Issue