diff --git a/dlib/svm/svm_multiclass_linear_trainer.h b/dlib/svm/svm_multiclass_linear_trainer.h index dfda0b1d6..436042b2d 100644 --- a/dlib/svm/svm_multiclass_linear_trainer.h +++ b/dlib/svm/svm_multiclass_linear_trainer.h @@ -177,7 +177,8 @@ namespace dlib num_threads(4), C(1), eps(0.001), - verbose(false) + verbose(false), + learn_nonnegative_weights(false) { } @@ -243,6 +244,16 @@ namespace dlib return kernel_type(); } + bool learns_nonnegative_weights ( + ) const { return learn_nonnegative_weights; } + + void set_learns_nonnegative_weights ( + bool value + ) + { + learn_nonnegative_weights = value; + } + void set_c ( scalar_type C_ ) @@ -297,7 +308,13 @@ namespace dlib problem.set_c(C); problem.set_epsilon(eps); - svm_objective = solver(problem, weights); + unsigned long num_nonnegative = 0; + if (learn_nonnegative_weights) + { + num_nonnegative = problem.get_num_dimensions(); + } + + svm_objective = solver(problem, weights, num_nonnegative); trained_function_type df; @@ -315,6 +332,7 @@ namespace dlib scalar_type eps; bool verbose; oca solver; + bool learn_nonnegative_weights; }; // ---------------------------------------------------------------------------------------- diff --git a/dlib/svm/svm_multiclass_linear_trainer_abstract.h b/dlib/svm/svm_multiclass_linear_trainer_abstract.h index 2e3db35d9..5bfc089b8 100644 --- a/dlib/svm/svm_multiclass_linear_trainer_abstract.h +++ b/dlib/svm/svm_multiclass_linear_trainer_abstract.h @@ -32,6 +32,7 @@ namespace dlib INITIAL VALUE - get_num_threads() == 4 + - learns_nonnegative_weights() == false - get_epsilon() == 0.001 - get_c() == 1 - this object will not be verbose unless be_verbose() is called @@ -155,6 +156,26 @@ namespace dlib generalization. !*/ + bool learns_nonnegative_weights ( + ) const; + /*! + ensures + - The output of training is a set of weights and bias values that together + define the behavior of a multiclass_linear_decision_function object. If + learns_nonnegative_weights() == true then the resulting weights and bias + values will always have non-negative values. That is, if this function + returns true then all the numbers in the multiclass_linear_decision_function + objects output by train() will be non-negative. + !*/ + + void set_learns_nonnegative_weights ( + bool value + ); + /*! + ensures + - #learns_nonnegative_weights() == value + !*/ + trained_function_type train ( const std::vector& all_samples, const std::vector& all_labels