Relaxed the requirements on the feature extractor interface and also

added some tests to make sure the code really does work with the
relaxed interface.
This commit is contained in:
Davis King 2011-12-07 23:03:47 -05:00
parent 6d48b166a2
commit 44c79bcb91
2 changed files with 94 additions and 1 deletions

View File

@ -137,7 +137,10 @@ namespace dlib
set_feature(55,1);
Therefore, the first argument to set_feature is the index of the feature
to be set while the second argument is the value the feature should take.
- This function only calls set_feature() once for each feature index.
Additionally, note that calling set_feature() multiple times with the same
feature index does NOT overwrite the old value, it adds to the previous
value. For example, if you call set_feature(55) 3 times then it will
result in feature 55 having a value of 3.
- This function only calls set_feature() with feature_index values < num_features()
!*/

View File

@ -62,6 +62,47 @@ namespace
}
};
class feature_extractor_partial
{
public:
typedef unsigned long sample_type;
unsigned long num_features() const
{
return num_label_states*num_label_states + num_label_states*num_sample_states;
}
unsigned long order() const
{
return 1;
}
unsigned long num_labels() const
{
return num_label_states;
}
template <typename feature_setter, typename EXP>
void get_features (
feature_setter& set_feature,
const std::vector<sample_type>& x,
const matrix_exp<EXP>& y,
unsigned long position
) const
{
if (y.size() > 1)
{
set_feature(y(1)*num_label_states + y(0), 0.5);
set_feature(y(1)*num_label_states + y(0), 0.5);
}
set_feature(num_label_states*num_label_states +
y(0)*num_sample_states + x[position],0.4);
set_feature(num_label_states*num_label_states +
y(0)*num_sample_states + x[position],0.6);
}
};
bool called_rejct_labeling = false;
class feature_extractor2
{
@ -324,6 +365,53 @@ namespace
DLIB_TEST(std::abs(accuracy - 0.882) < 0.01);
}
// ----------------------------------------------------------------------------------------
void test2()
{
/*
The point of this test is to make sure calling set_feature() multiple
times works the way it is supposed to.
*/
print_spinner();
std::vector<std::vector<unsigned long> > samples;
std::vector<std::vector<unsigned long> > labels;
matrix<double> transition_probabilities(num_label_states, num_label_states);
transition_probabilities = 0.05, 0.90, 0.05,
0.05, 0.05, 0.90,
0.90, 0.05, 0.05;
matrix<double> emission_probabilities(num_label_states,num_sample_states);
emission_probabilities = 0.5, 0.5, 0.0,
0.0, 0.5, 0.5,
0.5, 0.0, 0.5;
make_dataset(transition_probabilities,emission_probabilities,
samples, labels, 1000);
dlog << LINFO << "samples.size(): "<< samples.size();
structural_sequence_labeling_trainer<feature_extractor> trainer;
structural_sequence_labeling_trainer<feature_extractor_partial> trainer_part;
trainer.set_c(4);
trainer_part.set_c(4);
trainer.set_num_threads(4);
trainer_part.set_num_threads(4);
// Learn to do sequence labeling from the dataset
sequence_labeler<feature_extractor> labeler = trainer.train(samples, labels);
sequence_labeler<feature_extractor_partial> labeler_part = trainer_part.train(samples, labels);
// Both feature extractors should be equivalent.
DLIB_TEST(length(labeler.get_weights() - labeler_part.get_weights()) < 1e-10);
}
// ----------------------------------------------------------------------------------------
class sequence_labeler_tester : public tester
@ -342,6 +430,8 @@ namespace
DLIB_TEST(called_rejct_labeling == false);
do_test<feature_extractor2>();
DLIB_TEST(called_rejct_labeling == true);
test2();
}
} a;