From 925a9be91cfd89122d01083ac7827baa664f28e3 Mon Sep 17 00:00:00 2001 From: Davis King Date: Wed, 27 May 2009 02:21:00 +0000 Subject: [PATCH] Added overloads of the kernel_derivative object for all the kernels in dlib. --HG-- extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403067 --- dlib/svm/kernel.h | 60 ++++++++++++++++++++++++++++++++++++++ dlib/svm/kernel_abstract.h | 5 +++- dlib/test/svm.cpp | 52 +++++++++++++++++++++++++++++++++ 3 files changed, 116 insertions(+), 1 deletion(-) diff --git a/dlib/svm/kernel.h b/dlib/svm/kernel.h index 15535deb8..48340fb1b 100644 --- a/dlib/svm/kernel.h +++ b/dlib/svm/kernel.h @@ -318,6 +318,28 @@ namespace dlib } } + template < + typename T + > + struct kernel_derivative > + { + typedef typename T::type scalar_type; + typedef T sample_type; + typedef typename T::mem_manager_type mem_manager_type; + + kernel_derivative(const sigmoid_kernel& k_) : k(k_){} + + const sample_type& operator() (const sample_type& x, const sample_type& y) const + { + // return the derivative of the rbf kernel + temp = k.gamma*x*(1-std::pow(k(x,y),2)); + return temp; + } + + const sigmoid_kernel& k; + mutable sample_type temp; + }; + // ---------------------------------------------------------------------------------------- template @@ -359,6 +381,25 @@ namespace dlib std::istream& in ){} + template < + typename T + > + struct kernel_derivative > + { + typedef typename T::type scalar_type; + typedef T sample_type; + typedef typename T::mem_manager_type mem_manager_type; + + kernel_derivative(const linear_kernel& k_) : k(k_){} + + const sample_type& operator() (const sample_type& x, const sample_type& y) const + { + return x; + } + + const linear_kernel& k; + }; + // ---------------------------------------------------------------------------------------- template @@ -442,6 +483,25 @@ namespace dlib } } + template < + typename T + > + struct kernel_derivative > + { + typedef typename T::scalar_type scalar_type; + typedef typename T::sample_type sample_type; + typedef typename T::mem_manager_type mem_manager_type; + + kernel_derivative(const offset_kernel& k) : der(k.kernel){} + + const sample_type operator() (const sample_type& x, const sample_type& y) const + { + return der(x,y); + } + + kernel_derivative der; + }; + // ---------------------------------------------------------------------------------------- } diff --git a/dlib/svm/kernel_abstract.h b/dlib/svm/kernel_abstract.h index 588037832..b8c0a9776 100644 --- a/dlib/svm/kernel_abstract.h +++ b/dlib/svm/kernel_abstract.h @@ -537,6 +537,9 @@ namespace dlib kernel_type must be one of the following kernel types: - radial_basis_kernel - polynomial_kernel + - sigmoid_kernel + - linear_kernel + - offset_kernel WHAT THIS OBJECT REPRESENTS This is a function object that computes the derivative of a kernel @@ -562,7 +565,7 @@ namespace dlib ) const; /*! ensures - - returns the derivative of k with respect to y. Or in other words, k(x, y+dy)/dy + - returns the derivative of k with respect to y. !*/ const kernel_type& k; diff --git a/dlib/test/svm.cpp b/dlib/test/svm.cpp index 9f5682310..5b1e61b22 100644 --- a/dlib/test/svm.cpp +++ b/dlib/test/svm.cpp @@ -399,9 +399,60 @@ namespace // ---------------------------------------------------------------------------------------- + template + struct kernel_der_obj + { + typename kernel_type::sample_type x; + kernel_type k; + + double operator()(const typename kernel_type::sample_type& y) const { return k(x,y); } + }; + template + void test_kernel_derivative ( + const kernel_type& k, + const typename kernel_type::sample_type& x, + const typename kernel_type::sample_type& y + ) + { + kernel_der_obj obj; + obj.x = x; + obj.k = k; + kernel_derivative der(obj.k); + DLIB_CASSERT(dlib::equal(derivative(obj)(y) , der(obj.x,y), 1e-5), ""); + } + void test_kernel_derivative ( + ) + { + typedef matrix sample_type; + + sigmoid_kernel k1; + radial_basis_kernel k2; + linear_kernel k3; + polynomial_kernel k4(2,3,4); + + offset_kernel > k5; + offset_kernel > k6; + + dlib::rand::float_1a rnd; + + sample_type x, y; + for (int i = 0; i < 10; ++i) + { + x = randm(2,1,rnd); + y = randm(2,1,rnd); + test_kernel_derivative(k1, x, y); + test_kernel_derivative(k2, x, y); + test_kernel_derivative(k3, x, y); + test_kernel_derivative(k4, x, y); + test_kernel_derivative(k5, x, y); + test_kernel_derivative(k6, x, y); + } + } + +// ---------------------------------------------------------------------------------------- class svm_tester : public tester { @@ -415,6 +466,7 @@ namespace void perform_test ( ) { + test_kernel_derivative(); test_binary_classification(); test_clutering(); test_regression();