Fixed old tests. Also added more max pooling tests.

This commit is contained in:
Davis King 2015-12-12 12:59:29 -05:00
parent 18d0f0f4d3
commit b7e127f212
1 changed files with 57 additions and 3 deletions

View File

@ -340,7 +340,7 @@ namespace
DLIB_TEST(max(abs(truth3-mat(dest))) < 1e-5);
matrix<float> truth4 = pointwise_multiply(mat(A), mat(B));
multiply(A, B);
multiply(A, A, B);
DLIB_TEST(max(abs(truth4-mat(A))) < 1e-5);
matrix<float> truth5 = mat(B) > 0.1;
@ -497,8 +497,8 @@ namespace
rnd.fill_uniform(dest);
rnd.fill_uniform(src);
dest2 = dest; src2 = src;
cuda::multiply(dest, src);
cpu::multiply(dest2, src2);
cuda::multiply(dest, dest, src);
cpu::multiply(dest2, dest2, src2);
DLIB_TEST(equal(mat(dest),mat(dest2)));
@ -692,6 +692,56 @@ namespace
}
#endif
// ----------------------------------------------------------------------------------------
void test_max_pool(
const int window_height,
const int window_width,
const int stride_y,
const int stride_x
)
{
print_spinner();
resizable_tensor A, B, gradient_input;
A.set_size(2,2,16,7);
B.copy_size(A);
gradient_input.copy_size(A);
tt::tensor_rand rnd;
rnd.fill_gaussian(A,0,1);
rnd.fill_gaussian(B,0,1);
rnd.fill_gaussian(gradient_input,0,1);
tt::max_pool mp;
mp.setup(window_height,window_width,stride_y,stride_x);
mp(A, B);
// make sure max_pool does what it's spec says it should.
DLIB_TEST( A.num_samples() == B.num_samples());
DLIB_TEST( A.k() == B.k());
DLIB_TEST( A.nr() == B.nr()/stride_y);
DLIB_TEST( A.nc() == B.nc()/stride_x);
for (long s = 0; s < A.num_samples(); ++s)
{
for (long k = 0; k < A.k(); ++k)
{
for (long r = 0; r < A.nr(); ++r)
{
for (long c = 0; c < A.nc(); ++c)
{
DLIB_TEST(image_plane(A,s,k)(r,c) == max(subm_clipped(image_plane(B,s,k),
r*stride_y,
c*stride_x,
window_height,
window_width)));
}
}
}
}
}
// ----------------------------------------------------------------------------------------
void test_layers()
@ -785,6 +835,10 @@ namespace
compare_bn_gpu_and_cpu();
compare_bn_conv_gpu_and_cpu();
#endif
test_max_pool(1,1,2,3);
test_max_pool(3,3,1,1);
test_max_pool(3,3,2,2);
test_max_pool(4,5,3,1);
test_tanh();
test_softmax();
test_sigmoid();