From 2ab4add4a9d5c61f5c3c89c8f7d07034f3c26663 Mon Sep 17 00:00:00 2001 From: Davis King Date: Sun, 30 Oct 2011 11:51:49 -0400 Subject: [PATCH] Added some missing validation of the user supplied number of folds to the cross_validate_multiclass_trainer() routine. Not it will throw an exception if the number of folds is too big rather than just producing a confusing result. --- dlib/svm/cross_validate_multiclass_trainer.h | 16 ++++++++++++++++ .../cross_validate_multiclass_trainer_abstract.h | 14 ++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/dlib/svm/cross_validate_multiclass_trainer.h b/dlib/svm/cross_validate_multiclass_trainer.h index ad1850404..5aa0d0bee 100644 --- a/dlib/svm/cross_validate_multiclass_trainer.h +++ b/dlib/svm/cross_validate_multiclass_trainer.h @@ -7,6 +7,7 @@ #include "../matrix.h" #include "one_vs_one_trainer.h" #include "cross_validate_multiclass_trainer_abstract.h" +#include namespace dlib { @@ -67,6 +68,12 @@ namespace dlib // ---------------------------------------------------------------------------------------- + class cross_validation_error : public dlib::error + { + public: + cross_validation_error(const std::string& msg) : dlib::error(msg){}; + }; + template < typename trainer_type, typename sample_type, @@ -104,6 +111,15 @@ namespace dlib for (typename std::map::iterator i = label_counts.begin(); i != label_counts.end(); ++i) { const long in_test = i->second/folds; + if (in_test == 0) + { + std::ostringstream sout; + sout << "In dlib::cross_validate_multiclass_trainer(), the number of folds was larger" << std::endl; + sout << "than the number of elements of one of the training classes." << std::endl; + sout << " folds: "<< folds << std::endl; + sout << " size of class " << i->first << ": "<< i->second << std::endl; + throw cross_validation_error(sout.str()); + } num_in_test[i->first] = in_test; num_in_train[i->first] = i->second - in_test; } diff --git a/dlib/svm/cross_validate_multiclass_trainer_abstract.h b/dlib/svm/cross_validate_multiclass_trainer_abstract.h index c9fb7c8b0..a00b23cc4 100644 --- a/dlib/svm/cross_validate_multiclass_trainer_abstract.h +++ b/dlib/svm/cross_validate_multiclass_trainer_abstract.h @@ -38,6 +38,16 @@ namespace dlib with labels the decision function hasn't ever seen before are ignored. !*/ +// ---------------------------------------------------------------------------------------- + + class cross_validation_error : public dlib::error + { + /*! + This is the exception class used by the cross_validate_multiclass_trainer() + routine. + !*/ + }; + // ---------------------------------------------------------------------------------------- template < @@ -74,6 +84,10 @@ namespace dlib samples in a class is not an even multiple of folds. This is because each fold has the same number of test samples in it and so if the number of samples in a class isn't a multiple of folds then a few are not tested. + throws + - cross_validation_error + This exception is thrown if one of the classes has fewer samples than + the number of requested folds. !*/ }