Added some missing validation of the user supplied number of folds

to the cross_validate_multiclass_trainer() routine.  Not it will
throw an exception if the number of folds is too big rather than
just producing a confusing result.
This commit is contained in:
Davis King 2011-10-30 11:51:49 -04:00
parent 7c4361534a
commit 2ab4add4a9
2 changed files with 30 additions and 0 deletions

View File

@ -7,6 +7,7 @@
#include "../matrix.h"
#include "one_vs_one_trainer.h"
#include "cross_validate_multiclass_trainer_abstract.h"
#include <sstream>
namespace dlib
{
@ -67,6 +68,12 @@ namespace dlib
// ----------------------------------------------------------------------------------------
class cross_validation_error : public dlib::error
{
public:
cross_validation_error(const std::string& msg) : dlib::error(msg){};
};
template <
typename trainer_type,
typename sample_type,
@ -104,6 +111,15 @@ namespace dlib
for (typename std::map<label_type,long>::iterator i = label_counts.begin(); i != label_counts.end(); ++i)
{
const long in_test = i->second/folds;
if (in_test == 0)
{
std::ostringstream sout;
sout << "In dlib::cross_validate_multiclass_trainer(), the number of folds was larger" << std::endl;
sout << "than the number of elements of one of the training classes." << std::endl;
sout << " folds: "<< folds << std::endl;
sout << " size of class " << i->first << ": "<< i->second << std::endl;
throw cross_validation_error(sout.str());
}
num_in_test[i->first] = in_test;
num_in_train[i->first] = i->second - in_test;
}

View File

@ -38,6 +38,16 @@ namespace dlib
with labels the decision function hasn't ever seen before are ignored.
!*/
// ----------------------------------------------------------------------------------------
class cross_validation_error : public dlib::error
{
/*!
This is the exception class used by the cross_validate_multiclass_trainer()
routine.
!*/
};
// ----------------------------------------------------------------------------------------
template <
@ -74,6 +84,10 @@ namespace dlib
samples in a class is not an even multiple of folds. This is because each fold has the
same number of test samples in it and so if the number of samples in a class isn't a
multiple of folds then a few are not tested.
throws
- cross_validation_error
This exception is thrown if one of the classes has fewer samples than
the number of requested folds.
!*/
}