diff --git a/dlib/svm/sparse_vector.h b/dlib/svm/sparse_vector.h index 4abe5cab3..dc1a9ebff 100644 --- a/dlib/svm/sparse_vector.h +++ b/dlib/svm/sparse_vector.h @@ -195,19 +195,16 @@ namespace dlib // ------------------------------------------------------------------------------------ - template + template typename T::value_type::second_type dot ( const T& a, - const U& b + const T& b ) { typedef typename T::value_type::second_type scalar_type; - typedef typename U::value_type::second_type scalar_typeU; - // Both T and U must contain the same kinds of elements - COMPILE_TIME_ASSERT((is_same_type::value)); typename T::const_iterator ai = a.begin(); - typename U::const_iterator bi = b.begin(); + typename T::const_iterator bi = b.begin(); scalar_type sum = 0; while (ai != a.end() && bi != b.end()) @@ -231,6 +228,43 @@ namespace dlib return sum; } + // ------------------------------------------------------------------------------------ + + template + typename T::value_type::second_type dot ( + const T& a, + const matrix_exp& b + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_vector(b) && max_index_plus_one(a) <= (unsigned long)b.size(), + "\t scalar_type dot(sparse_vector a, dense_vector b)" + << "\n\t 'b' must be a vector to be used in a dot product and the sparse vector 'a'" + << "\n\t can't be bigger that the dense vector 'b'." + ); + + typedef typename T::value_type::second_type scalar_type; + + scalar_type sum = 0; + for (typename T::const_iterator ai = a.begin(); ai != a.end(); ++ai) + { + sum += ai->second * b(ai->first); + } + + return sum; + } + + // ------------------------------------------------------------------------------------ + + template + typename T::value_type::second_type dot ( + const matrix_exp& b, + const T& a + ) + { + return dot(a,b); + } + // ------------------------------------------------------------------------------------ template diff --git a/dlib/svm/sparse_vector_abstract.h b/dlib/svm/sparse_vector_abstract.h index bdbe9d2d8..37ceec30d 100644 --- a/dlib/svm/sparse_vector_abstract.h +++ b/dlib/svm/sparse_vector_abstract.h @@ -138,15 +138,48 @@ namespace dlib // ---------------------------------------------------------------------------------------- - template + template typename T::value_type::second_type dot ( const T& a, - const U& b + const T& b ); /*! requires - - a is a sorted range of std::pair objects - - b is a sorted range of std::pair objects + - a and b are valid sparse vectors (as defined at the top of this file). + ensures + - returns the dot product between the vectors a and b + !*/ + + // ---------------------------------------------------------------------------------------- + + template + typename T::value_type::second_type dot ( + const T& a, + const matrix_exp& b + ); + /*! + requires + - a is a valid sparse vector (as defined at the top of this file). + - is_vector(b) == true + - max_index_plus_one(a) <= b.size() + (i.e. a can't be bigger than b) + ensures + - returns the dot product between the vectors a and b + !*/ + + // ---------------------------------------------------------------------------------------- + + template + typename T::value_type::second_type dot ( + const matrix_exp& a, + const T& b + ); + /*! + requires + - b is a valid sparse vector (as defined at the top of this file). + - is_vector(a) == true + - max_index_plus_one(b) <= a.size() + (i.e. b can't be bigger than a) ensures - returns the dot product between the vectors a and b !*/ diff --git a/dlib/svm/svm_c_linear_trainer.h b/dlib/svm/svm_c_linear_trainer.h index a766a457a..0d1ee1c73 100644 --- a/dlib/svm/svm_c_linear_trainer.h +++ b/dlib/svm/svm_c_linear_trainer.h @@ -152,31 +152,6 @@ namespace dlib private: - // ----------------------------------------------------- - // ----------------------------------------------------- - - template - scalar_type dot_helper ( - const matrix_type& w, - const matrix_exp& sample - ) const - { - return dot(colm(w,0,sample.size()), sample); - } - - template - typename disable_if,scalar_type >::type dot_helper ( - const matrix_type& w, - const T& sample - ) const - { - // compute a dot product between a dense column vector and a sparse vector - scalar_type temp = 0; - for (typename T::const_iterator i = sample.begin(); i != sample.end(); ++i) - temp += w(i->first) * i->second; - return temp; - } - // ----------------------------------------------------- // ----------------------------------------------------- @@ -189,9 +164,9 @@ namespace dlib - for all i: #dot_prods[i] == dot(colm(#w,0,w.size()-1), samples(i)) - #w(w.size()-1) !*/ { - + using sparse_vector::dot; for (long i = 0; i < samples.size(); ++i) - dot_prods[i] = dot_helper(w,samples(i)) - w(w.size()-1); + dot_prods[i] = dot(colm(w,0,w.size()-1), samples(i)) - w(w.size()-1); if (is_first_call) { diff --git a/dlib/test/svm_c_linear.cpp b/dlib/test/svm_c_linear.cpp index 881aadc6a..a3f75e998 100644 --- a/dlib/test/svm_c_linear.cpp +++ b/dlib/test/svm_c_linear.cpp @@ -163,6 +163,23 @@ namespace DLIB_TEST(m(1) == 1 - 2*samples[3][1].second); DLIB_TEST(m(2) == 1); + // test mixed sparse and dense dot products + { + std::map sv; + matrix dv(4); + + dv = 1,2,3,4; + + sv[0] = 1; + sv[3] = 1; + + using namespace sparse_vector; + + DLIB_TEST(dot(sv,dv) == 5); + DLIB_TEST(dot(dv,sv) == 5); + DLIB_TEST(dot(dv,dv) == 30); + DLIB_TEST(dot(sv,sv) == 2); + } } // ---------------------------------------------------------------------------------------- @@ -202,10 +219,10 @@ namespace // ---------------------------------------------------------------------------------------- - class svm_c_linear_tester : public tester + class tester_svm_c_linear : public tester { public: - svm_c_linear_tester ( + tester_svm_c_linear ( ) : tester ("test_svm_c_linear", "Runs tests on the svm_c_linear_trainer.")