Added tests for some of the new tensor operators.

This commit is contained in:
Davis King 2015-11-09 17:37:15 -05:00
parent 1d35a89653
commit 00e18a8d59
1 changed files with 48 additions and 0 deletions

View File

@ -41,6 +41,7 @@ namespace
void test_batch_normalize()
{
print_spinner();
resizable_tensor src(5,5), gamma(1,5), beta(1,5), dest, means, vars, gradient_input(5,5);
src = matrix_cast<float>(gaussian_randm(5,5, 0));
gamma = matrix_cast<float>(gaussian_randm(1,5, 1));
@ -115,6 +116,7 @@ namespace
void test_batch_normalize_conv()
{
print_spinner();
resizable_tensor src(5,5,4,4), gamma(1,5), beta(1,5), dest, means, vars, gradient_input(5,5,4,4);
src = matrix_cast<float>(gaussian_randm(5,5*4*4, 0));
gamma = matrix_cast<float>(gaussian_randm(1,5, 1));
@ -190,6 +192,51 @@ namespace
}
// ----------------------------------------------------------------------------------------
void test_basic_tensor_ops()
{
print_spinner();
resizable_tensor dest, src(3,4), A(1,4), B(1,4);
src = 2;
affine_transform(dest, src, 2, 3);
dlog << LINFO << mat(dest);
matrix<float> truth1(3,4), truth2(3,4);
truth1 = 7;
truth2 = 7, 10, 7, 7,
7, 10, 7, 7,
7, 10, 7, 7;
DLIB_TEST(max(abs(truth1-mat(dest))) < 1e-5);
A = 2;
B = 3;
A.host()[1] = 3;
B.host()[1] = 4;
dest = 0;
affine_transform(dest, src, A, B);
dlog << LINFO << mat(dest);
DLIB_TEST(max(abs(truth2-mat(dest))) < 1e-5);
A.set_size(3,4);
B.set_size(3,4);
A = matrix_cast<float>(gaussian_randm(3,4, 1));
B = matrix_cast<float>(gaussian_randm(3,4, 2));
affine_transform(dest, src, A, B);
dlog << LINFO << mat(dest);
matrix<float> truth3 = pointwise_multiply(mat(src), mat(A)) + mat(B);
DLIB_TEST(max(abs(truth3-mat(dest))) < 1e-5);
matrix<float> truth4 = pointwise_multiply(mat(A), mat(B));
multiply(A, B);
DLIB_TEST(max(abs(truth4-mat(A))) < 1e-5);
matrix<float> truth5 = mat(B) > 0.1;
dlog << LINFO << truth5;
threshold(B, 0.1);
DLIB_TEST(max(abs(truth5-mat(B))) < 1e-5);
}
// ----------------------------------------------------------------------------------------
class dnn_tester : public tester
@ -206,6 +253,7 @@ namespace
{
test_batch_normalize();
test_batch_normalize_conv();
test_basic_tensor_ops();
}
} a;