Added a few more tests

This commit is contained in:
Davis King 2016-04-16 06:22:43 -04:00
parent c48d0973c7
commit 79adbae2f5
1 changed files with 7 additions and 2 deletions

View File

@ -154,7 +154,7 @@ namespace
{
using namespace dlib::tt;
print_spinner();
resizable_tensor src(5,5), gamma(1,5), beta(1,5), dest, dest2, means, vars, gradient_input(5,5);
resizable_tensor src(5,5), gamma(1,5), beta(1,5), dest, dest2, dest3, means, vars, gradient_input(5,5);
src = matrix_cast<float>(gaussian_randm(5,5, 0));
gamma = matrix_cast<float>(gaussian_randm(1,5, 1));
beta = matrix_cast<float>(gaussian_randm(1,5, 2));
@ -171,6 +171,8 @@ namespace
running_variances = mat(running_variances)/scale;
batch_normalize_inference(dest2, src, gamma, beta, running_means, running_variances);
DLIB_TEST_MSG(max(abs(mat(dest2)-mat(dest))) < 1e-5, max(abs(mat(dest2)-mat(dest))));
cpu::batch_normalize_inference(dest3, src, gamma, beta, running_means, running_variances);
DLIB_TEST_MSG(max(abs(mat(dest3)-mat(dest))) < 1e-5, max(abs(mat(dest3)-mat(dest))));
auto grad_src = [&](long idx) {
@ -237,7 +239,7 @@ namespace
{
using namespace dlib::tt;
print_spinner();
resizable_tensor src(5,5,4,4), gamma(1,5), beta(1,5), dest, dest2, means, vars, gradient_input(5,5,4,4);
resizable_tensor src(5,5,4,4), gamma(1,5), beta(1,5), dest, dest2, dest3, means, vars, gradient_input(5,5,4,4);
src = matrix_cast<float>(gaussian_randm(5,5*4*4, 0));
gamma = matrix_cast<float>(gaussian_randm(1,5, 1));
beta = matrix_cast<float>(gaussian_randm(1,5, 2));
@ -255,6 +257,8 @@ namespace
running_variances = mat(running_variances)/scale;
batch_normalize_conv_inference(dest2, src, gamma, beta, running_means, running_variances);
DLIB_TEST(max(abs(mat(dest2)-mat(dest))) < 1e-5);
cpu::batch_normalize_conv_inference(dest3, src, gamma, beta, running_means, running_variances);
DLIB_TEST(max(abs(mat(dest3)-mat(dest))) < 1e-5);
auto grad_src = [&](long idx) {
@ -1233,6 +1237,7 @@ namespace
test_avg_pool(3,3,2,2);
test_avg_pool(2,2,2,2);
test_avg_pool(4,5,3,1);
test_avg_pool(100,100,100,100);
test_tanh();
test_softmax();
test_sigmoid();