mirror of https://github.com/davisking/dlib.git
Added overloads of the kernel_derivative object for all the kernels in dlib.
--HG-- extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403067
This commit is contained in:
parent
4ed6922b37
commit
925a9be91c
|
@ -318,6 +318,28 @@ namespace dlib
|
|||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T
|
||||
>
|
||||
struct kernel_derivative<sigmoid_kernel<T> >
|
||||
{
|
||||
typedef typename T::type scalar_type;
|
||||
typedef T sample_type;
|
||||
typedef typename T::mem_manager_type mem_manager_type;
|
||||
|
||||
kernel_derivative(const sigmoid_kernel<T>& k_) : k(k_){}
|
||||
|
||||
const sample_type& operator() (const sample_type& x, const sample_type& y) const
|
||||
{
|
||||
// return the derivative of the rbf kernel
|
||||
temp = k.gamma*x*(1-std::pow(k(x,y),2));
|
||||
return temp;
|
||||
}
|
||||
|
||||
const sigmoid_kernel<T>& k;
|
||||
mutable sample_type temp;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename T>
|
||||
|
@ -359,6 +381,25 @@ namespace dlib
|
|||
std::istream& in
|
||||
){}
|
||||
|
||||
template <
|
||||
typename T
|
||||
>
|
||||
struct kernel_derivative<linear_kernel<T> >
|
||||
{
|
||||
typedef typename T::type scalar_type;
|
||||
typedef T sample_type;
|
||||
typedef typename T::mem_manager_type mem_manager_type;
|
||||
|
||||
kernel_derivative(const linear_kernel<T>& k_) : k(k_){}
|
||||
|
||||
const sample_type& operator() (const sample_type& x, const sample_type& y) const
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
const linear_kernel<T>& k;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename T>
|
||||
|
@ -442,6 +483,25 @@ namespace dlib
|
|||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T
|
||||
>
|
||||
struct kernel_derivative<offset_kernel<T> >
|
||||
{
|
||||
typedef typename T::scalar_type scalar_type;
|
||||
typedef typename T::sample_type sample_type;
|
||||
typedef typename T::mem_manager_type mem_manager_type;
|
||||
|
||||
kernel_derivative(const offset_kernel<T>& k) : der(k.kernel){}
|
||||
|
||||
const sample_type operator() (const sample_type& x, const sample_type& y) const
|
||||
{
|
||||
return der(x,y);
|
||||
}
|
||||
|
||||
kernel_derivative<T> der;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
|
|
@ -537,6 +537,9 @@ namespace dlib
|
|||
kernel_type must be one of the following kernel types:
|
||||
- radial_basis_kernel
|
||||
- polynomial_kernel
|
||||
- sigmoid_kernel
|
||||
- linear_kernel
|
||||
- offset_kernel
|
||||
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
This is a function object that computes the derivative of a kernel
|
||||
|
@ -562,7 +565,7 @@ namespace dlib
|
|||
) const;
|
||||
/*!
|
||||
ensures
|
||||
- returns the derivative of k with respect to y. Or in other words, k(x, y+dy)/dy
|
||||
- returns the derivative of k with respect to y.
|
||||
!*/
|
||||
|
||||
const kernel_type& k;
|
||||
|
|
|
@ -399,9 +399,60 @@ namespace
|
|||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename kernel_type>
|
||||
struct kernel_der_obj
|
||||
{
|
||||
typename kernel_type::sample_type x;
|
||||
kernel_type k;
|
||||
|
||||
double operator()(const typename kernel_type::sample_type& y) const { return k(x,y); }
|
||||
};
|
||||
|
||||
|
||||
template <typename kernel_type>
|
||||
void test_kernel_derivative (
|
||||
const kernel_type& k,
|
||||
const typename kernel_type::sample_type& x,
|
||||
const typename kernel_type::sample_type& y
|
||||
)
|
||||
{
|
||||
kernel_der_obj<kernel_type> obj;
|
||||
obj.x = x;
|
||||
obj.k = k;
|
||||
kernel_derivative<kernel_type> der(obj.k);
|
||||
DLIB_CASSERT(dlib::equal(derivative(obj)(y) , der(obj.x,y), 1e-5), "");
|
||||
}
|
||||
|
||||
void test_kernel_derivative (
|
||||
)
|
||||
{
|
||||
typedef matrix<double, 2, 1> sample_type;
|
||||
|
||||
sigmoid_kernel<sample_type> k1;
|
||||
radial_basis_kernel<sample_type> k2;
|
||||
linear_kernel<sample_type> k3;
|
||||
polynomial_kernel<sample_type> k4(2,3,4);
|
||||
|
||||
offset_kernel<sigmoid_kernel<sample_type> > k5;
|
||||
offset_kernel<radial_basis_kernel<sample_type> > k6;
|
||||
|
||||
dlib::rand::float_1a rnd;
|
||||
|
||||
sample_type x, y;
|
||||
for (int i = 0; i < 10; ++i)
|
||||
{
|
||||
x = randm(2,1,rnd);
|
||||
y = randm(2,1,rnd);
|
||||
test_kernel_derivative(k1, x, y);
|
||||
test_kernel_derivative(k2, x, y);
|
||||
test_kernel_derivative(k3, x, y);
|
||||
test_kernel_derivative(k4, x, y);
|
||||
test_kernel_derivative(k5, x, y);
|
||||
test_kernel_derivative(k6, x, y);
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class svm_tester : public tester
|
||||
{
|
||||
|
@ -415,6 +466,7 @@ namespace
|
|||
void perform_test (
|
||||
)
|
||||
{
|
||||
test_kernel_derivative();
|
||||
test_binary_classification();
|
||||
test_clutering();
|
||||
test_regression();
|
||||
|
|
Loading…
Reference in New Issue