Added a dot() function which can dot a sparse vector with a dense vector.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%404177
This commit is contained in:
Davis King 2011-03-20 14:13:32 +00:00
parent 57aed828d4
commit 02ec16db94
4 changed files with 98 additions and 39 deletions

View File

@ -195,19 +195,16 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
template <typename T, typename U> template <typename T>
typename T::value_type::second_type dot ( typename T::value_type::second_type dot (
const T& a, const T& a,
const U& b const T& b
) )
{ {
typedef typename T::value_type::second_type scalar_type; typedef typename T::value_type::second_type scalar_type;
typedef typename U::value_type::second_type scalar_typeU;
// Both T and U must contain the same kinds of elements
COMPILE_TIME_ASSERT((is_same_type<scalar_type, scalar_typeU>::value));
typename T::const_iterator ai = a.begin(); typename T::const_iterator ai = a.begin();
typename U::const_iterator bi = b.begin(); typename T::const_iterator bi = b.begin();
scalar_type sum = 0; scalar_type sum = 0;
while (ai != a.end() && bi != b.end()) while (ai != a.end() && bi != b.end())
@ -231,6 +228,43 @@ namespace dlib
return sum; return sum;
} }
// ------------------------------------------------------------------------------------
template <typename T, typename EXP>
typename T::value_type::second_type dot (
const T& a,
const matrix_exp<EXP>& b
)
{
// make sure requires clause is not broken
DLIB_ASSERT(is_vector(b) && max_index_plus_one(a) <= (unsigned long)b.size(),
"\t scalar_type dot(sparse_vector a, dense_vector b)"
<< "\n\t 'b' must be a vector to be used in a dot product and the sparse vector 'a'"
<< "\n\t can't be bigger that the dense vector 'b'."
);
typedef typename T::value_type::second_type scalar_type;
scalar_type sum = 0;
for (typename T::const_iterator ai = a.begin(); ai != a.end(); ++ai)
{
sum += ai->second * b(ai->first);
}
return sum;
}
// ------------------------------------------------------------------------------------
template <typename T, typename EXP>
typename T::value_type::second_type dot (
const matrix_exp<EXP>& b,
const T& a
)
{
return dot(a,b);
}
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
template <typename T> template <typename T>

View File

@ -138,15 +138,48 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename T, typename U> template <typename T>
typename T::value_type::second_type dot ( typename T::value_type::second_type dot (
const T& a, const T& a,
const U& b const T& b
); );
/*! /*!
requires requires
- a is a sorted range of std::pair objects - a and b are valid sparse vectors (as defined at the top of this file).
- b is a sorted range of std::pair objects ensures
- returns the dot product between the vectors a and b
!*/
// ----------------------------------------------------------------------------------------
template <typename T, typename EXP>
typename T::value_type::second_type dot (
const T& a,
const matrix_exp<EXP>& b
);
/*!
requires
- a is a valid sparse vector (as defined at the top of this file).
- is_vector(b) == true
- max_index_plus_one(a) <= b.size()
(i.e. a can't be bigger than b)
ensures
- returns the dot product between the vectors a and b
!*/
// ----------------------------------------------------------------------------------------
template <typename T, typename EXP>
typename T::value_type::second_type dot (
const matrix_exp<EXP>& a,
const T& b
);
/*!
requires
- b is a valid sparse vector (as defined at the top of this file).
- is_vector(a) == true
- max_index_plus_one(b) <= a.size()
(i.e. b can't be bigger than a)
ensures ensures
- returns the dot product between the vectors a and b - returns the dot product between the vectors a and b
!*/ !*/

View File

@ -152,31 +152,6 @@ namespace dlib
private: private:
// -----------------------------------------------------
// -----------------------------------------------------
template <typename EXP>
scalar_type dot_helper (
const matrix_type& w,
const matrix_exp<EXP>& sample
) const
{
return dot(colm(w,0,sample.size()), sample);
}
template <typename T>
typename disable_if<is_matrix<T>,scalar_type >::type dot_helper (
const matrix_type& w,
const T& sample
) const
{
// compute a dot product between a dense column vector and a sparse vector
scalar_type temp = 0;
for (typename T::const_iterator i = sample.begin(); i != sample.end(); ++i)
temp += w(i->first) * i->second;
return temp;
}
// ----------------------------------------------------- // -----------------------------------------------------
// ----------------------------------------------------- // -----------------------------------------------------
@ -189,9 +164,9 @@ namespace dlib
- for all i: #dot_prods[i] == dot(colm(#w,0,w.size()-1), samples(i)) - #w(w.size()-1) - for all i: #dot_prods[i] == dot(colm(#w,0,w.size()-1), samples(i)) - #w(w.size()-1)
!*/ !*/
{ {
using sparse_vector::dot;
for (long i = 0; i < samples.size(); ++i) for (long i = 0; i < samples.size(); ++i)
dot_prods[i] = dot_helper(w,samples(i)) - w(w.size()-1); dot_prods[i] = dot(colm(w,0,w.size()-1), samples(i)) - w(w.size()-1);
if (is_first_call) if (is_first_call)
{ {

View File

@ -163,6 +163,23 @@ namespace
DLIB_TEST(m(1) == 1 - 2*samples[3][1].second); DLIB_TEST(m(1) == 1 - 2*samples[3][1].second);
DLIB_TEST(m(2) == 1); DLIB_TEST(m(2) == 1);
// test mixed sparse and dense dot products
{
std::map<unsigned int, double> sv;
matrix<double,0,1> dv(4);
dv = 1,2,3,4;
sv[0] = 1;
sv[3] = 1;
using namespace sparse_vector;
DLIB_TEST(dot(sv,dv) == 5);
DLIB_TEST(dot(dv,sv) == 5);
DLIB_TEST(dot(dv,dv) == 30);
DLIB_TEST(dot(sv,sv) == 2);
}
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
@ -202,10 +219,10 @@ namespace
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
class svm_c_linear_tester : public tester class tester_svm_c_linear : public tester
{ {
public: public:
svm_c_linear_tester ( tester_svm_c_linear (
) : ) :
tester ("test_svm_c_linear", tester ("test_svm_c_linear",
"Runs tests on the svm_c_linear_trainer.") "Runs tests on the svm_c_linear_trainer.")