mirror of https://github.com/davisking/dlib.git
Added some missing validation of the user supplied number of folds
to the cross_validate_multiclass_trainer() routine. Not it will throw an exception if the number of folds is too big rather than just producing a confusing result.
This commit is contained in:
parent
7c4361534a
commit
2ab4add4a9
|
@ -7,6 +7,7 @@
|
|||
#include "../matrix.h"
|
||||
#include "one_vs_one_trainer.h"
|
||||
#include "cross_validate_multiclass_trainer_abstract.h"
|
||||
#include <sstream>
|
||||
|
||||
namespace dlib
|
||||
{
|
||||
|
@ -67,6 +68,12 @@ namespace dlib
|
|||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class cross_validation_error : public dlib::error
|
||||
{
|
||||
public:
|
||||
cross_validation_error(const std::string& msg) : dlib::error(msg){};
|
||||
};
|
||||
|
||||
template <
|
||||
typename trainer_type,
|
||||
typename sample_type,
|
||||
|
@ -104,6 +111,15 @@ namespace dlib
|
|||
for (typename std::map<label_type,long>::iterator i = label_counts.begin(); i != label_counts.end(); ++i)
|
||||
{
|
||||
const long in_test = i->second/folds;
|
||||
if (in_test == 0)
|
||||
{
|
||||
std::ostringstream sout;
|
||||
sout << "In dlib::cross_validate_multiclass_trainer(), the number of folds was larger" << std::endl;
|
||||
sout << "than the number of elements of one of the training classes." << std::endl;
|
||||
sout << " folds: "<< folds << std::endl;
|
||||
sout << " size of class " << i->first << ": "<< i->second << std::endl;
|
||||
throw cross_validation_error(sout.str());
|
||||
}
|
||||
num_in_test[i->first] = in_test;
|
||||
num_in_train[i->first] = i->second - in_test;
|
||||
}
|
||||
|
|
|
@ -38,6 +38,16 @@ namespace dlib
|
|||
with labels the decision function hasn't ever seen before are ignored.
|
||||
!*/
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class cross_validation_error : public dlib::error
|
||||
{
|
||||
/*!
|
||||
This is the exception class used by the cross_validate_multiclass_trainer()
|
||||
routine.
|
||||
!*/
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
template <
|
||||
|
@ -74,6 +84,10 @@ namespace dlib
|
|||
samples in a class is not an even multiple of folds. This is because each fold has the
|
||||
same number of test samples in it and so if the number of samples in a class isn't a
|
||||
multiple of folds then a few are not tested.
|
||||
throws
|
||||
- cross_validation_error
|
||||
This exception is thrown if one of the classes has fewer samples than
|
||||
the number of requested folds.
|
||||
!*/
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue