mirror of https://github.com/davisking/dlib.git
Added a dot() function which can dot a sparse vector with a dense vector.
--HG-- extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%404177
This commit is contained in:
parent
57aed828d4
commit
02ec16db94
|
@ -195,19 +195,16 @@ namespace dlib
|
|||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T>
|
||||
typename T::value_type::second_type dot (
|
||||
const T& a,
|
||||
const U& b
|
||||
const T& b
|
||||
)
|
||||
{
|
||||
typedef typename T::value_type::second_type scalar_type;
|
||||
typedef typename U::value_type::second_type scalar_typeU;
|
||||
// Both T and U must contain the same kinds of elements
|
||||
COMPILE_TIME_ASSERT((is_same_type<scalar_type, scalar_typeU>::value));
|
||||
|
||||
typename T::const_iterator ai = a.begin();
|
||||
typename U::const_iterator bi = b.begin();
|
||||
typename T::const_iterator bi = b.begin();
|
||||
|
||||
scalar_type sum = 0;
|
||||
while (ai != a.end() && bi != b.end())
|
||||
|
@ -231,6 +228,43 @@ namespace dlib
|
|||
return sum;
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
||||
template <typename T, typename EXP>
|
||||
typename T::value_type::second_type dot (
|
||||
const T& a,
|
||||
const matrix_exp<EXP>& b
|
||||
)
|
||||
{
|
||||
// make sure requires clause is not broken
|
||||
DLIB_ASSERT(is_vector(b) && max_index_plus_one(a) <= (unsigned long)b.size(),
|
||||
"\t scalar_type dot(sparse_vector a, dense_vector b)"
|
||||
<< "\n\t 'b' must be a vector to be used in a dot product and the sparse vector 'a'"
|
||||
<< "\n\t can't be bigger that the dense vector 'b'."
|
||||
);
|
||||
|
||||
typedef typename T::value_type::second_type scalar_type;
|
||||
|
||||
scalar_type sum = 0;
|
||||
for (typename T::const_iterator ai = a.begin(); ai != a.end(); ++ai)
|
||||
{
|
||||
sum += ai->second * b(ai->first);
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
||||
template <typename T, typename EXP>
|
||||
typename T::value_type::second_type dot (
|
||||
const matrix_exp<EXP>& b,
|
||||
const T& a
|
||||
)
|
||||
{
|
||||
return dot(a,b);
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -138,15 +138,48 @@ namespace dlib
|
|||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename T, typename U>
|
||||
template <typename T>
|
||||
typename T::value_type::second_type dot (
|
||||
const T& a,
|
||||
const U& b
|
||||
const T& b
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- a is a sorted range of std::pair objects
|
||||
- b is a sorted range of std::pair objects
|
||||
- a and b are valid sparse vectors (as defined at the top of this file).
|
||||
ensures
|
||||
- returns the dot product between the vectors a and b
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename T, typename EXP>
|
||||
typename T::value_type::second_type dot (
|
||||
const T& a,
|
||||
const matrix_exp<EXP>& b
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- a is a valid sparse vector (as defined at the top of this file).
|
||||
- is_vector(b) == true
|
||||
- max_index_plus_one(a) <= b.size()
|
||||
(i.e. a can't be bigger than b)
|
||||
ensures
|
||||
- returns the dot product between the vectors a and b
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <typename T, typename EXP>
|
||||
typename T::value_type::second_type dot (
|
||||
const matrix_exp<EXP>& a,
|
||||
const T& b
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- b is a valid sparse vector (as defined at the top of this file).
|
||||
- is_vector(a) == true
|
||||
- max_index_plus_one(b) <= a.size()
|
||||
(i.e. b can't be bigger than a)
|
||||
ensures
|
||||
- returns the dot product between the vectors a and b
|
||||
!*/
|
||||
|
|
|
@ -152,31 +152,6 @@ namespace dlib
|
|||
|
||||
private:
|
||||
|
||||
// -----------------------------------------------------
|
||||
// -----------------------------------------------------
|
||||
|
||||
template <typename EXP>
|
||||
scalar_type dot_helper (
|
||||
const matrix_type& w,
|
||||
const matrix_exp<EXP>& sample
|
||||
) const
|
||||
{
|
||||
return dot(colm(w,0,sample.size()), sample);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename disable_if<is_matrix<T>,scalar_type >::type dot_helper (
|
||||
const matrix_type& w,
|
||||
const T& sample
|
||||
) const
|
||||
{
|
||||
// compute a dot product between a dense column vector and a sparse vector
|
||||
scalar_type temp = 0;
|
||||
for (typename T::const_iterator i = sample.begin(); i != sample.end(); ++i)
|
||||
temp += w(i->first) * i->second;
|
||||
return temp;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------
|
||||
// -----------------------------------------------------
|
||||
|
||||
|
@ -189,9 +164,9 @@ namespace dlib
|
|||
- for all i: #dot_prods[i] == dot(colm(#w,0,w.size()-1), samples(i)) - #w(w.size()-1)
|
||||
!*/
|
||||
{
|
||||
|
||||
using sparse_vector::dot;
|
||||
for (long i = 0; i < samples.size(); ++i)
|
||||
dot_prods[i] = dot_helper(w,samples(i)) - w(w.size()-1);
|
||||
dot_prods[i] = dot(colm(w,0,w.size()-1), samples(i)) - w(w.size()-1);
|
||||
|
||||
if (is_first_call)
|
||||
{
|
||||
|
|
|
@ -163,6 +163,23 @@ namespace
|
|||
DLIB_TEST(m(1) == 1 - 2*samples[3][1].second);
|
||||
DLIB_TEST(m(2) == 1);
|
||||
|
||||
// test mixed sparse and dense dot products
|
||||
{
|
||||
std::map<unsigned int, double> sv;
|
||||
matrix<double,0,1> dv(4);
|
||||
|
||||
dv = 1,2,3,4;
|
||||
|
||||
sv[0] = 1;
|
||||
sv[3] = 1;
|
||||
|
||||
using namespace sparse_vector;
|
||||
|
||||
DLIB_TEST(dot(sv,dv) == 5);
|
||||
DLIB_TEST(dot(dv,sv) == 5);
|
||||
DLIB_TEST(dot(dv,dv) == 30);
|
||||
DLIB_TEST(dot(sv,sv) == 2);
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
@ -202,10 +219,10 @@ namespace
|
|||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class svm_c_linear_tester : public tester
|
||||
class tester_svm_c_linear : public tester
|
||||
{
|
||||
public:
|
||||
svm_c_linear_tester (
|
||||
tester_svm_c_linear (
|
||||
) :
|
||||
tester ("test_svm_c_linear",
|
||||
"Runs tests on the svm_c_linear_trainer.")
|
||||
|
|
Loading…
Reference in New Issue